Caffe2 - Python API
A deep learning, cross platform ML framework
batch_distill_lr_loss.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchDistillLRLoss(ModelLayer):
19 
20  def __init__(
21  self, model, input_record,
22  name='batch_distill_lr_loss', teacherWeight=0.0, **kwargs):
23 
24  super(BatchDistillLRLoss, self).__init__(model, name, input_record, **kwargs)
25 
26  assert teacherWeight >= 0 and teacherWeight <= 1, (
27  'teacherWeight=%0.2f should be in [0, 1]' % teacherWeight
28  )
29  self._teacherWeight = teacherWeight
30 
33  ('teacher_label', schema.Scalar()),
34  ('label', schema.Scalar()),
35  ('prediction', schema.Scalar())
36  ),
37  input_record
38  )
39  self.tags.update({Tags.TRAIN_ONLY})
40 
42  np.float32,
43  model.net.NextScopedBlob(name + '_output'))
44 
45  def add_ops(self, net):
46  label = self.input_record.label()
47  if self.input_record.label.field_type() != np.int32:
48  label = net.Cast(label, net.NextScopedBlob('int32_label'), to='int32')
49 
50  teacher_label = self.input_record.teacher_label()
51 
52  class_probabilities = net.MakeTwoClass(
53  self.input_record.prediction(),
54  net.NextScopedBlob('two_class_predictions')
55  )
56 
57  true_xent = net.LabelCrossEntropy(
58  [class_probabilities, label],
59  net.NextScopedBlob('cross_entropy')
60  )
61  teacher_xent = net.CrossEntropy(
62  [self.input_record.prediction(), teacher_label],
63  net.NextScopedBlob('teacher_cross_entropy')
64  )
65 
66  scaled_true_xent = net.Scale(
67  true_xent,
68  net.NextScopedBlob('scaled_cross_entropy'),
69  scale=1.0 - self._teacherWeight,
70  )
71  scaled_teacher_xent = net.Scale(
72  teacher_xent,
73  net.NextScopedBlob('scaled_teacher_cross_entropy'),
74  scale=self._teacherWeight,
75  )
76 
77  true_loss = net.AveragedLoss(
78  scaled_true_xent,
79  net.NextScopedBlob('true_loss')
80  )
81  teacher_loss = net.AveragedLoss(
82  scaled_teacher_xent,
83  net.NextScopedBlob('teacher_loss')
84  )
85 
86  net.Add(
87  [true_loss, teacher_loss],
88  self.output_schema.field_blobs()
89  )
def input_record(self)
Definition: layers.py:149
def is_schema_subset(schema, original_schema)
Definition: schema.py:985