1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
switchable_bn.py 1.4 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
whcao Отправлено 3 лет назад da9c7a3
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
class SwitchableBatchNorm2d(nn.Module):
"""Employs independent batch normalization for different switches in a
slimmable network.
To train slimmable networks, ``SwitchableBatchNorm2d`` privatizes all
batch normalization layers for each switch in a slimmable network.
Compared with the naive training approach, it solves the problem of feature
aggregation inconsistency between different switches by independently
normalizing the feature mean and variance during testing.
Args:
max_num_features (int): The maximum ``num_features`` among BatchNorm2d
in all the switches.
num_bns (int): The number of different switches in the slimmable
networks.
"""
def __init__(self, max_num_features, num_bns):
super(SwitchableBatchNorm2d, self).__init__()
self.max_num_features = max_num_features
# number of BatchNorm2d in a SwitchableBatchNorm2d
self.num_bns = num_bns
bns = []
for _ in range(num_bns):
bns.append(nn.BatchNorm2d(max_num_features))
self.bns = nn.ModuleList(bns)
# When switching bn we should switch index simultaneously
self.index = 0
def forward(self, input):
"""Forward computation according to the current switch of the slimmable
networks."""
return self.bns[self.index](input)

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0