Shortcuts

mmtrack.models.sot.mixformer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh
from mmdet.models.builder import build_backbone, build_head
from torchvision.transforms.functional import normalize

from ..builder import MODELS
from .stark import Stark


[文档]@MODELS.register_module() class MixFormer(Stark): """MixFormer: End-to-End Tracking with Iterative Mixed Attention. This single object tracker is the implementation of `MixFormer<https://arxiv.org/abs/2203.11082>`_. """ def __init__(self, backbone, head=None, init_cfg=None, frozen_modules=None, train_cfg=None, test_cfg=None): super(Stark, self).__init__(init_cfg) self.backbone = build_backbone(backbone) self.head = build_head(head) self.test_cfg = test_cfg self.train_cfg = train_cfg self.update_interval = self.test_cfg['update_interval'][0] self.online_size = self.test_cfg['online_size'][0] self.max_score_decay = self.test_cfg['max_score_decay'][0] if frozen_modules is not None: self.freeze_module(frozen_modules) def set_online(self, template, online_template): self.backbone.set_online(template, online_template)
[文档] def init(self, img, bbox): """Initialize the single object tracker in the first frame. Args: img: (Tensor): input image of shape (1, C, H, W). bbox (list | Tensor): in [cx, cy, w, h] format. """ self.z_dict_list = [] # store templates # get the 1st template z_patch, _, z_mask = self.get_cropped_img( img, bbox, self.test_cfg['template_factor'], self.test_cfg['template_size'] ) # z_patch pf shape [1,C,H,W]; z_mask of shape [1,H,W] z_patch = normalize( z_patch.squeeze() / 255., mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).unsqueeze(0) with torch.no_grad(): self.set_online(z_patch, z_patch) self.template = z_patch self.online_template = z_patch self.best_online_template = z_patch self.best_conf_score = -1.0 self.online_forget_id = 0
[文档] def update_template(self, img, bbox, conf_score): """Update the dynamic templates. Args: img (Tensor): of shape (1, C, H, W). bbox (list | ndarray): in [cx, cy, w, h] format. conf_score (float): the confidence score of the predicted bbox. """ if conf_score > 0.5 and conf_score > self.best_conf_score: z_patch, _, z_mask = self.get_cropped_img( img, bbox, self.test_cfg['template_factor'], output_size=self.test_cfg['template_size'], ) z_patch = normalize( z_patch.squeeze() / 255., mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ).unsqueeze(0) self.best_online_template = z_patch self.best_conf_score = conf_score if self.frame_id % self.update_interval == 0: if self.online_size == 1: self.online_template = self.best_online_template elif self.online_template.shape[0] < self.online_size: self.online_template = torch.cat( [self.online_template, self.best_online_template]) else: self.online_template[self. online_forget_id:self.online_forget_id + 1] = self.best_online_template self.online_forget_id = (self.online_forget_id + 1) % self.online_size with torch.no_grad(): self.set_online(self.template, self.online_template) self.best_conf_score = -1 self.best_online_template = self.template
[文档] def track(self, img, bbox): """Track the box `bbox` of previous frame to current frame `img` Args: img (Tensor): of shape (1, C, H, W). bbox (list | Tensor): The bbox in previous frame. The shape of the bbox is (4, ) in [x, y, w, h] format. """ H, W = img.shape[2:] # get the t-th search region x_patch, resize_factor, x_mask = self.get_cropped_img( img, bbox, self.test_cfg['search_factor'], self.test_cfg['search_size'] ) # bbox: of shape (x1,y1, w, h), x_mask: of shape (1, h, w) x_patch = normalize( x_patch.squeeze() / 255., mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).unsqueeze(0) with torch.no_grad(): x_patch.squeeze(1) template, search = self.backbone.forward_test(x_patch) out_dict = self.head(template, search) pred_box = out_dict['pred_bboxes'] # xyxy pred_box = self.mapping_bbox_back(pred_box, self.memo.bbox, resize_factor) pred_box = self._bbox_clip(pred_box, H, W, margin=10) # update template self.best_conf_score = self.best_conf_score * self.max_score_decay conf_score = -1. if self.head.score_decoder_head is not None: # get confidence score (whether the search region is reliable) conf_score = out_dict['pred_scores'].view(1).sigmoid().item() crop_bbox = bbox_xyxy_to_cxcywh(pred_box) self.update_template(img, crop_bbox, conf_score) return conf_score, pred_box
[文档] def forward_train(self, imgs, img_metas, search_img, search_img_metas, **kwargs): """forward of training. Args: img (Tensor): template images of shape (N, num_templates, C, H, W) Typically, there are 2 template images, and H and W are both equal to 128. img_metas (list[dict]): list of image information dict where each dict has: 'image_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. search_img (Tensor): of shape (N, 1, C, H, W) encoding input search images. 1 denotes there is only one search image for each exemplar image. Typically H and W are both equal to 320. search_img_metas (list[list[dict]]): The second list only has one element. The first list contains search image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape' and 'img_norm_cfg'. For details on the values of there keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. gt_bboxes (list[Tensor]): Ground truth bboxes for template images with shape (N, 4) in [tl_x, tl_y, br_x, br_y] format. padding_mask (Tensor): padding mask of template images. It's of shape (N, num_templates, H, W). Typically, there are 2 padding masks of tehmplate images, and H and W are both equal to that of template images. search_gt_bboxes (list[Tensor]): Ground truth bboxes for search images with shape (N, 5) in [0., tl_x, tl_y, br_x, br_y] format. search_padding_mask (Tensor): padding mask of search images. Its of shape (N, 1, H, W). There are 1 padding masks of search image, and H and W are both equal to that of search image. search_gt_labels (list[Tensor], optional): Ground truth labels for search images with shape (N, 2). Returns: dict[str, Tensor]: a dictionary of loss components. """ template, online_template = imgs[:, 0, ...], imgs[:, 1, ...] search = search_img.squeeze(1) template, search = self.backbone(template, online_template, search) # box head out_dict = self.head(template, search, **kwargs) # compute loss return out_dict
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.