Shortcuts

mmtrack.datasets.samplers.video_sampler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler

from mmtrack.datasets.base_sot_dataset import BaseSOTDataset


[文档]class SOTVideoSampler(Sampler): """Only used for sot testing on single gpu. Args: dataset (Dataset): Test dataset must have `num_frames_per_video` attribute. It records the frame number of each video. """ def __init__(self, dataset): super().__init__(dataset) # The input of '__getitem__' function in SOT dataset class must be # a tuple when testing. The tuple is in (video_index, frame_index) # format. self.dataset = dataset self.indices = [] for video_ind, num_frames in enumerate( self.dataset.num_frames_per_video): self.indices.extend([(video_ind, frame_ind) for frame_ind in range(num_frames)]) def __iter__(self): return iter(self.indices) def __len__(self): return len(self.dataset)
[文档]class DistributedVideoSampler(_DistributedSampler): """Put videos to multi gpus during testing. Args: dataset (Dataset): Test dataset must have `data_infos` attribute. Each data_info in `data_infos` records information of one frame or one video (in SOT Dataset). If not SOT Dataset, each video must have one data_info that includes `data_info['frame_id'] == 0`. num_replicas (int): The number of gpus. Defaults to None. rank (int): Gpu rank id. Defaults to None. shuffle (bool): If True, shuffle the dataset. Defaults to False. """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): super().__init__(dataset, num_replicas=num_replicas, rank=rank) self.shuffle = shuffle assert not self.shuffle, 'Specific for video sequential testing.' self.num_samples = len(dataset) if isinstance(dataset, BaseSOTDataset): # The input of '__getitem__' function in SOT dataset class must be # a tuple when testing. The tuple is in (video_index, frame_index) # format. self.num_videos = len(self.dataset.data_infos) self.num_frames_per_video = self.dataset.num_frames_per_video if self.num_videos < num_replicas: raise ValueError(f'only {self.num_videos} videos loaded,' f'but {self.num_replicas} gpus were given.') chunks = np.array_split( list(range(self.num_videos)), self.num_replicas) self.indices = [] for videos in chunks: indices_chunk = [] for video_ind in videos: indices_chunk.extend([ (video_ind, frame_ind) for frame_ind in range( self.num_frames_per_video[video_ind]) ]) self.indices.append(indices_chunk) else: first_frame_indices = [] for i, img_info in enumerate(self.dataset.data_infos): if img_info['frame_id'] == 0: first_frame_indices.append(i) if len(first_frame_indices) < num_replicas: raise ValueError( f'only {len(first_frame_indices)} videos loaded,' f'but {self.num_replicas} gpus were given.') chunks = np.array_split(first_frame_indices, self.num_replicas) split_flags = [c[0] for c in chunks] split_flags.append(self.num_samples) self.indices = [ list(range(split_flags[i], split_flags[i + 1])) for i in range(self.num_replicas) ] def __iter__(self): """Put videos to specify gpu.""" indices = self.indices[self.rank] return iter(indices)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.