Shortcuts

mmtrack.models.backbones.sot_resnet 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmdet.models.backbones.resnet import Bottleneck, ResNet
from mmdet.models.builder import BACKBONES


class SOTBottleneck(Bottleneck):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None,
                 plugins=None,
                 init_cfg=None):
        """Bottleneck block for ResNet.

        If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
        it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
        super(SOTBottleneck, self).__init__(
            inplanes=inplanes,
            planes=planes,
            stride=1,
            dilation=1,
            downsample=None,
            style='pytorch',
            with_cp=False,
            conv_cfg=None,
            norm_cfg=dict(type='BN'),
            dcn=None,
            plugins=None,
            init_cfg=init_cfg)
        assert style in ['pytorch', 'caffe']
        assert dcn is None or isinstance(dcn, dict)
        assert plugins is None or isinstance(plugins, list)
        if plugins is not None:
            allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
            assert all(p['position'] in allowed_position for p in plugins)

        self.inplanes = inplanes
        self.planes = planes
        self.stride = stride
        self.dilation = dilation
        self.style = style
        self.with_cp = with_cp
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.dcn = dcn
        self.with_dcn = dcn is not None
        self.plugins = plugins
        self.with_plugins = plugins is not None

        padding = 2 - stride
        if dilation > 1:
            padding = dilation
            if downsample is not None:
                dilation = dilation // 2
                padding = dilation

        if self.with_plugins:
            # collect plugins for conv1/conv2/conv3
            self.after_conv1_plugins = [
                plugin['cfg'] for plugin in plugins
                if plugin['position'] == 'after_conv1'
            ]
            self.after_conv2_plugins = [
                plugin['cfg'] for plugin in plugins
                if plugin['position'] == 'after_conv2'
            ]
            self.after_conv3_plugins = [
                plugin['cfg'] for plugin in plugins
                if plugin['position'] == 'after_conv3'
            ]

        if self.style == 'pytorch':
            self.conv1_stride = 1
            self.conv2_stride = stride
        else:
            self.conv1_stride = stride
            self.conv2_stride = 1

        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
        self.norm3_name, norm3 = build_norm_layer(
            norm_cfg, planes * self.expansion, postfix=3)

        self.conv1 = build_conv_layer(
            conv_cfg,
            inplanes,
            planes,
            kernel_size=1,
            stride=self.conv1_stride,
            bias=False)
        self.add_module(self.norm1_name, norm1)
        fallback_on_stride = False
        if self.with_dcn:
            fallback_on_stride = dcn.pop('fallback_on_stride', False)
        if not self.with_dcn or fallback_on_stride:
            self.conv2 = build_conv_layer(
                conv_cfg,
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=padding,
                dilation=dilation,
                bias=False)
        else:
            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
            self.conv2 = build_conv_layer(
                dcn,
                planes,
                planes,
                kernel_size=3,
                stride=self.conv2_stride,
                padding=padding,
                dilation=dilation,
                bias=False)

        self.add_module(self.norm2_name, norm2)
        self.conv3 = build_conv_layer(
            conv_cfg,
            planes,
            planes * self.expansion,
            kernel_size=1,
            bias=False)
        self.add_module(self.norm3_name, norm3)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

        if self.with_plugins:
            self.after_conv1_plugin_names = self.make_block_plugins(
                planes, self.after_conv1_plugins)
            self.after_conv2_plugin_names = self.make_block_plugins(
                planes, self.after_conv2_plugins)
            self.after_conv3_plugin_names = self.make_block_plugins(
                planes * self.expansion, self.after_conv3_plugins)


[文档]@BACKBONES.register_module() class SOTResNet(ResNet): """ResNet backbone for SOT. The main difference between ResNet in torch and the SOTResNet is the padding and dilation in the convs of SOTResNet. Please refer to `SiamRPN++ <https://arxiv.org/abs/1812.11703>`_ for detailed analysis. Args: depth (int): Depth of resnet, from {50, }. """ arch_settings = {50: (SOTBottleneck, (3, 4, 6, 3))} def __init__(self, depth, unfreeze_backbone=True, **kwargs): assert depth == 50, 'Only support r50 backbone for sot.' super(SOTResNet, self).__init__(depth, **kwargs) # unfreeze the backbone parameters so that DDP can build all parameters # into buckets. self.unfreeze_backbone = unfreeze_backbone if self.unfreeze_backbone: self._unfreeze_stages() def _unfreeze_stages(self): if self.frozen_stages >= 0: if self.deep_stem: self.stem.train() for param in self.stem.parameters(): param.requires_grad = True else: self.norm1.train() for m in [self.conv1, self.norm1]: for param in m.parameters(): param.requires_grad = True for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.train() for param in m.parameters(): param.requires_grad = True
[文档] def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" return SOTResLayer(**kwargs)
def _make_stem_layer(self, in_channels, stem_channels): if self.deep_stem: self.stem = nn.Sequential( build_conv_layer( self.conv_cfg, in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=0, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels // 2, kernel_size=3, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, stem_channels // 2)[1], nn.ReLU(inplace=True), build_conv_layer( self.conv_cfg, stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=0, bias=False), build_norm_layer(self.norm_cfg, stem_channels)[1], nn.ReLU(inplace=True)) else: self.conv1 = build_conv_layer( self.conv_cfg, in_channels, stem_channels, kernel_size=7, stride=2, padding=0, bias=False) self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, stem_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
class SOTResLayer(nn.Sequential): """SOTResLayer to build ResNet style backbone for SOT. Args: block (nn.Module): Block used to build SOTResLayer. inplanes (int): Inplanes of block. planes (int): Planes of block. num_blocks (int): Number of blocks. stride (int): Stride of the first block. Default: 1 avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False conv_cfg (dict): Dictionary to construct and config conv layer. Default: None norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN') downsample_first (bool): Downsample at the first block or last block. False for Hourglass, True for ResNet. Default: True """ def __init__(self, block, inplanes, planes, num_blocks, stride=1, dilation=1, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), downsample_first=True, **kwargs): self.block = block if stride != 1 or inplanes != planes * block.expansion: if stride == 1 and dilation == 1: kernel_size = 1 dd = 1 padding = 0 else: kernel_size = 3 if dilation > 1: dd = dilation // 2 padding = dd else: dd = 1 padding = 0 downsample = nn.Sequential( build_conv_layer( conv_cfg, inplanes, planes * block.expansion, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dd, bias=False), build_norm_layer(norm_cfg, planes * block.expansion)[1]) layers = [] if downsample_first: layers.append( block( inplanes=inplanes, planes=planes, stride=stride, dilation=dilation, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) inplanes = planes * block.expansion for _ in range(1, num_blocks): layers.append( block( inplanes=inplanes, planes=planes, stride=1, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) else: # downsample_first=False is for HourglassModule for _ in range(num_blocks - 1): layers.append( block( inplanes=inplanes, planes=inplanes, stride=1, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) layers.append( block( inplanes=inplanes, planes=planes, stride=stride, dilation=dilation, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, **kwargs)) super(SOTResLayer, self).__init__(*layers)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.