# Copyright (c) OpenMMLab. All rights reserved. import argparse import numpy as np import torch from mmcv import Config, DictAction from mmcv.cnn.utils import get_model_complexity_info from mmrazor.models import build_algorithm def parse_args(): parser = argparse.ArgumentParser(description='Get model flops and params') parser.add_argument('config', help='config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[224, 224], help='input image size') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--size-divisor', type=int, default=32, help='Pad the input image, the minimum size that is divisible ' 'by size_divisor, -1 means do not pad the image.') args = parser.parse_args() return args def main(): args = parse_args() if len(args.shape) == 1: h = w = args.shape[0] elif len(args.shape) == 2: h, w = args.shape else: raise ValueError('invalid input shape') orig_shape = (3, h, w) divisor = args.size_divisor if divisor > 0: h = int(np.ceil(h / divisor)) * divisor w = int(np.ceil(w / divisor)) * divisor input_shape = (3, h, w) cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) algorithm = build_algorithm(cfg.algorithm) if torch.cuda.is_available(): algorithm.cuda() algorithm.eval() if hasattr(algorithm.architecture, 'forward_dummy'): algorithm.architecture.forward = algorithm.architecture.forward_dummy else: raise NotImplementedError( 'FLOPs counter is currently not currently supported with {}'. format(algorithm.architecture.__class__.__name__)) flops, params = get_model_complexity_info(algorithm.architecture, input_shape) split_line = '=' * 30 if divisor > 0 and \ input_shape != orig_shape: print(f'{split_line}\nUse size divisor set input shape ' f'from {orig_shape} to {input_shape}\n') print(f'{split_line}\nInput shape: {input_shape}\n' f'Flops: {flops}\nParams: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') if __name__ == '__main__': main()