Слияние кода завершено, страница обновится автоматически
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import collect_env, get_root_logger
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import init_dist, load_checkpoint
from mmrazor.core import build_searcher
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes
# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls
try:
from mmcv.runner import wrap_fp16_model
except ImportError:
warnings.warn('wrap_fp16_model from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
from mmcls.core import wrap_fp16_model
def parse_args():
parser = argparse.ArgumentParser(
description='MMClsArchitecture search subnet')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help='working direction is '
'to save search result and log')
parser.add_argument(
'--resume-from', type=str, help='the checkpoint file to resume from')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# build the dataloader
dataset = build_dataset(cfg.data.test)
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=True)
# build the algorithm and load checkpoint
algorithm = build_algorithm(cfg.algorithm)
model = algorithm.architecture.model
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(
algorithm, args.checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
CLASSES = ImageNet.CLASSES
if not distributed:
if args.device == 'cpu':
algorithm = algorithm.cpu()
else:
algorithm = MMDataParallel(algorithm, device_ids=[0])
model.CLASSES = CLASSES
test_fn = single_gpu_test
else:
algorithm = MMDistributedDataParallel(
algorithm.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
test_fn = multi_gpu_test
logger.info('build search...')
searcher = build_searcher(
cfg.searcher,
default_args=dict(
algorithm=algorithm,
dataloader=data_loader,
test_fn=test_fn,
work_dir=cfg.work_dir,
logger=logger,
resume_from=args.resume_from))
logger.info('start search...')
searcher.search()
if __name__ == '__main__':
main()
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарий ( 0 )