Shortcuts

Source code for mmtrack.datasets.samplers.entire_video_batch_sampler

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

from torch.utils.data import BatchSampler, Sampler

from mmtrack.registry import DATA_SAMPLERS


[docs]@DATA_SAMPLERS.register_module() class EntireVideoBatchSampler(BatchSampler): """A sampler wrapper for grouping images from one video into a same batch. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. Here, we take a video as a batch. Defaults to 1. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``. Defaults to False. """ def __init__(self, sampler: Sampler, batch_size: int = 1, drop_last: bool = False) -> None: assert sampler.dataset.test_mode if not isinstance(sampler, Sampler): raise TypeError('sampler should be an instance of ``Sampler``, ' f'but got {sampler}') if not isinstance(batch_size, int) or batch_size != 1: raise ValueError('batch_size should be 1, ' f'but got batch_size={batch_size}') self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self) -> Sequence[int]: batch = [] for idx in self.sampler: data_info = self.sampler.dataset.get_data_info(idx) video_length = data_info['video_length'] batch.append(idx) if len(batch) == video_length: yield batch batch = [] def __len__(self) -> int: return self.sampler.num_videos
Read the Docs v: 1.x
Versions
latest
stable
1.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.