1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
cwd_pspnet.py 2.8 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
whcao Отправлено 3 лет назад f7fbe77
norm_cfg = dict(type='BN',requires_grad=True)
student = dict(
type = 'mmseg.EncoderDecoder',
backbone=dict(
type='mmseg.ResNetV1c',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='mmseg.PSPHead',
in_channels=512,
in_index=3,
channels=128,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
),
auxiliary_head=dict(
type='mmseg.FCNHead',
in_channels=256,
in_index=2,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)
),
train_cfg=dict(),
test_cfg=dict(mode='whole')
)
teacher = dict(
type = 'mmseg.EncoderDecoder',
# init_cfg=dict(
# type='Pretrained',
# checkpoint='pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth'),
backbone=dict(
type='mmseg.ResNetV1c',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True
),
decode_head=dict(
type='mmseg.PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
),
)
algorithm = dict(
type = 'GeneralDistill',
architecture = dict(
type = 'MMSegArchitecture',
model = student,
),
distiller = dict(
type = 'SingleTeacherDistiller',
teacher = teacher,
teacher_trainable = False,
components = [
dict(
student_module = 'decode_head.conv_seg',
teacher_module = 'decode_head.conv_seg',
losses = [
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau = 5,
loss_weight=3,
)
]
)
]
),
)

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.1.0