3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import schema
9 from caffe2.python.layers.layers
import (
12 from caffe2.python.layers.tags
import (
21 self, model, input_record,
22 name='batch_distill_lr_loss', teacherWeight=0.0, **kwargs):
24 super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
26 assert teacherWeight >= 0
and teacherWeight <= 1, (
27 'teacherWeight=%0.2f should be in [0, 1]' % teacherWeight
39 self.
tags.update({Tags.TRAIN_ONLY})
43 model.net.NextScopedBlob(name +
'_output'))
45 def add_ops(self, net):
48 label = net.Cast(label, net.NextScopedBlob(
'int32_label'), to=
'int32')
52 class_probabilities = net.MakeTwoClass(
54 net.NextScopedBlob(
'two_class_predictions')
57 true_xent = net.LabelCrossEntropy(
58 [class_probabilities, label],
59 net.NextScopedBlob(
'cross_entropy')
61 teacher_xent = net.CrossEntropy(
63 net.NextScopedBlob(
'teacher_cross_entropy')
66 scaled_true_xent = net.Scale(
68 net.NextScopedBlob(
'scaled_cross_entropy'),
71 scaled_teacher_xent = net.Scale(
73 net.NextScopedBlob(
'scaled_teacher_cross_entropy'),
77 true_loss = net.AveragedLoss(
79 net.NextScopedBlob(
'true_loss')
81 teacher_loss = net.AveragedLoss(
83 net.NextScopedBlob(
'teacher_loss')
87 [true_loss, teacher_loss],
def is_schema_subset(schema, original_schema)