Shortcuts

mmtrack.models.mot.ocsort 源代码

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from mmdet.models import build_detector

from mmtrack.core import outs2results, results2outs
from ..builder import MODELS, build_motion, build_tracker
from .base import BaseMultiObjectTracker


[文档]@MODELS.register_module() class OCSORT(BaseMultiObjectTracker): """OCOSRT: Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking This multi object tracker is the implementation of `OC-SORT <https://arxiv.org/abs/2203.14360>`_. Args: detector (dict): Configuration of detector. Defaults to None. tracker (dict): Configuration of tracker. Defaults to None. motion (dict): Configuration of motion. Defaults to None. init_cfg (dict): Configuration of initialization. Defaults to None. """ def __init__(self, detector=None, tracker=None, motion=None, init_cfg=None): super().__init__(init_cfg) if detector is not None: self.detector = build_detector(detector) if motion is not None: self.motion = build_motion(motion) if tracker is not None: self.tracker = build_tracker(tracker)
[文档] def forward_train(self, *args, **kwargs): """Forward function during training.""" return self.detector.forward_train(*args, **kwargs)
[文档] def simple_test(self, img, img_metas, rescale=False, **kwargs): """Test without augmentations. Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. rescale (bool, optional): If False, then returned bboxes and masks will fit the scale of img, otherwise, returned bboxes and masks will fit the scale of original image shape. Defaults to False. Returns: dict[str : list(ndarray)]: The tracking results. """ frame_id = img_metas[0].get('frame_id', -1) if frame_id == 0: self.tracker.reset() det_results = self.detector.simple_test( img, img_metas, rescale=rescale) assert len(det_results) == 1, 'Batch inference is not supported.' bbox_results = det_results[0] num_classes = len(bbox_results) outs_det = results2outs(bbox_results=bbox_results) det_bboxes = torch.from_numpy(outs_det['bboxes']).to(img) det_labels = torch.from_numpy(outs_det['labels']).to(img).long() track_bboxes, track_labels, track_ids = self.tracker.track( img=img, img_metas=img_metas, model=self, bboxes=det_bboxes, labels=det_labels, frame_id=frame_id, rescale=rescale, **kwargs) track_results = outs2results( bboxes=track_bboxes, labels=track_labels, ids=track_ids, num_classes=num_classes) det_results = outs2results( bboxes=det_bboxes, labels=det_labels, num_classes=num_classes) return dict( det_bboxes=det_results['bbox_results'], track_bboxes=track_results['bbox_results'])
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.