Shortcuts

mmtrack.datasets.imagenet_vid_dataset 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets import DATASETS
from mmdet.datasets.api_wrappers import COCO

from .coco_video_dataset import CocoVideoDataset
from .parsers import CocoVID


[文档]@DATASETS.register_module() class ImagenetVIDDataset(CocoVideoDataset): """ImageNet VID dataset for video object detection.""" CLASSES = ('airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car', 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda', 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit', 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle', 'watercraft', 'whale', 'zebra') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[文档] def load_annotations(self, ann_file): """Load annotations from COCO/COCOVID style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation information from COCO/COCOVID api. """ if self.load_as_video: data_infos = self.load_video_anns(ann_file) else: data_infos = self.load_image_anns(ann_file) return data_infos
[文档] def load_image_anns(self, ann_file): """Load annotations from COCO style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation information from COCO api. """ self.coco = COCO(ann_file) self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} all_img_ids = self.coco.get_img_ids() self.img_ids = [] data_infos = [] for img_id in all_img_ids: info = self.coco.load_imgs([img_id])[0] info['filename'] = info['file_name'] if info['is_vid_train_frame']: self.img_ids.append(img_id) data_infos.append(info) return data_infos
[文档] def load_video_anns(self, ann_file): """Load annotations from COCOVID style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation information from COCOVID api. """ self.coco = CocoVID(ann_file) self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} data_infos = [] self.vid_ids = self.coco.get_vid_ids() self.img_ids = [] for vid_id in self.vid_ids: img_ids = self.coco.get_img_ids_from_vid(vid_id) for img_id in img_ids: info = self.coco.load_imgs([img_id])[0] info['filename'] = info['file_name'] if self.test_mode: assert not info['is_vid_train_frame'], \ 'is_vid_train_frame must be False in testing' self.img_ids.append(img_id) data_infos.append(info) elif info['is_vid_train_frame']: self.img_ids.append(img_id) data_infos.append(info) return data_infos
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.