Shortcuts

mmtrack.models.roi_heads.selsa_roi_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import bbox2result, bbox2roi
from mmdet.models import HEADS, StandardRoIHead


[文档]@HEADS.register_module() class SelsaRoIHead(StandardRoIHead): """selsa roi head."""
[文档] def forward_train(self, x, ref_x, img_metas, proposal_list, ref_proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): """ Args: x (list[Tensor]): list of multi-level img features. ref_x (list[Tensor]): list of multi-level ref_img features. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. proposal_list (list[Tensors]): list of region proposals. ref_proposal_list (list[Tensors]): list of region proposals from ref_imgs. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. Returns: dict[str, Tensor]: a dictionary of loss components """ # assign gts and sample proposals if self.with_bbox or self.with_mask: num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] for i in range(num_imgs): assign_result = self.bbox_assigner.assign( proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes[i], gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train(x, ref_x, sampling_results, ref_proposal_list, gt_bboxes, gt_labels) losses.update(bbox_results['loss_bbox']) # mask head forward and loss if self.with_mask: mask_results = self._mask_forward_train(x, sampling_results, bbox_results['bbox_feats'], gt_masks, img_metas) # TODO: Support empty tensor input. #2280 if mask_results['loss_mask'] is not None: losses.update(mask_results['loss_mask']) return losses
def _bbox_forward(self, x, ref_x, rois, ref_rois): """Box head forward function used in both training and testing.""" # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois, ref_feats=ref_x[:self.bbox_roi_extractor.num_inputs]) ref_bbox_feats = self.bbox_roi_extractor( ref_x[:self.bbox_roi_extractor.num_inputs], ref_rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) ref_bbox_feats = self.shared_head(ref_bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats, ref_bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) return bbox_results def _bbox_forward_train(self, x, ref_x, sampling_results, ref_proposal_list, gt_bboxes, gt_labels): """Run forward function and calculate loss for box head in training.""" rois = bbox2roi([res.bboxes for res in sampling_results]) ref_rois = bbox2roi(ref_proposal_list) bbox_results = self._bbox_forward(x, ref_x, rois, ref_rois) bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, self.train_cfg) loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], bbox_results['bbox_pred'], rois, *bbox_targets) bbox_results.update(loss_bbox=loss_bbox) return bbox_results
[文档] def simple_test(self, x, ref_x, proposals_list, ref_proposals_list, img_metas, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( x, ref_x, proposals_list, ref_proposals_list, img_metas, self.test_cfg, rescale=rescale) bbox_results = [ bbox2result(det_bboxes[i], det_labels[i], self.bbox_head.num_classes) for i in range(len(det_bboxes)) ] if not self.with_mask: return bbox_results else: mask_results = self.simple_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale) return list(zip(bbox_results, mask_results))
[文档] def simple_test_bboxes(self, x, ref_x, proposals, ref_proposals, img_metas, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation.""" rois = bbox2roi(proposals) ref_rois = bbox2roi(ref_proposals) bbox_results = self._bbox_forward(x, ref_x, rois, ref_rois) img_shapes = tuple(meta['img_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) # split batch bbox prediction back to each image cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] num_proposals_per_img = tuple(len(p) for p in proposals) rois = rois.split(num_proposals_per_img, 0) cls_score = cls_score.split(num_proposals_per_img, 0) # some detector with_reg is False, bbox_pred will be None bbox_pred = bbox_pred.split( num_proposals_per_img, 0) if bbox_pred is not None else [None, None] # apply bbox post-processing to each image individually det_bboxes = [] det_labels = [] for i in range(len(proposals)): det_bbox, det_label = self.bbox_head.get_bboxes( rois[i], cls_score[i], bbox_pred[i], img_shapes[i], scale_factors[i], rescale=rescale, cfg=rcnn_test_cfg) det_bboxes.append(det_bbox) det_labels.append(det_label) return det_bboxes, det_labels
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.