Shortcuts

mmtrack.models.roi_heads.bbox_heads.selsa_bbox_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmdet.models import HEADS, ConvFCBBoxHead

from mmtrack.models import build_aggregator


[文档]@HEADS.register_module() class SelsaBBoxHead(ConvFCBBoxHead): """Selsa bbox head. This module is proposed in "Sequence Level Semantics Aggregation for Video Object Detection". `SELSA <https://arxiv.org/abs/1907.06390>`_. Args: aggregator (dict): Configuration of aggregator. """ def __init__(self, aggregator, *args, **kwargs): super(SelsaBBoxHead, self).__init__(*args, **kwargs) self.aggregator = nn.ModuleList() for i in range(self.num_shared_fcs): self.aggregator.append(build_aggregator(aggregator)) self.inplace_false_relu = nn.ReLU(inplace=False)
[文档] def forward(self, x, ref_x): """Computing the `cls_score` and `bbox_pred` of the features `x` of key frame proposals. Args: x (Tensor): of shape [N, C, H, W]. N is the number of key frame proposals. ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference frame proposals. Returns: tuple(cls_score, bbox_pred): The predicted score of classes and the predicted regression offsets. """ # shared part if self.num_shared_convs > 0: for conv in self.shared_convs: x = conv(x) ref_x = conv(ref_x) if self.num_shared_fcs > 0: if self.with_avg_pool: x = self.avg_pool(x) ref_x = self.avg_pool(ref_x) x = x.flatten(1) ref_x = ref_x.flatten(1) for i, fc in enumerate(self.shared_fcs): x = fc(x) ref_x = fc(ref_x) x = x + self.aggregator[i](x, ref_x) ref_x = self.inplace_false_relu(ref_x) x = self.inplace_false_relu(x) # separate branches x_cls = x x_reg = x for conv in self.cls_convs: x_cls = conv(x_cls) if x_cls.dim() > 2: if self.with_avg_pool: x_cls = self.avg_pool(x_cls) x_cls = x_cls.flatten(1) for fc in self.cls_fcs: x_cls = self.relu(fc(x_cls)) for conv in self.reg_convs: x_reg = conv(x_reg) if x_reg.dim() > 2: if self.with_avg_pool: x_reg = self.avg_pool(x_reg) x_reg = x_reg.flatten(1) for fc in self.reg_fcs: x_reg = self.relu(fc(x_reg)) cls_score = self.fc_cls(x_cls) if self.with_cls else None bbox_pred = self.fc_reg(x_reg) if self.with_reg else None return cls_score, bbox_pred
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.