Слияние кода завершено, страница обновится автоматически
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcls.models.utils import SELayer
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from ..builder import OPS
from .base import BaseOP
@OPS.register_module()
class MBBlock(BaseOP):
"""Mobilenet block for Searchable backbone.
Args:
kernel_size (int): Size of the convolving kernel.
expand_ratio (int): The input channels' expand factor of the depthwise
convolution.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
kernel_size,
expand_ratio,
se_cfg=None,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
**kwargs):
super(MBBlock, self).__init__(**kwargs)
self.with_res_shortcut = (
self.stride == 1 and self.in_channels == self.out_channels)
assert self.stride in [1, 2]
self.kernel_size = kernel_size
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.mid_channels = self.in_channels * expand_ratio
self.with_expand_conv = (self.mid_channels != self.in_channels)
if self.with_se:
assert isinstance(se_cfg, dict)
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=self.in_channels,
out_channels=self.mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.mid_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=kernel_size // 2,
groups=self.mid_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(self.mid_channels, **se_cfg)
self.linear_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + self.drop_path(out)
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )