1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmrazor

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
general_distill.py 2 КБ
Копировать Редактировать Исходные данные Просмотреть построчно История
Haian Huang(深度眸) Отправлено 3 лет назад 57a5549
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.models.builder import ALGORITHMS
from mmrazor.models.utils import add_prefix
from .base import BaseAlgorithm
@ALGORITHMS.register_module()
class GeneralDistill(BaseAlgorithm):
"""General Distillation Algorithm.
Args:
with_student_loss (bool): Whether to use student loss.
Defaults to True.
with_teacher_loss (bool): Whether to use teacher loss.
Defaults to False.
"""
def __init__(self,
with_student_loss=True,
with_teacher_loss=False,
**kwargs):
super(GeneralDistill, self).__init__(**kwargs)
self.with_student_loss = with_student_loss
self.with_teacher_loss = with_teacher_loss
def train_step(self, data, optimizer):
""""""
losses = dict()
if self.with_teacher_loss:
teacher_losses = self.distiller.exec_teacher_forward(data)
teacher_losses = add_prefix(teacher_losses, 'teacher')
losses.update(teacher_losses)
else:
# Just to be able to trigger the forward hooks that
# have been registered
_ = self.distiller.exec_teacher_forward(data)
if self.with_student_loss:
student_losses = self.distiller.exec_student_forward(
self.architecture, data)
student_losses = add_prefix(student_losses, 'student')
losses.update(student_losses)
else:
# Just to be able to trigger the forward hooks that
# have been registered
_ = self.distiller.exec_student_forward(self.architecture, data)
distill_losses = self.distiller.compute_distill_loss(data)
distill_losses = add_prefix(distill_losses, 'distiller')
losses.update(distill_losses)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://gitlife.ru/oschina-mirror/open-mmlab-mmrazor.git
git@gitlife.ru:oschina-mirror/open-mmlab-mmrazor.git
oschina-mirror
open-mmlab-mmrazor
open-mmlab-mmrazor
v0.3.0