Shortcuts

mmtrack.core.optimizer.sot_lr_updater 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math

import numpy as np
from mmcv.runner.hooks import HOOKS, LrUpdaterHook


def step_lr_interval(start_lr_factor, end_lr_factor, start_epoch, end_epoch):
    """Exponentially varying learning rate.

    Generator learning rate factor exponentially varying from `start_lr_factor`
    to `end_lr_factor` in total `end_epoch - start_epoch` epochs.

    Args:
        start_lr_factor (float): Start learning rate factor.
        end_lr_factor (float): End learning rate factor.
        start_epoch (int): Start epoch.
        end_epoch (int): End epoch.

    Returns:
        ndarray: The exponentially varying learning rate.
    """
    epochs = end_epoch - start_epoch
    mult = math.pow(end_lr_factor / start_lr_factor, 1. / (epochs))
    lr_intervals = start_lr_factor * (mult**np.arange(epochs))
    return lr_intervals


def log_lr_interval(start_lr_factor, end_lr_factor, start_epoch, end_epoch):
    """Logarithmically varying learning rate.

    Generator learning rate factor logarithmically varying from
    `start_lr_factor` to `end_lr_factor` in total `end_epoch - start_epoch`
    epochs.

    Args:
        start_lr_factor (float): Start learning rate factor.
        end_lr_factor (float): End learning rate factor.
        start_epoch (int): Start epoch.
        end_epoch (int): End epoch.

    Returns:
        ndarray: The logarithmically varying learning rate.
    """
    epochs = end_epoch - start_epoch
    lr_intervals = np.logspace(
        math.log10(start_lr_factor), math.log10(end_lr_factor), epochs)
    return lr_intervals


[文档]@HOOKS.register_module() class SiameseRPNLrUpdaterHook(LrUpdaterHook): """Learning rate updater for siamese rpn. Args: lr_configs (list[dict]): List of dict where each dict denotes the configuration of specifical learning rate updater and must have 'type'. """ lr_types = {'step': step_lr_interval, 'log': log_lr_interval} def __init__(self, lr_configs=[ dict( type='step', start_lr_factor=0.2, end_lr_factor=1.0, end_epoch=5), dict( type='log', start_lr_factor=1.0, end_lr_factor=0.1, end_epoch=20), ], **kwargs): super(SiameseRPNLrUpdaterHook, self).__init__(**kwargs) assert self.by_epoch is True self.lr_intervals = [] start_epoch = 0 for lr_config in lr_configs: lr_type = self.lr_types[lr_config.pop('type')] lr_config['start_epoch'] = start_epoch lr_intervals = lr_type(**lr_config) self.lr_intervals.append(lr_intervals) start_epoch = lr_config['end_epoch'] self.lr_intervals = np.concatenate(self.lr_intervals)
[文档] def get_lr(self, runner, base_lr): """Get a specifical learning rate for each epoch.""" return base_lr * self.lr_intervals[runner.epoch]
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.