Слияние кода завершено, страница обновится автоматически
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmrazor.models.builder import OPS
def test_shuffle_series():
tensor = torch.randn(16, 16, 32, 32)
# test ShuffleBlock_7x7
shuffle_block_7x7 = dict(
type='ShuffleBlock',
in_channels=16,
out_channels=16,
kernel_size=7,
stride=1)
op = OPS.build(shuffle_block_7x7)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test ShuffleBlock_5x5
shuffle_block_5x5 = dict(
type='ShuffleBlock',
in_channels=16,
out_channels=16,
kernel_size=5,
stride=1)
op = OPS.build(shuffle_block_5x5)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test ShuffleBlock_3x3
shuffle_block_3x3 = dict(
type='ShuffleBlock',
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1)
op = OPS.build(shuffle_block_3x3)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test ShuffleXception
shuffle_xception = dict(
type='ShuffleXception', in_channels=16, out_channels=16, stride=1)
op = OPS.build(shuffle_xception)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
def test_darts_series():
tensor = torch.randn(16, 16, 32, 32)
# test avg pool bn
avg_pool_bn = dict(
type='DartsPoolBN',
in_channels=16,
out_channels=16,
kernel_size=3,
pool_type='avg',
stride=1)
op = OPS.build(avg_pool_bn)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test max pool bn
max_pool_bn = dict(
type='DartsPoolBN',
in_channels=16,
out_channels=16,
kernel_size=3,
pool_type='max',
stride=1)
op = OPS.build(max_pool_bn)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test DartsSepConv
sep_conv = dict(
type='DartsSepConv',
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1)
op = OPS.build(sep_conv)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test DartsSepConv
sep_conv = dict(
type='DartsSepConv',
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1)
op = OPS.build(sep_conv)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test DartsDilConv
dil_conv = dict(
type='DartsDilConv',
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1)
op = OPS.build(dil_conv)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
# test DartsSkipConnect
skip_connect = dict(
type='DartsSkipConnect', in_channels=16, out_channels=16, stride=1)
op = OPS.build(skip_connect)
# test forward
outputs = op(tensor)
assert outputs.size(1) == 16 and outputs.size(2) == 32
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарий ( 0 )