
mmtrack.models.track_heads.roi_embed_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from mmdet.models import HEADS, build_loss
from mmdet.models.losses import accuracy
from torch.nn.modules.utils import _pair

from mmtrack.core import embed_similarity

[文档]@HEADS.register_module() class RoIEmbedHead(BaseModule): """The roi embed head. This module is used in multi-object tracking methods, such as MaskTrack R-CNN. Args: num_convs (int): The number of convoluational layers to embed roi features. Defaults to 0. num_fcs (int): The number of fully connection layers to embed roi features. Defaults to 0. roi_feat_size (int|tuple(int)): The spatial size of roi features. Defaults to 7. in_channels (int): The input channel of roi features. Defaults to 256. conv_out_channels (int): The output channel of roi features after forwarding convoluational layers. Defaults to 256. with_avg_pool (bool): Whether use average pooling before passing roi features into fully connection layers. Defaults to False. fc_out_channels (int): The output channel of roi features after forwarding fully connection layers. Defaults to 1024. conv_cfg (dict): Config dict for convolution layer. Defaults to None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Defaults to None. loss_match (dict): The loss function. Defaults to dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) init_cfg (dict): Configuration of initialization. Defaults to None. """ def __init__(self, num_convs=0, num_fcs=0, roi_feat_size=7, in_channels=256, conv_out_channels=256, with_avg_pool=False, fc_out_channels=1024, conv_cfg=None, norm_cfg=None, loss_match=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), init_cfg=None, **kwargs): super(RoIEmbedHead, self).__init__(init_cfg=init_cfg) self.num_convs = num_convs self.num_fcs = num_fcs self.roi_feat_size = _pair(roi_feat_size) self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.with_avg_pool = with_avg_pool self.fc_out_channels = fc_out_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.loss_match = build_loss(loss_match) self.fp16_enabled = False if self.with_avg_pool: self.avg_pool = nn.AvgPool2d(self.roi_feat_size) # add convs and fcs self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch( self.num_convs, self.num_fcs, self.in_channels) self.relu = nn.ReLU(inplace=True) def _add_conv_fc_branch(self, num_branch_convs, num_branch_fcs, in_channels): """Add shared or separable branch. convs -> avg pool (optional) -> fcs """ last_layer_dim = in_channels # add branch specific conv layers branch_convs = nn.ModuleList() if num_branch_convs > 0: for i in range(num_branch_convs): conv_in_channels = ( last_layer_dim if i == 0 else self.conv_out_channels) branch_convs.append( ConvModule( conv_in_channels, self.conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) last_layer_dim = self.conv_out_channels # add branch specific fc layers branch_fcs = nn.ModuleList() if num_branch_fcs > 0: if not self.with_avg_pool: last_layer_dim *= self.roi_feat_area for i in range(num_branch_fcs): fc_in_channels = ( last_layer_dim if i == 0 else self.fc_out_channels) branch_fcs.append( nn.Linear(fc_in_channels, self.fc_out_channels)) last_layer_dim = self.fc_out_channels return branch_convs, branch_fcs, last_layer_dim @property def custom_activation(self): return getattr(self.loss_match, 'custom_activation', False) def _forward(self, x, num_x_per_img): """Forward the input `x`, and split the output to a list. Args: x (Tensor): of shape [N, C, H, W]. N is the number of proposals. num_x_per_img (list[int]): The `x` contains proposals of multi-images. `num_x_per_img` denotes the number of proposals for each image. Returns: list[Tensor]: Each Tensor denotes the embed features belonging to an image in a batch. """ if self.num_convs > 0: for conv in self.convs: x = conv(x) if self.num_fcs > 0: if self.with_avg_pool: x = self.avg_pool(x) x = x.flatten(1) for fc in self.fcs: x = self.relu(fc(x)) else: x = x.flatten(1) x_split = torch.split(x, num_x_per_img, dim=0) return x_split
[文档] @auto_fp16(apply_to=('x', 'ref_x')) def forward(self, x, ref_x, num_x_per_img, num_x_per_ref_img): """Computing the similarity scores between `x` and `ref_x`. Args: x (Tensor): of shape [N, C, H, W]. N is the number of key frame proposals. ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference frame proposals. num_x_per_img (list[int]): The `x` contains proposals of multi-images. `num_x_per_img` denotes the number of proposals for each key image. num_x_per_ref_img (list[int]): The `ref_x` contains proposals of multi-images. `num_x_per_ref_img` denotes the number of proposals for each reference image. Returns: list[Tensor]: The predicted similarity_logits of each pair of key image and reference image. """ x_split = self._forward(x, num_x_per_img) ref_x_split = self._forward(ref_x, num_x_per_ref_img) similarity_logits = [] for one_x, one_ref_x in zip(x_split, ref_x_split): similarity_logit = embed_similarity( one_x, one_ref_x, method='dot_product') dummy = similarity_logit.new_zeros(one_x.shape[0], 1) similarity_logit =, similarity_logit), dim=1) similarity_logits.append(similarity_logit) return similarity_logits
[文档] def get_targets(self, sampling_results, gt_instance_ids, ref_gt_instance_ids): """Calculate the ground truth for all samples in a batch according to the sampling_results. Args: sampling_results (List[obj:SamplingResults]): Assign results of all images in a batch after sampling. gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of all images in a batch, each tensor has shape (num_gt, ). ref_gt_instance_ids (list[Tensor]): The instance ids of gt_bboxes of all reference images in a batch, each tensor has shape (num_gt, ). Returns: Tuple[list[Tensor]]: Ground truth for proposals in a batch. Containing the following list of Tensors: - track_id_targets (list[Tensor]): The instance ids of Gt_labels for all proposals in a batch, each tensor in list has shape (num_proposals,). - track_id_weights (list[Tensor]): Labels_weights for all proposals in a batch, each tensor in list has shape (num_proposals,). """ track_id_targets = [] track_id_weights = [] for res, gt_instance_id, ref_gt_instance_id in zip( sampling_results, gt_instance_ids, ref_gt_instance_ids): pos_instance_ids = gt_instance_id[res.pos_assigned_gt_inds] pos_match_id = gt_instance_id.new_zeros(len(pos_instance_ids)) for i, id in enumerate(pos_instance_ids): if id in ref_gt_instance_id: pos_match_id[i] = ref_gt_instance_id.tolist().index(id) + 1 track_id_target = gt_instance_id.new_zeros( len(res.bboxes), dtype=torch.int64) track_id_target[:len(res.pos_bboxes)] = pos_match_id track_id_weight = res.bboxes.new_zeros(len(res.bboxes)) track_id_weight[:len(res.pos_bboxes)] = 1.0 track_id_targets.append(track_id_target) track_id_weights.append(track_id_weight) return track_id_targets, track_id_weights
[文档] @force_fp32(apply_to=('similarity_logits', )) def loss(self, similarity_logits, track_id_targets, track_id_weights, reduction_override=None): """Calculate the loss in a batch. Args: similarity_logits (list[Tensor]): The predicted similarity_logits of each pair of key image and reference image. track_id_targets (list[Tensor]): The instance ids of Gt_labels for all proposals in a batch, each tensor in list has shape (num_proposals,). track_id_weights (list[Tensor]): Labels_weights for all proposals in a batch, each tensor in list has shape (num_proposals,). reduction_override (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Returns: dict[str, Tensor]: a dictionary of loss components. """ assert isinstance(similarity_logits, list) assert isinstance(track_id_targets, list) assert isinstance(track_id_weights, list) assert len(similarity_logits) == len(track_id_targets) assert len(track_id_weights) == len(track_id_targets) losses = defaultdict(list) for similarity_logit, track_id_target, track_id_weight in zip( similarity_logits, track_id_targets, track_id_weights): avg_factor = max(torch.sum(track_id_target > 0).float().item(), 1.) if similarity_logit.numel() > 0: loss_match = self.loss_match( similarity_logit, track_id_target, track_id_weight, avg_factor=avg_factor, reduction_override=reduction_override) if isinstance(loss_match, dict): for key, value in loss_match.items(): losses[key].append(value) else: losses['loss_match'].append(loss_match) valid_index = track_id_weight > 0 valid_similarity_logit = similarity_logit[valid_index] valid_track_id_target = track_id_target[valid_index] if self.custom_activation: match_accuracy = self.loss_match.get_accuracy( valid_similarity_logit, valid_track_id_target) for key, value in match_accuracy.items(): losses[key].append(value) else: losses['match_accuracy'].append( accuracy(valid_similarity_logit, valid_track_id_target)) for key, value in losses.items(): losses[key] = sum(losses[key]) / len(similarity_logits) return losses
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.