1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
mobilenet_series.py 4.1 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
pppppM Отправлено 3 лет назад ef1637e
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcls.models.utils import SELayer
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from ..builder import OPS
from .base import BaseOP
@OPS.register_module()
class MBBlock(BaseOP):
"""Mobilenet block for Searchable backbone.
Args:
kernel_size (int): Size of the convolving kernel.
expand_ratio (int): The input channels' expand factor of the depthwise
convolution.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
kernel_size,
expand_ratio,
se_cfg=None,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
**kwargs):
super(MBBlock, self).__init__(**kwargs)
self.with_res_shortcut = (
self.stride == 1 and self.in_channels == self.out_channels)
assert self.stride in [1, 2]
self.kernel_size = kernel_size
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.mid_channels = self.in_channels * expand_ratio
self.with_expand_conv = (self.mid_channels != self.in_channels)
if self.with_se:
assert isinstance(se_cfg, dict)
if self.with_expand_conv:
self.expand_conv = ConvModule(
in_channels=self.in_channels,
out_channels=self.mid_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.mid_channels,
kernel_size=kernel_size,
stride=self.stride,
padding=kernel_size // 2,
groups=self.mid_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(self.mid_channels, **se_cfg)
self.linear_conv = ConvModule(
in_channels=self.mid_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
def _inner_forward(x):
out = x
if self.with_expand_conv:
out = self.expand_conv(out)
out = self.depthwise_conv(out)
if self.with_se:
out = self.se(out)
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + self.drop_path(out)
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0