Shortcuts

mmtrack.datasets.sot_imagenet_vid_dataset 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets import DATASETS

from mmtrack.datasets.parsers import CocoVID
from .base_sot_dataset import BaseSOTDataset


[文档]@DATASETS.register_module() class SOTImageNetVIDDataset(BaseSOTDataset): """ImageNet VID dataset of single object tracking. The dataset only support training mode. """ def __init__(self, ann_file, *args, **kwargs): """Initialization of SOT dataset class. Args: ann_file (str): The coco-format annotation file of ImageNet VID Dataset. It will be loaded and parsed in the `self.load_data_infos` function. """ file_client_args = kwargs.get('file_client_args', dict(backend='disk')) self.file_client = mmcv.FileClient(**file_client_args) with self.file_client.get_local_path(ann_file) as local_path: self.coco = CocoVID(local_path) super().__init__(*args, **kwargs)
[文档] def load_data_infos(self, split='train'): """Load dataset information. Args: split (str, optional): The split of dataset. Defaults to 'train'. Returns: list[int]: The length of the list is the number of instances. The elemment in the list is instance ID in coco API. """ data_infos = list(self.coco.instancesToImgs.keys()) return data_infos
[文档] def get_bboxes_from_video(self, video_ind): """Get bbox annotation about the instance in a video. Considering `get_bboxes_from_video` in `SOTBaseDataset` is not compatible with `SOTImageNetVIDDataset`, we oveload this function though it's not called by `self.get_ann_infos_from_video`. Args: video_ind (int): video index. Each video_ind denotes an instance. Returns: ndarray: in [N, 4] shape. The bbox is in (x, y, w, h) format. """ instance_id = self.data_infos[video_ind] img_ids = self.coco.instancesToImgs[instance_id] bboxes = [] for img_id in img_ids: for ann in self.coco.imgToAnns[img_id]: if ann['instance_id'] == instance_id: bboxes.append(ann['bbox']) bboxes = np.array(bboxes).reshape(-1, 4) return bboxes
[文档] def get_img_infos_from_video(self, video_ind): """Get image information in a video. Args: video_ind (int): video index Returns: dict: {'filename': list[str], 'frame_ids':ndarray, 'video_id':int} """ instance_id = self.data_infos[video_ind] img_ids = self.coco.instancesToImgs[instance_id] frame_ids = [] img_names = [] # In ImageNetVID dataset, frame_ids may not be continuous. for img_id in img_ids: frame_ids.append(self.coco.imgs[img_id]['frame_id']) img_names.append(self.coco.imgs[img_id]['file_name']) img_infos = dict( filename=img_names, frame_ids=frame_ids, video_id=video_ind) return img_infos
[文档] def get_ann_infos_from_video(self, video_ind): """Get annotation information in a video. Note: We overload this function for speed up loading video information. Args: video_ind (int): video index. Each video_ind denotes an instance. Returns: dict: {'bboxes': ndarray in (N, 4) shape, 'bboxes_isvalid': ndarray, 'visible':ndarray}. The bbox is in (x1, y1, x2, y2) format. """ instance_id = self.data_infos[video_ind] img_ids = self.coco.instancesToImgs[instance_id] bboxes = [] visible = [] for img_id in img_ids: for ann in self.coco.imgToAnns[img_id]: if ann['instance_id'] == instance_id: bboxes.append(ann['bbox']) visible.append(not ann.get('occluded', False)) bboxes = np.array(bboxes).reshape(-1, 4) bboxes_isvalid = (bboxes[:, 2] > self.bbox_min_size) & ( bboxes[:, 3] > self.bbox_min_size) bboxes[:, 2:] += bboxes[:, :2] visible = np.array(visible, dtype=np.bool_) & bboxes_isvalid ann_infos = ann_infos = dict( bboxes=bboxes, bboxes_isvalid=bboxes_isvalid, visible=visible) return ann_infos
[文档] def get_visibility_from_video(self, video_ind): """Get the visible information in a video. Considering `get_visibility_from_video` in `SOTBaseDataset` is not compatible with `SOTImageNetVIDDataset`, we oveload this function though it's not called by `self.get_ann_infos_from_video`. """ instance_id = self.data_infos[video_ind] img_ids = self.coco.instancesToImgs[instance_id] visible = [] for img_id in img_ids: for ann in self.coco.imgToAnns[img_id]: if ann['instance_id'] == instance_id: visible.append(not ann.get('occluded', False)) visible_info = dict(visible=np.array(visible, dtype=np.bool_)) return visible_info
[文档] def get_len_per_video(self, video_ind): """Get the number of frames in a video.""" instance_id = self.data_infos[video_ind] return len(self.coco.instancesToImgs[instance_id])
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.