1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
darts_series.py 5.6 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
pppppM Отправлено 3 лет назад c5a097b
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks import DropPath
from ..builder import OPS
from .base import BaseOP
@OPS.register_module()
class DartsPoolBN(BaseOP):
def __init__(self,
pool_type,
kernel_size=3,
norm_cfg=dict(type='BN'),
use_drop_path=False,
**kwargs):
super(DartsPoolBN, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
if pool_type == 'max':
self.pool = nn.MaxPool2d(self.kernel_size, self.stride, 1)
elif pool_type == 'avg':
self.pool = nn.AvgPool2d(
self.kernel_size, self.stride, 1, count_include_pad=False)
self.bn = build_norm_layer(self.norm_cfg, self.out_channels)[1]
if use_drop_path:
self.drop_path = DropPath()
else:
self.drop_path = None
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
if self.drop_path is not None:
out = self.drop_path(out)
return out
@OPS.register_module()
class DartsDilConv(BaseOP):
def __init__(self,
kernel_size,
use_drop_path=False,
norm_cfg=dict(type='BN'),
**kwargs):
super(DartsDilConv, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
self.dilation = 2
assert self.kernel_size in [3, 5]
assert self.stride in [1, 2]
self.conv1 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(
self.in_channels,
self.in_channels,
self.kernel_size,
self.stride, (self.kernel_size // 2) * self.dilation,
dilation=self.dilation,
groups=self.in_channels,
bias=False),
nn.Conv2d(
self.in_channels, self.out_channels, 1, stride=1, bias=False),
build_norm_layer(self.norm_cfg, self.in_channels)[1])
if use_drop_path:
self.drop_path = DropPath()
else:
self.drop_path = None
def forward(self, x):
out = self.conv1(x)
if self.drop_path is not None:
out = self.drop_path(out)
return out
@OPS.register_module()
class DartsSepConv(BaseOP):
def __init__(self,
kernel_size,
use_drop_path=False,
norm_cfg=dict(type='BN'),
**kwargs):
super(DartsSepConv, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.norm_cfg = norm_cfg
assert self.kernel_size in [3, 5]
assert self.stride in [1, 2]
self.conv1 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(
self.in_channels,
self.in_channels,
self.kernel_size,
self.stride,
self.kernel_size // 2,
groups=self.in_channels,
bias=False),
nn.Conv2d(
self.in_channels, self.in_channels, 1, stride=1, bias=False),
build_norm_layer(self.norm_cfg, self.in_channels)[1])
self.conv2 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(
self.in_channels,
self.out_channels,
self.kernel_size,
1,
self.kernel_size // 2,
groups=self.in_channels,
bias=False),
nn.Conv2d(
self.out_channels, self.out_channels, 1, stride=1, bias=False),
build_norm_layer(self.norm_cfg, self.out_channels)[1])
if use_drop_path:
self.drop_path = DropPath()
else:
self.drop_path = None
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
if self.drop_path is not None:
out = self.drop_path(out)
return out
@OPS.register_module()
class DartsSkipConnect(BaseOP):
"""Reduce feature map size by factorized pointwise (stride=2)."""
def __init__(self,
use_drop_path=False,
norm_cfg=dict(type='BN'),
**kwargs):
super(DartsSkipConnect, self).__init__(**kwargs)
self.norm_cfg = norm_cfg
if self.stride > 1:
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(
self.in_channels,
self.out_channels // 2,
1,
stride=2,
padding=0,
bias=False)
self.conv2 = nn.Conv2d(
self.in_channels,
self.out_channels // 2,
1,
stride=2,
padding=0,
bias=False)
self.bn = build_norm_layer(self.norm_cfg, self.out_channels)[1]
if use_drop_path:
self.drop_path = DropPath()
else:
self.drop_path = None
def forward(self, x):
if self.stride > 1:
x = self.relu(x)
out = torch.cat(
[self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
if self.drop_path is not None:
out = self.drop_path(out)
else:
out = x
return out
@OPS.register_module()
class DartsZero(BaseOP):
def __init__(self, **kwargs):
super(DartsZero, self).__init__(**kwargs)
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0