Shortcuts

mmtrack.models.backbones.mixformer_backbone 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import logging
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule
from mmdet.models.builder import BACKBONES

from mmtrack.core.utils.misc import ntuple

to_2tuple = ntuple(2)


class LayerNormAutofp32(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    """Approximation of GELU activation function introduced in `Gaussian Error
    Linear Units<https://arxiv.org/abs/1606.08415v4>`."""

    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class MixedAttentionModule(nn.Module):
    """Mixed Attention Module (MAM) proposed in MixFormer.

    It is the core design for simultaneous feature extraction
    and target information integration. Please refer to
    `MixFormer<https://arxiv.org/abs/2203.11082>`_ for more details.

    Args:
        dim_in (int): Input dimension of this module.
        dim_out (int): Output dimension of this module.
        num_heads (int): Number of heads in multi-head attention mechanism.
        qkv_bias (bool): Add bias when projecting to qkv tokens.
            Default: False
        attn_drop (float): A Dropout layer on attn_output_weight.
            Default: 0.0
        proj_drop (float): A Dropout layer after attention operation.
            Default: 0.0
        method (str): Method for building the projection layer. Choices are
            ['dw_bn', 'avg', 'identity']. Default: 'dw_bn'
        kernel_size (int): Kernel size of the projection layer. Default: 1
        stride_q (int): Stride of the query projection layer. Default: 1
        stride_kv (int): Stride of the key/value projection layer. Default: 1
        padding_q (int): Padding number of the query projection layer.
            Default: 1
        padding_kv (int): Padding number of the key/value projection layer.
            Default: 1
        norm_cfg (dict): Norm layer config.
    """

    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        self.scale = dim_out**-0.5
        self.norm_cfg = norm_cfg

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q, stride_q,
            'identity' if method == 'avg' else method)
        self.conv_proj_k = self._build_projection(dim_in, dim_out, kernel_size,
                                                  padding_kv, stride_kv,
                                                  method)
        self.conv_proj_v = self._build_projection(dim_in, dim_out, kernel_size,
                                                  padding_kv, stride_kv,
                                                  method)

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self, dim_in, dim_out, kernel_size, padding, stride,
                          method):
        """Build qkv projection according to `method` argument.

        Args:
            dim_in (int): input dimension
            dim_out (int): output dimension
            kernel_size (int): kernel size of convolution
            padding (int): padding of convolution
            stride (int): stride of convolution
            method (str): description of projection method.
                ``'dw_bn'``: Apply a convolution layer with batch
                                    norm on the input.
                ``'avg'``: Apply an avgpool2d on the input.
                ``'identity'``: No transformation on the input.
        """
        if method == 'dw_bn':
            proj = nn.Sequential(
                OrderedDict([
                    ('conv',
                     nn.Conv2d(
                         dim_in,
                         dim_in,
                         kernel_size=kernel_size,
                         padding=padding,
                         stride=stride,
                         bias=False,
                         groups=dim_in)),
                    build_norm_layer(self.norm_cfg, dim_in),
                ]))
        elif method == 'avg':
            proj = nn.Sequential(
                OrderedDict([
                    ('avg',
                     nn.AvgPool2d(
                         kernel_size=kernel_size,
                         padding=padding,
                         stride=stride,
                         ceil_mode=True)),
                ]))
        elif method == 'identity':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, t_h, t_w, s_h, s_w):
        """Projecting the input to qkv tokens."""
        template, online_template, search = torch.split(
            x, [t_h * t_w, t_h * t_w, s_h * s_w], dim=1)
        template = rearrange(
            template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()
        online_template = rearrange(
            online_template, 'b (h w) c -> b c h w', h=t_h,
            w=t_w).contiguous()
        search = rearrange(
            search, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()

        if self.conv_proj_q is not None:
            t_q = self.conv_proj_q(template)
            ot_q = self.conv_proj_q(online_template)
            s_q = self.conv_proj_q(search)
        else:
            t_q = template
            ot_q = online_template
            s_q = search

        t_q = rearrange(t_q, 'b c h w -> b (h w) c').contiguous()
        ot_q = rearrange(ot_q, 'b c h w -> b (h w) c').contiguous()
        s_q = rearrange(s_q, 'b c h w -> b (h w) c').contiguous()
        q = torch.cat([t_q, ot_q, s_q], dim=1)

        if self.conv_proj_k is not None:
            t_k = self.conv_proj_k(template)
            ot_k = self.conv_proj_k(online_template)
            s_k = self.conv_proj_k(search)
        else:
            t_k = template
            ot_k = online_template
            s_k = search

        t_k = rearrange(t_k, 'b c h w -> b (h w) c').contiguous()
        ot_k = rearrange(ot_k, 'b c h w -> b (h w) c').contiguous()
        s_k = rearrange(s_k, 'b c h w -> b (h w) c').contiguous()
        k = torch.cat([t_k, ot_k, s_k], dim=1)

        if self.conv_proj_v is not None:
            t_v = self.conv_proj_v(template)
            ot_v = self.conv_proj_v(online_template)
            s_v = self.conv_proj_v(search)
        else:
            t_v = template
            ot_v = online_template
            s_v = search

        t_v = rearrange(t_v, 'b c h w -> b (h w) c').contiguous()
        ot_v = rearrange(ot_v, 'b c h w -> b (h w) c').contiguous()
        s_v = rearrange(s_v, 'b c h w -> b (h w) c').contiguous()
        v = torch.cat([t_v, ot_v, s_v], dim=1)

        return q, k, v

    def forward_conv_test(self, x, s_h, s_w):
        search = rearrange(
            x, 'b (h w) c -> b c h w', h=s_h, w=s_w).contiguous()

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(search)
        else:
            q = search
        q = rearrange(q, 'b c h w -> b (h w) c').contiguous()

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(search)
        else:
            k = search
        k = rearrange(k, 'b c h w -> b (h w) c').contiguous()
        k = torch.cat([self.t_k, self.ot_k, k], dim=1)

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(search)
        else:
            v = search
        v = rearrange(v, 'b c h w -> b (h w) c').contiguous()
        v = torch.cat([self.t_v, self.ot_v, v], dim=1)

        return q, k, v

    def forward(self, x, t_h, t_w, s_h, s_w):
        """Asymmetric mixed attention.

        Args:
            x (Tensor): concatenated feature of temmlate and search,
                shape (B, 2*t_h*t_w+s_h*s_w, C)
            t_h (int): template feature height
            t_w (int): template feature width
            s_h (int): search region feature height
            s_w (int): search region feature width
        """
        if (self.conv_proj_q is not None or self.conv_proj_k is not None
                or self.conv_proj_v is not None):
            q, k, v = self.forward_conv(x, t_h, t_w, s_h, s_w)

        q = rearrange(
            self.proj_q(q), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        k = rearrange(
            self.proj_k(k), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        v = rearrange(
            self.proj_v(v), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()

        # Attention!: k/v compression,1/4 of q_size(conv_stride=2)

        q_mt, q_s = torch.split(q, [t_h * t_w * 2, s_h * s_w], dim=2)
        k_mt, k_s = torch.split(
            k, [((t_h + 1) // 2)**2 * 2, s_h * s_w // 4], dim=2)
        v_mt, v_s = torch.split(
            v, [((t_h + 1) // 2)**2 * 2, s_h * s_w // 4], dim=2)

        # template attention
        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_mt, k_mt]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)
        x_mt = torch.einsum('bhlt,bhtv->bhlv', [attn, v_mt])
        x_mt = rearrange(x_mt, 'b h t d -> b t (h d)')

        # search region attention
        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)
        x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x_s = rearrange(x_s, 'b h t d -> b t (h d)')

        x = torch.cat([x_mt, x_s], dim=1)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def forward_test(self, x, s_h, s_w):
        if (self.conv_proj_q is not None or self.conv_proj_k is not None
                or self.conv_proj_v is not None):
            q_s, k, v = self.forward_conv_test(x, s_h, s_w)

        q_s = rearrange(
            self.proj_q(q_s), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        k = rearrange(
            self.proj_k(k), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        v = rearrange(
            self.proj_v(v), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q_s, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)
        x_s = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x_s = rearrange(x_s, 'b h t d -> b t (h d)').contiguous()

        x = x_s

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def set_online(self, x, t_h, t_w):
        template = x[:, :t_h * t_w]  # 1, 1024, c
        online_template = x[:, t_h * t_w:]  # 1, b*1024, c
        template = rearrange(
            template, 'b (h w) c -> b c h w', h=t_h, w=t_w).contiguous()
        online_template = rearrange(
            online_template.squeeze(0), '(b h w) c -> b c h w', h=t_h,
            w=t_w).contiguous()  # b, c, 32, 32

        if self.conv_proj_q is not None:
            t_q = rearrange(
                self.conv_proj_q(template),
                'b c h w -> b (h w) c').contiguous()
            ot_q = rearrange(
                self.conv_proj_q(online_template),
                'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        else:
            t_q = rearrange(template, 'b c h w -> b (h w) c').contiguous()
            ot_q = rearrange(online_template,
                             'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        q = torch.cat([t_q, ot_q], dim=1)

        if self.conv_proj_k is not None:
            self.t_k = rearrange(
                self.conv_proj_k(template),
                'b c h w -> b (h w) c').contiguous()
            self.ot_k = rearrange(
                self.conv_proj_k(online_template),
                'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        else:
            self.t_k = rearrange(template, 'b c h w -> b (h w) c').contiguous()
            self.ot_k = rearrange(
                online_template,
                'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        k = torch.cat([self.t_k, self.ot_k], dim=1)

        if self.conv_proj_v is not None:
            self.t_v = rearrange(
                self.conv_proj_v(template),
                'b c h w -> b (h w) c').contiguous()
            self.ot_v = rearrange(
                self.conv_proj_v(online_template),
                'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        else:
            self.t_v = rearrange(template, 'b c h w -> b (h w) c').contiguous()
            self.ot_v = rearrange(
                online_template,
                'b c h w -> (b h w) c').contiguous().unsqueeze(0)
        v = torch.cat([self.t_v, self.ot_v], dim=1)

        q = rearrange(
            self.proj_q(q), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        k = rearrange(
            self.proj_k(k), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()
        v = rearrange(
            self.proj_v(v), 'b t (h d) -> b h t d',
            h=self.num_heads).contiguous()

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)
        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)').contiguous()

        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class MixFormerAttentionBlock(nn.Module):
    """Block containing attention operation, FFN and residual layer."""

    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 mlp_channel_ratio=4.,
                 qkv_bias=False,
                 drop=0.,
                 attn_drop=0.,
                 path_drop_probs=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        super().__init__()

        self.norm1 = norm_layer(dim_in)
        self.attn = MixedAttentionModule(
            dim_in,
            dim_out,
            num_heads,
            qkv_bias,
            attn_drop,
            drop,
            norm_cfg=norm_cfg,
            **kwargs)

        self.drop_path = DropPath(path_drop_probs) \
            if path_drop_probs > 0. else nn.Identity()
        self.norm2 = norm_layer(dim_out)

        dim_mlp_hidden = int(dim_out * mlp_channel_ratio)
        self.mlp = FFN(
            embed_dims=dim_out,
            feedforward_channels=dim_mlp_hidden,
            num_fcs=2,
            act_cfg=dict(type='GELU'),
            ffn_drop=drop,
            add_identity=False,
        )

    def forward(self, x, t_h, t_w, s_h, s_w):
        """
        Args:
            x (Tensor): concatenated feature of temmlate and search,
                shape (B, 2*t_h*t_w+s_h*s_w, C)
            t_h (int): template feature height
            t_w (int): template feature width
            s_h (int): search region feature height
            s_w (int): search region feature width
        """
        res = x
        x = self.norm1(x)
        attn = self.attn(x, t_h, t_w, s_h, s_w)
        x = res + self.drop_path(attn)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def forward_test(self, x, s_h, s_w):
        res = x

        x = self.norm1(x)
        attn = self.attn.forward_test(x, s_h, s_w)
        x = res + self.drop_path(attn)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

    def set_online(self, x, t_h, t_w):
        res = x
        x = self.norm1(x)
        attn = self.attn.set_online(x, t_h, t_w)
        x = res + self.drop_path(attn)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ConvEmbed(nn.Module):
    """Image to Conv Embedding.

    Args:
        patch_size (int): patch size
        in_chans (int): number of input channels
        embed_dim (int): embedding dimension
        stride (int): stride of convolution layer
        padding (int): number of padding
        norm_layer (nn.Module): normalization layer
    """

    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding)
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W).contiguous()

        return x


class ConvVisionTransformerLayer(BaseModule):
    """One stage of ConvVisionTransformer containing one patch embed layer and
    stacked attention blocks.

    Args:
        patch_size (int): patch size of ConvEmbed module
        patch_stride (int): patch stride of ConvEmbed module
        patch_padding (int): padding of ConvEmbed module
        in_chans (int): number of input channels
        embed_dim (int): embedding dimension
        depth (int): number of attention blocks
        num_heads (int): number of heads in multi-head attention operation
        mlp_channel_ratio (int): hidden dim ratio of FFN
        qkv_bias (bool): qkv bias
        drop_rate (float): drop rate after patch embed
        attn_drop_rate (float): drop rate in attention
        path_drop_probs (float): drop path for stochastic depth decay
        act_layer (nn.Module): activate function used in FFN
        norm_layer (nn.Module): normalization layer used in attention block
        init (str): weight init method
        norm_cfg (dict): normalization layer config
    """

    def __init__(self,
                 patch_size=16,
                 patch_stride=16,
                 patch_padding=0,
                 in_chans=3,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_channel_ratio=4.,
                 qkv_bias=False,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 path_drop_probs=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm,
                 init='trunc_norm',
                 norm_cfg=False,
                 **kwargs):
        super().__init__()
        self.init = init
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = ConvEmbed(
            patch_size=patch_size,
            in_chans=in_chans,
            stride=patch_stride,
            padding=patch_padding,
            embed_dim=embed_dim,
            norm_layer=norm_layer)

        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, path_drop_probs, depth)
               ]  # stochastic depth decay rule

        blocks = []
        for j in range(depth):
            blocks.append(
                MixFormerAttentionBlock(
                    dim_in=embed_dim,
                    dim_out=embed_dim,
                    num_heads=num_heads,
                    mlp_channel_ratio=mlp_channel_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    path_drop_probs=dpr[j],
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    norm_cfg=norm_cfg,
                    **kwargs))
        self.blocks = nn.ModuleList(blocks)

    def init_weights(self):
        if self.init == 'xavier':
            self.apply(self._init_weights_xavier)
        else:
            self.apply(self._init_weights_trunc_normal)

    def _init_weights_trunc_normal(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from trunc norm')
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _init_weights_xavier(self, m):
        if isinstance(m, nn.Linear):
            logging.info('=> init weight of Linear from xavier uniform')
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                logging.info('=> init bias of Linear to zeros')
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, template, online_template, search):
        """
        Args:
            template (Tensor): template features of shape (B, C, H, W)
            online template (Tensor): online template features
                of shape (B, C, H, W)
            search (Tensor): search features of shape (B, C, H, W)
        """
        template = self.patch_embed(template)
        online_template = self.patch_embed(online_template)
        t_B, t_C, t_H, t_W = template.size()
        search = self.patch_embed(search)
        s_B, s_C, s_H, s_W = search.size()

        template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
        online_template = rearrange(online_template,
                                    'b c h w -> b (h w) c').contiguous()
        search = rearrange(search, 'b c h w -> b (h w) c').contiguous()
        x = torch.cat([template, online_template, search], dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, t_H, t_W, s_H, s_W)

        template, online_template, search = torch.split(
            x, [t_H * t_W, t_H * t_W, s_H * s_W], dim=1)
        template = rearrange(
            template, 'b (h w) c -> b c h w', h=t_H, w=t_W).contiguous()
        online_template = rearrange(
            online_template, 'b (h w) c -> b c h w', h=t_H,
            w=t_W).contiguous()
        search = rearrange(
            search, 'b (h w) c -> b c h w', h=s_H, w=s_W).contiguous()

        return template, online_template, search

    def forward_test(self, search):
        search = self.patch_embed(search)
        s_B, s_C, s_H, s_W = search.size()

        search = rearrange(search, 'b c h w -> b (h w) c').contiguous()

        x = self.pos_drop(search)

        for i, blk in enumerate(self.blocks):
            x = blk.forward_test(x, s_H, s_W)

        search = rearrange(x, 'b (h w) c -> b c h w', h=s_H, w=s_W)

        return search

    def set_online(self, template, online_template):
        template = self.patch_embed(template)
        online_template = self.patch_embed(online_template)
        t_B, t_C, t_H, t_W = template.size()

        template = rearrange(template, 'b c h w -> b (h w) c').contiguous()
        online_template = rearrange(
            online_template, 'b c h w -> (b h w) c').unsqueeze(0).contiguous()
        x = torch.cat([template, online_template], dim=1)

        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk.set_online(x, t_H, t_W)

        template = x[:, :t_H * t_W]
        online_template = x[:, t_H * t_W:]
        template = rearrange(template, 'b (h w) c -> b c h w', h=t_H, w=t_W)
        online_template = rearrange(
            online_template.squeeze(0), '(b h w) c -> b c h w', h=t_H, w=t_W)

        return template, online_template


[文档]@BACKBONES.register_module() class ConvVisionTransformer(BaseModule): """Vision Transformer with support for patch or hybrid CNN input stage. This backbone refers to the implementation of `CvT: <https://arxiv.org/abs/2103.15808>`_. Args: in_chans (int): number of input channels act_layer (nn.Module): activate function used in FFN norm_layer (nn.Module): normalization layer used in attention block init (str): weight init method num_stage (int): number of backbone stages patch_size (List[int]): patch size of each stage patch_stride (List[int]): patch stride of each stage patch_padding (List[int]): patch padding of each stage dim_embed (List[int]): embedding dimension of each stage num_heads (List[int]): number of heads in multi-head attention operation of each stage depth (List[int]): number of attention blocks of each stage mlp_channel_ratio (List[int]): hidden dim ratio of FFN of each stage attn_drop_rate (List[float]): attn drop rate of each stage drop_rate (List[float]): drop rate of each stage path_drop_probs (List[float]): drop path of each stage qkv_bias (List[bool]): qkv bias of each stage qkv_proj_method (List[str]): qkv project method of each stage kernel_qkv (List[int]): kernel size for qkv projection of each stage padding_kv/q (List[int]): padding size for kv/q projection of each stage stride_kv/q (List[int]): stride for kv/q project of each stage norm_cfg (dict): normalization layer config """ def __init__(self, in_chans=3, act_layer=QuickGELU, norm_layer=partial(LayerNormAutofp32, eps=1e-5), init='trunc_norm', num_stages=3, patch_size=[7, 3, 3], patch_stride=[4, 2, 2], patch_padding=[2, 1, 1], dim_embed=[64, 192, 384], num_heads=[1, 3, 6], depth=[1, 4, 16], mlp_channel_ratio=[4, 4, 4], attn_drop_rate=[0.0, 0.0, 0.0], drop_rate=[0.0, 0.0, 0.0], path_drop_probs=[0.0, 0.0, 0.1], qkv_bias=[True, True, True], qkv_proj_method=['dw_bn', 'dw_bn', 'dw_bn'], kernel_qkv=[3, 3, 3], padding_kv=[1, 1, 1], stride_kv=[2, 2, 2], padding_q=[1, 1, 1], stride_q=[1, 1, 1], norm_cfg=dict(type='BN', requires_grad=False)): super().__init__() self.num_stages = num_stages for i in range(self.num_stages): kwargs = { 'patch_size': patch_size[i], 'patch_stride': patch_stride[i], 'patch_padding': patch_padding[i], 'embed_dim': dim_embed[i], 'depth': depth[i], 'num_heads': num_heads[i], 'mlp_channel_ratio': mlp_channel_ratio[i], 'qkv_bias': qkv_bias[i], 'drop_rate': drop_rate[i], 'attn_drop_rate': attn_drop_rate[i], 'path_drop_probs': path_drop_probs[i], 'method': qkv_proj_method[i], 'kernel_size': kernel_qkv[i], 'padding_q': padding_q[i], 'padding_kv': padding_kv[i], 'stride_kv': stride_kv[i], 'stride_q': stride_q[i], 'norm_cfg': norm_cfg, } stage = ConvVisionTransformerLayer( in_chans=in_chans, init=init, act_layer=act_layer, norm_layer=norm_layer, **kwargs) setattr(self, f'stage{i}', stage) in_chans = dim_embed[i] dim_embed = dim_embed[-1] self.norm = norm_layer(dim_embed) self.head = nn.Linear(dim_embed, 1000)
[文档] def forward(self, template, online_template, search): """Forward-pass method in train pipeline. Args: template (Tensor): template images of shape (B, C, H, W) online template (Tensor): online template images of shape (B, C, H, W) search (Tensor): search images of shape (B, C, H, W) """ for i in range(self.num_stages): template, online_template, search = getattr(self, f'stage{i}')( template, online_template, search) return template, search
[文档] def forward_test(self, search): """Forward-pass method for search image in test pipeline. The model forwarding strategies are different between train and test. In test pipeline, we call ``search()`` method which only takes in search image when tracker is tracking current frame. This approach reduces computational overhead and thus increases tracking speed. Args: search (Tensor): search images of shape (B, C, H, W) """ for i in range(self.num_stages): search = getattr(self, f'stage{i}').forward_test(search) return self.template, search
[文档] def set_online(self, template, online_template): """Forward-pass method for template image in test pipeline. The model forwarding strategies are different between train and test. In test pipeline, we call ``set_online()`` method which only takes in template images when tracker is initialized or is updating online template. This approach reduces computational overhead and thus increases tracking speed. Args: template (Tensor): template images of shape (B, C, H, W) online template (Tensor): online template images of shape (B, C, H, W) """ for i in range(self.num_stages): template, online_template = getattr(self, f'stage{i}').set_online( template, online_template) self.template = template
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.