Shortcuts

mmtrack.models.vid.fgfa 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
from addict import Dict
from mmdet.core import bbox2result
from mmdet.models import build_detector

from mmtrack.core import flow_warp_feats
from ..builder import MODELS, build_aggregator, build_motion
from .base import BaseVideoDetector


[文档]@MODELS.register_module() class FGFA(BaseVideoDetector): """Flow-Guided Feature Aggregation for Video Object Detection. This video object detector is the implementation of `FGFA <https://arxiv.org/abs/1703.10025>`_. """ def __init__(self, detector, motion, aggregator, pretrains=None, init_cfg=None, frozen_modules=None, train_cfg=None, test_cfg=None): super(FGFA, self).__init__(init_cfg) if isinstance(pretrains, dict): warnings.warn('DeprecationWarning: pretrains is deprecated, ' 'please use "init_cfg" instead') motion_pretrain = pretrains.get('motion', None) if motion_pretrain: motion.init_cfg = dict( type='Pretrained', checkpoint=motion_pretrain) else: motion.init_cfg = None detector_pretrain = pretrains.get('detector', None) if detector_pretrain: detector.init_cfg = dict( type='Pretrained', checkpoint=detector_pretrain) else: detector.init_cfg = None self.detector = build_detector(detector) self.motion = build_motion(motion) self.aggregator = build_aggregator(aggregator) self.train_cfg = train_cfg self.test_cfg = test_cfg if frozen_modules is not None: self.freeze_module(frozen_modules)
[文档] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, ref_img, ref_img_metas, ref_gt_bboxes, ref_gt_labels, gt_instance_ids=None, gt_bboxes_ignore=None, gt_masks=None, proposals=None, ref_gt_instance_ids=None, ref_gt_bboxes_ignore=None, ref_gt_masks=None, ref_proposals=None, **kwargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. ref_img (Tensor): of shape (N, 2, C, H, W) encoding input images. Typically these should be mean centered and std scaled. 2 denotes there is two reference images for each input image. ref_img_metas (list[list[dict]]): The first list only has one element. The second list contains reference image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The Tensor contains ground truth bboxes for each reference image with shape (num_all_ref_gts, 5) in [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id start from 0, and denotes the id of reference image for each key image. ref_gt_labels (list[Tensor]): The list only has one Tensor. The Tensor contains class indices corresponding to each reference box with shape (num_all_ref_gts, 2) in [ref_img_id, class_indice]. gt_instance_ids (None | list[Tensor]): specify the instance id for each ground truth bbox. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals (None | Tensor) : override rpn proposals with custom proposals. Use when `with_rpn` is False. ref_gt_instance_ids (None | list[Tensor]): specify the instance id for each ground truth bboxes of reference images. ref_gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes of reference images can be ignored when computing the loss. ref_gt_masks (None | Tensor) : True segmentation masks for each box of reference image used if the architecture supports a segmentation task. ref_proposals (None | Tensor) : override rpn proposals with custom proposals of reference images. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ assert len(img) == 1, \ 'fgfa video detectors only support 1 batch size per gpu for now.' flow_imgs = torch.cat((img, ref_img[:, 0]), dim=1) for i in range(1, ref_img.shape[1]): flow_img = torch.cat((img, ref_img[:, i]), dim=1) flow_imgs = torch.cat((flow_imgs, flow_img), dim=0) flows = self.motion(flow_imgs, img_metas) all_imgs = torch.cat((img, ref_img[0]), dim=0) all_x = self.detector.extract_feat(all_imgs) x = [] for i in range(len(all_x)): ref_x_single = flow_warp_feats(all_x[i][1:], flows) agg_x_single = self.aggregator(all_x[i][[0]], ref_x_single) x.append(agg_x_single) losses = dict() # Two stage detector if hasattr(self.detector, 'roi_head'): # RPN forward and loss if self.detector.with_rpn: proposal_cfg = self.detector.train_cfg.get( 'rpn_proposal', self.detector.test_cfg.rpn) rpn_losses, proposal_list = \ self.detector.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.detector.roi_head.forward_train( x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) # Single stage detector elif hasattr(self.detector, 'bbox_head'): bbox_losses = self.detector.bbox_head.forward_train( x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) losses.update(bbox_losses) else: raise TypeError('detector must has roi_head or bbox_head.') return losses
[文档] def extract_feats(self, img, img_metas, ref_img, ref_img_metas): """Extract features for `img` during testing. Args: img (Tensor): of shape (1, C, H, W) encoding input image. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. ref_img (Tensor | None): of shape (1, N, C, H, W) encoding input reference images. Typically these should be mean centered and std scaled. N denotes the number of reference images. There may be no reference images in some cases. ref_img_metas (list[list[dict]] | None): The first list only has one element. The second list contains image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. There may be no reference images in some cases. Returns: list[Tensor]: Multi level feature maps of `img`. """ frame_id = img_metas[0].get('frame_id', -1) assert frame_id >= 0 num_left_ref_imgs = img_metas[0].get('num_left_ref_imgs', -1) frame_stride = img_metas[0].get('frame_stride', -1) # test with adaptive stride if frame_stride < 1: if frame_id == 0: self.memo = Dict() self.memo.img = ref_img[0] ref_x = self.detector.extract_feat(ref_img[0]) # 'tuple' object (e.g. the output of FPN) does not support # item assignment self.memo.feats = [] for i in range(len(ref_x)): self.memo.feats.append(ref_x[i]) x = self.detector.extract_feat(img) # test with fixed stride else: if frame_id == 0: self.memo = Dict() self.memo.img = ref_img[0] ref_x = self.detector.extract_feat(ref_img[0]) # 'tuple' object (e.g. the output of FPN) does not support # item assignment self.memo.feats = [] # the features of img is same as ref_x[i][[num_left_ref_imgs]] x = [] for i in range(len(ref_x)): self.memo.feats.append(ref_x[i]) x.append(ref_x[i][[num_left_ref_imgs]]) elif frame_id % frame_stride == 0: assert ref_img is not None x = [] ref_x = self.detector.extract_feat(ref_img[0]) for i in range(len(ref_x)): self.memo.feats[i] = torch.cat( (self.memo.feats[i], ref_x[i]), dim=0)[1:] x.append(self.memo.feats[i][[num_left_ref_imgs]]) self.memo.img = torch.cat((self.memo.img, ref_img[0]), dim=0)[1:] else: assert ref_img is None x = self.detector.extract_feat(img) flow_imgs = torch.cat( (img.repeat(self.memo.img.shape[0], 1, 1, 1), self.memo.img), dim=1) flows = self.motion(flow_imgs, img_metas) agg_x = [] for i in range(len(x)): agg_x_single = flow_warp_feats(self.memo.feats[i], flows) if frame_stride < 1: agg_x_single = torch.cat((x[i], agg_x_single), dim=0) else: agg_x_single[num_left_ref_imgs] = x[i] agg_x_single = self.aggregator(x[i], agg_x_single) agg_x.append(agg_x_single) return agg_x
[文档] def simple_test(self, img, img_metas, ref_img=None, ref_img_metas=None, proposals=None, rescale=False): """Test without augmentation. Args: img (Tensor): of shape (1, C, H, W) encoding input image. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. ref_img (list[Tensor] | None): The list only contains one Tensor of shape (1, N, C, H, W) encoding input reference images. Typically these should be mean centered and std scaled. N denotes the number for reference images. There may be no reference images in some cases. ref_img_metas (list[list[list[dict]]] | None): The first and second list only has one element. The third list contains image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. There may be no reference images in some cases. proposals (None | Tensor): Override rpn proposals with custom proposals. Use when `with_rpn` is False. Defaults to None. rescale (bool): If False, then returned bboxes and masks will fit the scale of img, otherwise, returned bboxes and masks will fit the scale of original image shape. Defaults to False. Returns: dict[str : list(ndarray)]: The detection results. """ if ref_img is not None: ref_img = ref_img[0] if ref_img_metas is not None: ref_img_metas = ref_img_metas[0] x = self.extract_feats(img, img_metas, ref_img, ref_img_metas) # Two stage detector if hasattr(self.detector, 'roi_head'): if proposals is None: proposal_list = self.detector.rpn_head.simple_test_rpn( x, img_metas) else: proposal_list = proposals outs = self.detector.roi_head.simple_test( x, proposal_list, img_metas, rescale=rescale) # Single stage detector elif hasattr(self.detector, 'bbox_head'): outs = self.bbox_head(x) bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) # skip post-processing when exporting to ONNX if torch.onnx.is_in_onnx_export(): return bbox_list outs = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] else: raise TypeError('detector must has roi_head or bbox_head.') results = dict() results['det_bboxes'] = outs[0] if len(outs) == 2: results['det_masks'] = outs[1] return results
[文档] def aug_test(self, imgs, img_metas, **kwargs): """Test function with test time augmentation.""" raise NotImplementedError
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.