1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
split_checkpoint.py 2.2 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
pppppM Отправлено 3 лет назад cb5cb6d
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from mmcv import Config
from mmcv.runner import load_checkpoint, save_checkpoint
from mmrazor.models import build_algorithm
from mmrazor.models.pruners.utils import SwitchableBatchNorm2d
def parse_args():
parser = argparse.ArgumentParser(description='Split a slimmable trained'
'model checkpoint')
parser.add_argument('config', type=str, help='path of train config file')
parser.add_argument('checkpoint', type=str, help='checkpoint path')
parser.add_argument(
'--channel-cfgs',
nargs='+',
help='The path of the channel configs. '
'The order should be the same as that of train.')
parser.add_argument('--output-dir', type=str, default='')
args = parser.parse_args()
return args
def convert_bn(module, bn_ind):
def traverse(module):
for name, child in module.named_children():
if isinstance(child, SwitchableBatchNorm2d):
setattr(module, name, child.bns[bn_ind])
else:
traverse(child)
traverse(module)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.merge_from_dict(dict(algorithm=dict(channel_cfg=args.channel_cfgs)))
for i, channel_cfg in enumerate(args.channel_cfgs):
algorithm = build_algorithm(cfg.algorithm)
load_checkpoint(algorithm, args.checkpoint, map_location='cpu')
convert_bn(algorithm, i)
for module in algorithm.modules():
if hasattr(module, 'out_mask'):
del module.out_mask
if hasattr(module, 'in_mask'):
del module.in_mask
assert algorithm.with_pruner, \
'The algorithm should has attr pruner. Please check your ' \
'config file.'
algorithm.pruner.deploy_subnet(algorithm.architecture,
algorithm.channel_cfg[i])
filename = osp.join(args.output_dir, f'checkpoint_{i + 1}.pth')
save_checkpoint(algorithm, filename)
print(f'Successfully split the original checkpoint `{args.checkpoint}` to '
f'{len(args.channel_cfgs)} different checkpoints.')
if __name__ == '__main__':
main()

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0