Слияние кода завершено, страница обновится автоматически
# Copyright (c) OpenMMLab. All rights reserved.
try:
import mmdet
except (ImportError, ModuleNotFoundError):
mmdet = None
if mmdet is None:
raise RuntimeError('mmdet is not installed')
import argparse
import os
import os.path as osp
import time
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import init_dist, load_checkpoint, wrap_fp16_model
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import collect_env, get_root_logger
from mmrazor.core import build_searcher
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes
def parse_args():
parser = argparse.ArgumentParser(
description='MMClsArchitecture search subnet')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work-dir',
help='working direction is '
'to save search result and log')
parser.add_argument('--resume-from', help='resume searching')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set multi-process settings
setup_multi_processes(cfg)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
model_cfg = cfg.algorithm.architecture.model
model_cfg.pretrained = None
if model_cfg.get('neck'):
if isinstance(model_cfg.neck, list):
for neck_cfg in model_cfg.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif model_cfg.neck.get('rfp_backbone'):
if model_cfg.neck.rfp_backbone.get('pretrained'):
model_cfg.neck.rfp_backbone.pretrained = None
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
if samples_per_gpu > 1:
for ds_cfg in cfg.data.test:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
model_cfg.train_cfg = None
algorithm = build_algorithm(cfg.algorithm)
model = algorithm.architecture.model
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(
algorithm, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES
if not distributed:
algorithm = MMDataParallel(algorithm, device_ids=[0])
test_fn = single_gpu_test
else:
algorithm = MMDistributedDataParallel(
algorithm.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
test_fn = multi_gpu_test
logger.info('build search...')
searcher = build_searcher(
cfg.searcher,
default_args=dict(
algorithm=algorithm,
dataloader=data_loader,
test_fn=test_fn,
work_dir=cfg.work_dir,
logger=logger,
resume_from=args.resume_from))
logger.info('start search...')
searcher.search()
if __name__ == '__main__':
main()
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарий ( 0 )