1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Это зеркальный репозиторий, синхронизируется ежедневно с исходного репозитория.
Клонировать/Скачать
inference.py 2.1 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
pppppM Отправлено 3 лет назад 49f1bee
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Optional, Union
import mmcv
from mmcv.runner import load_checkpoint
from torch import nn
from mmrazor.models import build_algorithm
def init_mmcls_model(config: Union[str, mmcv.Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[Dict] = None) -> nn.Module:
"""Initialize a mmcls model from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): cfg_options to override some settings in the used
config.
Returns:
nn.Module: The constructed classifier.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
model_cfg = config.algorithm.architecture.model
model_cfg.pretrained = None
algorithm = build_algorithm(config.algorithm)
model = algorithm.architecture.model
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(algorithm, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = ImageNet.CLASSES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model

Комментарий ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0