Shortcuts

mmtrack.models.reid.base_reid 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmcls.models import ImageClassifier
from mmcv.runner import auto_fp16

from ..builder import REID


[文档]@REID.register_module() class BaseReID(ImageClassifier): """Base class for re-identification."""
[文档] def forward_train(self, img, gt_label, **kwargs): """"Training forward function.""" if img.ndim == 5: # change the shape of image tensor from NxSxCxHxW to NSxCxHxW # where S is the number of samples by triplet sampling img = img.view(-1, *img.shape[2:]) # change the shape of label tensor from NxS to NS gt_label = gt_label.view(-1) x = self.extract_feat(img) head_outputs = self.head.forward_train(x[0]) losses = dict() reid_loss = self.head.loss(gt_label, *head_outputs) losses.update(reid_loss) return losses
[文档] @auto_fp16(apply_to=('img', ), out_fp32=True) def simple_test(self, img, **kwargs): """Test without augmentation.""" if img.nelement() > 0: x = self.extract_feat(img) head_outputs = self.head.forward_train(x[0]) feats = head_outputs[0] return feats else: return img.new_zeros(0, self.head.out_channels)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.