# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.builder import ALGORITHMS
from .general_distill import GeneralDistill


@ALGORITHMS.register_module()
class AlignMethodDistill(GeneralDistill):

    def __init__(self, **kwargs):
        super(AlignMethodDistill, self).__init__(**kwargs)

    def train_step(self, data, optimizer):

        with self.distiller.context_manager:
            outputs = super().train_step(data, optimizer)
        return outputs