# Copyright (c) OpenMMLab. All rights reserved. import torch from mmrazor.models.builder import MUTABLES def test_one_shot_op(): oneshot_choice_op = dict( type='OneShotOP', space_id='test', num_chosen=1, choices=dict( shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3), shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5), shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7), shuffle_xception=dict(type='ShuffleXception'), ), choice_args=dict(in_channels=16, out_channels=16, stride=1)) model = MUTABLES.build(oneshot_choice_op) tensor = torch.randn(16, 16, 32, 32) # test forward outputs = model(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 def test_differentiable_op(): oneshot_choice_op = dict( type='DifferentiableOP', space_id='test', num_chosen=1, with_arch_param=True, choices=dict( zero=dict(type='DartsZero'), skip_connect=dict(type='DartsSkipConnect'), dil_conv_3x3=dict(type='DartsDilConv', kernel_size=3), dil_conv_5x5=dict(type='DartsDilConv', kernel_size=5), sep_conv_3x3=dict(type='DartsSepConv', kernel_size=3), sep_conv_5x5=dict(type='DartsSepConv', kernel_size=5), ), choice_args=dict(in_channels=16, out_channels=16, stride=2)) model = MUTABLES.build(oneshot_choice_op) arch_param = model.build_arch_param() tensor = torch.randn(16, 16, 32, 32) # test forward outputs = model(tensor, arch_param) assert outputs.size(1) == 16 and outputs.size(2) == 16