Shortcuts

mmtrack.models.reid.linear_reid_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch.nn as nn
from mmcls.models.builder import HEADS
from mmcls.models.heads.base_head import BaseHead
from mmcls.models.losses import Accuracy
from mmcv.runner import auto_fp16, force_fp32
from mmdet.models.builder import build_loss

from .fc_module import FcModule


[文档]@HEADS.register_module() class LinearReIDHead(BaseHead): """Linear head for re-identification. Args: num_fcs (int): Number of fcs. in_channels (int): Number of channels in the input. fc_channels (int): Number of channels in the fcs. out_channels (int): Number of channels in the output. norm_cfg (dict, optional): Configuration of normlization method after fc. Defaults to None. act_cfg (dict, optional): Configuration of activation method after fc. Defaults to None. num_classes (int, optional): Number of the identities. Default to None. loss (dict, optional): Cross entropy loss to train the re-identificaiton module. loss_pairwise (dict, optional): Triplet loss to train the re-identificaiton module. topk (int, optional): Calculate topk accuracy. Default to False. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01, bias=0). """ def __init__(self, num_fcs, in_channels, fc_channels, out_channels, norm_cfg=None, act_cfg=None, num_classes=None, loss=None, loss_pairwise=None, topk=(1, ), init_cfg=dict( type='Normal', layer='Linear', mean=0, std=0.01, bias=0)): super(LinearReIDHead, self).__init__(init_cfg) assert isinstance(topk, (int, tuple)) if isinstance(topk, int): topk = (topk, ) for _topk in topk: assert _topk > 0, 'Top-k should be larger than 0' self.topk = topk if not loss: if isinstance(num_classes, int): warnings.warn('Since cross entropy is not set, ' 'the num_classes will be ignored.') if not loss_pairwise: raise ValueError('Please choose at least one loss in ' 'triplet loss and cross entropy loss.') elif not isinstance(num_classes, int): raise TypeError('The num_classes must be a current number, ' 'if there is cross entropy loss.') self.loss_cls = build_loss(loss) if loss else None self.loss_triplet = build_loss( loss_pairwise) if loss_pairwise else None self.num_fcs = num_fcs self.in_channels = in_channels self.fc_channels = fc_channels self.out_channels = out_channels self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.num_classes = num_classes self.accuracy = Accuracy(topk=self.topk) self.fp16_enabled = False self._init_layers() def _init_layers(self): """Initialize fc layers.""" self.fcs = nn.ModuleList() for i in range(self.num_fcs): in_channels = self.in_channels if i == 0 else self.fc_channels self.fcs.append( FcModule(in_channels, self.fc_channels, self.norm_cfg, self.act_cfg)) in_channels = self.in_channels if self.num_fcs == 0 else \ self.fc_channels self.fc_out = nn.Linear(in_channels, self.out_channels) if self.loss_cls: self.bn = nn.BatchNorm1d(self.out_channels) self.classifier = nn.Linear(self.out_channels, self.num_classes)
[文档] @auto_fp16() def forward_train(self, x): """Model forward.""" for m in self.fcs: x = m(x) feats = self.fc_out(x) if self.loss_cls: feats_bn = self.bn(feats) cls_score = self.classifier(feats_bn) return (feats, cls_score) return (feats, )
[文档] @force_fp32(apply_to=('feats', 'cls_score')) def loss(self, gt_label, feats, cls_score=None): """Compute losses.""" losses = dict() if self.loss_triplet: losses['triplet_loss'] = self.loss_triplet(feats, gt_label) if self.loss_cls: assert cls_score is not None losses['ce_loss'] = self.loss_cls(cls_score, gt_label) # compute accuracy acc = self.accuracy(cls_score, gt_label) assert len(acc) == len(self.topk) losses['accuracy'] = { f'top-{k}': a for k, a in zip(self.topk, acc) } return losses
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.