1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
test_pruner.py 2.7 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
whcao Отправлено 3 лет назад f7fbe77
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from mmrazor.models.builder import ARCHITECTURES, PRUNERS
def test_ratio_pruner():
model_cfg = dict(
type='mmcls.ImageClassifier',
backbone=dict(
type='mmcls.ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='mmcls.GlobalAveragePooling'),
head=dict(
type='mmcls.LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
architecture_cfg = dict(
type='MMClsArchitecture',
model=model_cfg,
)
pruner_cfg = dict(
type='RatioPruner',
ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])
imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))
architecture = ARCHITECTURES.build(architecture_cfg)
pruner = PRUNERS.build(pruner_cfg)
pruner.prepare_from_supernet(architecture)
assert hasattr(pruner, 'channel_spaces')
# test set_min_channel
pruner_cfg_ = deepcopy(pruner_cfg)
pruner_cfg_['ratios'].insert(0, 0)
pruner_ = PRUNERS.build(pruner_cfg_)
architecture_ = ARCHITECTURES.build(architecture_cfg)
pruner_.prepare_from_supernet(architecture_)
with pytest.raises(AssertionError):
# Output channels should be a positive integer not zero
pruner_.set_min_channel()
# test set_max_channel
pruner.set_max_channel()
for name, module in architecture.model.named_modules():
if hasattr(module, 'in_mask'):
assert module.in_mask.sum() == module.in_mask.numel()
if hasattr(module, 'out_mask'):
assert module.out_mask.sum() == module.out_mask.numel()
# test channel bins
pruner.set_min_channel()
channel_bins_dict = pruner.get_max_channel_bins(max_channel_bins=4)
pruner.set_channel_bins(channel_bins_dict, 4)
for name, module in architecture.model.named_modules():
if hasattr(module, 'in_mask'):
assert module.in_mask.sum() == module.in_mask.numel()
if hasattr(module, 'out_mask'):
assert module.out_mask.sum() == module.out_mask.numel()
# test making groups logic
subnet_dict = pruner.sample_subnet()
assert isinstance(subnet_dict, dict)
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
assert isinstance(subnet_dict, dict)
pruner.deploy_subnet(architecture, subnet_dict)
losses = architecture(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.1.0