Слияние кода завершено, страница обновится автоматически
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from mmcv.runner import BaseModule
from mmcv.utils import import_modules_from_strings
def function_wrapper(ctx, method, method_str):
"""Pass teacher's outputs to student."""
def wrapper(*args, **kwargs):
# record inputs
ctx.method_args[method_str] = args
ctx.method_kwargs[method_str] = kwargs
# TODO cover more usecases, not only pass teacher's outputs to
# student.
if ctx.is_teacher:
# execute the raw function
outputs = method(*args, **kwargs)
# record outputs
ctx.method_return[method_str] = outputs
else:
# modify student's outputs to be same with teacher
outputs = ctx.method_return[method_str]
return outputs
return wrapper
class FunctionContext():
"""Function context manager for rewrite function.
Args:
ctx (ConversionContext): The distiller's overall context manager.
method (str): The name of the function to rewrite.
"""
def __init__(self, ctx, method, import_module=None):
self.ctx = ctx
self.import_module = import_modules_from_strings(import_module)
self.method_str = method
self.method_exec_str = f'self.import_module.{method}'
def _set_method(self, method):
"""Modify a function."""
exec(f'{self.method_exec_str} = method')
def __enter__(self):
"""Rewrite the function."""
self.method_impl = eval(self.method_exec_str)
if self.method_impl:
self._set_method(
function_wrapper(self.ctx, self.method_impl, self.method_str,
self.align_mode))
def __exit__(self, exc_type, exc_value, traceback):
"""Restore the function."""
if self.method_impl:
self._set_method(self.method_impl)
class ConversionContext():
"""Context manager for record functions' inputs or outputs."""
def __init__(self, hooks):
# save functions' inputs
self.method_args = dict()
self.method_kwargs = dict()
# save functions' outputs
self.method_return = dict()
# Each function will have a sub context manager, the function will be
# rewritten when enter the sub context manager.
self.hooks = []
self.is_teacher = True
for hook in hooks:
self.hooks.append(FunctionContext(self, **hook))
def __enter__(self):
"""Enter every sub context managers."""
for hook in self.hooks:
hook.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit every sub context managers."""
for hook in self.hooks:
hook.__exit__(exc_type, exc_value, traceback)
class BaseDistiller(BaseModule, metaclass=ABCMeta):
"""Base Distiller.
In the distillation algorithm, some intermediate results of the teacher
need to be obtained and passed to the student.
For nn.Module's outputs, obtained by pytorch forward hook.
For python function's outputs, obtained by a specific context manager.
Args:
align_methods (dict): The details of the functions which outputs need
to be obtained.
"""
def __init__(self, align_methods=None, **kwargs):
super(BaseDistiller, self).__init__(**kwargs)
if align_methods is None:
self.context_manager = None
else:
# To obtain the python function's outputs, there will build a
# specific context manager. When enter the context manager, the
# functions will be rewrite. The context manager could record
# inputs or outputs of the functions , and pass from teachr to
# student. When exit the context manager, the rewritten functions
# will restore.
self.context_manager = ConversionContext(align_methods)
@abstractmethod
def prepare_from_student(self, student):
"""Register forward hooks to students and teachers."""
pass
@abstractmethod
def teacher_forward_output_hook(self, module, inputs, outputs):
"""Save the teacher output."""
pass
@abstractmethod
def student_forward_output_hook(self, module, inputs, outputs):
"""Save the student output."""
pass
def reset_ctx_teacher_mode(self, mode=True):
if self.context_manager is not None:
self.context_manager.is_teacher = mode
@abstractmethod
def exec_teacher_forward(self, data):
"""Execute the teacher's forward function."""
pass
@abstractmethod
def exec_student_forward(self, student, data):
"""Execute the student's forward function."""
pass
@abstractmethod
def compute_distill_loss(self, data):
"""Compute distill loss according teacher's outputs and student's
outputs."""
pass
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарий ( 0 )