# Copyright (c) OpenMMLab. All rights reserved. import torch from mmrazor.models.builder import OPS def test_shuffle_series(): tensor = torch.randn(16, 16, 32, 32) # test ShuffleBlock_7x7 shuffle_block_7x7 = dict( type='ShuffleBlock', in_channels=16, out_channels=16, kernel_size=7, stride=1) op = OPS.build(shuffle_block_7x7) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test ShuffleBlock_5x5 shuffle_block_5x5 = dict( type='ShuffleBlock', in_channels=16, out_channels=16, kernel_size=5, stride=1) op = OPS.build(shuffle_block_5x5) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test ShuffleBlock_3x3 shuffle_block_3x3 = dict( type='ShuffleBlock', in_channels=16, out_channels=16, kernel_size=3, stride=1) op = OPS.build(shuffle_block_3x3) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test ShuffleXception shuffle_xception = dict( type='ShuffleXception', in_channels=16, out_channels=16, stride=1) op = OPS.build(shuffle_xception) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 def test_darts_series(): tensor = torch.randn(16, 16, 32, 32) # test avg pool bn avg_pool_bn = dict( type='DartsPoolBN', in_channels=16, out_channels=16, kernel_size=3, pool_type='avg', stride=1) op = OPS.build(avg_pool_bn) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test max pool bn max_pool_bn = dict( type='DartsPoolBN', in_channels=16, out_channels=16, kernel_size=3, pool_type='max', stride=1) op = OPS.build(max_pool_bn) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test DartsSepConv sep_conv = dict( type='DartsSepConv', in_channels=16, out_channels=16, kernel_size=3, stride=1) op = OPS.build(sep_conv) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test DartsSepConv sep_conv = dict( type='DartsSepConv', in_channels=16, out_channels=16, kernel_size=3, stride=1) op = OPS.build(sep_conv) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test DartsDilConv dil_conv = dict( type='DartsDilConv', in_channels=16, out_channels=16, kernel_size=3, stride=1) op = OPS.build(dil_conv) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32 # test DartsSkipConnect skip_connect = dict( type='DartsSkipConnect', in_channels=16, out_channels=16, stride=1) op = OPS.build(skip_connect) # test forward outputs = op(tensor) assert outputs.size(1) == 16 and outputs.size(2) == 32