Shortcuts

mmtrack.models.losses.triplet_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmdet.models import LOSSES


[文档]@LOSSES.register_module() class TripletLoss(nn.Module): """Triplet loss with hard positive/negative mining. Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/ master/torchreid/losses/hard_mine_triplet_loss.py>`_. Args: margin (float, optional): Margin for triplet loss. Default to 0.3. loss_weight (float, optional): Weight of the loss. Default to 1.0. """ def __init__(self, margin=0.3, loss_weight=1.0, hard_mining=True): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) self.loss_weight = loss_weight self.hard_mining = hard_mining
[文档] def hard_mining_triplet_loss_forward(self, inputs, targets): """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (num_classes). """ batch_size = inputs.size(0) # Compute Euclidean distance dist = torch.pow(inputs, 2).sum( dim=1, keepdim=True).expand(batch_size, batch_size) dist = dist + dist.t() dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find the furthest positive sample # and nearest negative sample in the embedding space mask = targets.expand(batch_size, batch_size).eq( targets.expand(batch_size, batch_size).t()) dist_ap, dist_an = [], [] for i in range(batch_size): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = torch.ones_like(dist_an) return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y)
[文档] def forward(self, inputs, targets, **kwargs): if self.hard_mining: return self.hard_mining_triplet_loss_forward(inputs, targets) else: raise NotImplementedError()
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.