1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
train.py 7.6 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
pppppM Отправлено 3 лет назад 49f1bee
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (EpochBasedRunner, Fp16OptimizerHook, OptimizerHook,
build_runner)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.utils import get_root_logger
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint
def set_random_seed(seed, deterministic=False):
"""Import `set_random_seed` function here was deprecated in v0.3 and will
be removed in v0.5.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set ``torch.backends.cudnn.deterministic``
to True and ``torch.backends.cudnn.benchmark`` to False.
Default: False.
"""
warnings.warn(
'Deprecated in v0.3 and will be removed in v0.5, '
'please import `set_random_seed` directly from `mmrazor.apis`',
category=DeprecationWarning)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_mmdet_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Copy from mmdetection and modify some codes.
This is an ugly implementation, and will be deprecated in the future. In
the future, there will be only one train api and no longer distinguish
between mmclassificaiton, mmsegmentation or mmdetection.
"""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
data_loader = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
if cfg.get('use_ddp_wrapper', False):
# Difference from mmdetection.
# In some algorithms, the ``optimizer.step()`` is executed in
# ``train_step``. To rebuilt reducer buckets rightly, there need to
# use DistributedDataParallelWrapper.
model = DistributedDataParallelWrapper(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
# Sets the ``find_unused_parameters`` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# build optimizers
# Difference from mmdetection.
# In some algorithms, there will be multi optimizers.
optimizer = build_optimizers(model, cfg.optimizer)
# build runner
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a ``runner`` section, '
'please set ``runner`` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and cfg.optimizer_config is not None \
and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
# Difference from mmclassification.
# MMRazor's ``DistSamplerSeedHook`` could process multi dataloaders
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
# Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loader, cfg.workflow)

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0