# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp

from mmcv import Config
from mmcv.runner import load_checkpoint, save_checkpoint

from mmrazor.models import build_algorithm
from mmrazor.models.pruners.utils import SwitchableBatchNorm2d


def parse_args():
    parser = argparse.ArgumentParser(description='Split a slimmable trained'
                                     'model checkpoint')
    parser.add_argument('config', type=str, help='path of train config file')
    parser.add_argument('checkpoint', type=str, help='checkpoint path')
    parser.add_argument(
        '--channel-cfgs',
        nargs='+',
        help='The path of the channel configs. '
        'The order should be the same as that of train.')
    parser.add_argument('--output-dir', type=str, default='')
    args = parser.parse_args()

    return args


def convert_bn(module, bn_ind):

    def traverse(module):
        for name, child in module.named_children():
            if isinstance(child, SwitchableBatchNorm2d):
                setattr(module, name, child.bns[bn_ind])
            else:
                traverse(child)

    traverse(module)


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    cfg.merge_from_dict(dict(algorithm=dict(channel_cfg=args.channel_cfgs)))

    for i, channel_cfg in enumerate(args.channel_cfgs):
        algorithm = build_algorithm(cfg.algorithm)
        load_checkpoint(algorithm, args.checkpoint, map_location='cpu')
        convert_bn(algorithm, i)
        for module in algorithm.modules():
            if hasattr(module, 'out_mask'):
                del module.out_mask
            if hasattr(module, 'in_mask'):
                del module.in_mask
        assert algorithm.with_pruner, \
            'The algorithm should has attr pruner. Please check your ' \
            'config file.'
        algorithm.pruner.deploy_subnet(algorithm.architecture,
                                       algorithm.channel_cfg[i])
        filename = osp.join(args.output_dir, f'checkpoint_{i + 1}.pth')
        save_checkpoint(algorithm, filename)

    print(f'Successfully split the original checkpoint `{args.checkpoint}` to '
          f'{len(args.channel_cfgs)} different checkpoints.')


if __name__ == '__main__':
    main()