
mmtrack.models.track_heads.quasi_dense_embed_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models import HEADS, build_loss

from mmtrack.core import embed_similarity
from .roi_embed_head import RoIEmbedHead

[文档]@HEADS.register_module() class QuasiDenseEmbedHead(RoIEmbedHead): """The quasi-dense roi embed head. Args: embed_channels (int): The input channel of embed features. Defaults to 256. softmax_temp (int): Softmax temperature. Defaults to -1. loss_track (dict): The loss function for tracking. Defaults to MultiPosCrossEntropyLoss. loss_track_aux (dict): The auxiliary loss function for tracking. Defaults to L2Loss. """ def __init__(self, embed_channels=256, softmax_temp=-1, loss_track=dict( type='MultiPosCrossEntropyLoss', loss_weight=0.25), loss_track_aux=dict( type='L2Loss', sample_ratio=3, margin=0.3, loss_weight=1.0, hard_mining=True), init_cfg=dict( type='Xavier', layer='Linear', distribution='uniform', bias=0, override=dict( type='Normal', name='fc_embed', mean=0, std=0.01, bias=0)), *args, **kwargs): super(QuasiDenseEmbedHead, self).__init__( init_cfg=init_cfg, *args, **kwargs) self.fc_embed = nn.Linear(self.last_layer_dim, embed_channels) self.softmax_temp = softmax_temp self.loss_track = build_loss(loss_track) if loss_track_aux is not None: self.loss_track_aux = build_loss(loss_track_aux) else: self.loss_track_aux = None
[文档] def forward(self, x): """Forward the input `x`.""" if self.num_convs > 0: for conv in self.convs: x = conv(x) x = x.flatten(1) if self.num_fcs > 0: for fc in self.fcs: x = self.relu(fc(x)) x = self.fc_embed(x) return x
[文档] def get_targets(self, gt_match_indices, key_sampling_results, ref_sampling_results): """Calculate the track targets and track weights for all samples in a batch according to the sampling_results. Args: key_sampling_results (List[obj:SamplingResults]): Assign results of all images in a batch after sampling. ref_sampling_results (List[obj:SamplingResults]): Assign results of all reference images in a batch after sampling. gt_match_indices (list(Tensor)): Mapping from gt_instance_ids to ref_gt_instance_ids of the same tracklet in a pair of images. Returns: Tuple[list[Tensor]]: Association results. Containing the following list of Tensors: - track_targets (list[Tensor]): The mapping instance ids from all positive proposals in the key image to all proposals in the reference image, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). - track_weights (list[Tensor]): Loss weights for all positive proposals in a batch, each tensor in list has shape (len(key_pos_bboxes),). """ track_targets = [] track_weights = [] for _gt_match_indices, key_res, ref_res in zip(gt_match_indices, key_sampling_results, ref_sampling_results): targets = _gt_match_indices.new_zeros( (key_res.pos_bboxes.size(0), ref_res.bboxes.size(0)), _match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds] pos2pos = (_match_indices.view( -1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int() targets[:, :pos2pos.size(1)] = pos2pos weights = (targets.sum(dim=1) > 0).float() track_targets.append(targets) track_weights.append(weights) return track_targets, track_weights
[文档] def match(self, key_embeds, ref_embeds, key_sampling_results, ref_sampling_results): """Calculate the dist matrixes for loss measurement. Args: key_embeds (Tensor): Embeds of positive bboxes in sampling results of key image. ref_embeds (Tensor): Embeds of all bboxes in sampling results of the reference image. keysampling_results (List[obj:SamplingResults]): Assign results of all images in a batch after sampling. ref_sampling_results (List[obj:SamplingResults]): Assign results of all reference images in a batch after sampling. Returns: Tuple[list[Tensor]]: Calculation results. Containing the following list of Tensors: - dists (list[Tensor]): Dot-product dists between key_embeds and ref_embeds, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). - cos_dists (list[Tensor]): Cosine dists between key_embeds and ref_embeds, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). """ num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results] key_embeds = torch.split(key_embeds, num_key_rois) num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results] ref_embeds = torch.split(ref_embeds, num_ref_rois) dists, cos_dists = [], [] for key_embed, ref_embed in zip(key_embeds, ref_embeds): dist = embed_similarity( key_embed, ref_embed, method='dot_product', temperature=self.softmax_temp) dists.append(dist) if self.loss_track_aux is not None: cos_dist = embed_similarity( key_embed, ref_embed, method='cosine') cos_dists.append(cos_dist) else: cos_dists.append(None) return dists, cos_dists
[文档] def loss(self, dists, cos_dists, targets, weights): """Calculate the track loss and the auxiliary track loss. Args: dists (list[Tensor]): Dot-product dists between key_embeds and ref_embeds. cos_dists (list[Tensor]): Cosine dists between key_embeds and ref_embeds. targets (list[Tensor]): The mapping instance ids from all positive proposals in the key image to all proposals in the reference image, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). weights (list[Tensor]): Loss weights for all positive proposals in a batch, each tensor in list has shape (len(key_pos_bboxes),). Returns: Dict [str: Tensor]: Calculation results. Containing the following list of Tensors: - loss_track (Tensor): Results of loss_track function. - loss_track_aux (Tensor): Results of loss_track_aux function. """ losses = dict() loss_track = 0. loss_track_aux = 0. for _dists, _cos_dists, _targets, _weights in zip( dists, cos_dists, targets, weights): loss_track += self.loss_track( _dists, _targets, _weights, avg_factor=_weights.sum()) if self.loss_track_aux is not None: loss_track_aux += self.loss_track_aux(_cos_dists, _targets) losses['loss_track'] = loss_track / len(dists) if self.loss_track_aux is not None: losses['loss_track_aux'] = loss_track_aux / len(dists) return losses
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.