Caffe2 - Python API
A deep learning, cross platform ML framework
batch_lr_loss.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchLRLoss(ModelLayer):
19 
20  def __init__(self, model, input_record, name='batch_lr_loss',
21  average_loss=True, **kwargs):
22  super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
23 
24  self.average_loss = average_loss
25 
28  ('label', schema.Scalar()),
29  ('prediction', schema.Scalar())
30  ),
31  input_record
32  )
33  self.tags.update({Tags.TRAIN_ONLY})
34 
36  np.float32,
37  model.net.NextScopedBlob(name + '_output'))
38 
39  # This should be a bit more complicated than it is right now
40  def add_ops(self, net):
41  class_probabilities = net.MakeTwoClass(
42  self.input_record.prediction.field_blobs(),
43  net.NextScopedBlob('two_class_predictions')
44  )
45  label = self.input_record.label.field_blobs()
46  if self.input_record.label.field_type().base != np.int32:
47  label = [net.Cast(
48  label,
49  net.NextScopedBlob('int32_label'),
50  to=core.DataType.INT32)]
51  # LabelCrossEntropyGraidentOp does not output gradient for the label
52 
53  xent = net.LabelCrossEntropy(
54  [class_probabilities] + label,
55  net.NextScopedBlob('cross_entropy'),
56  )
57  if 'weight' in self.input_record.fields:
58  weight_blob = self.input_record.weight()
59  if self.input_record.weight.field_type().base != np.float32:
60  weight_blob = net.Cast(
61  weight_blob,
62  weight_blob + '_float32',
63  to=core.DataType.FLOAT
64  )
65  weight_blob = net.StopGradient([weight_blob], [weight_blob])
66  xent = net.Mul(
67  [xent, weight_blob],
68  net.NextScopedBlob('weighted_cross_entropy'),
69  )
70 
71  if self.average_loss:
72  net.AveragedLoss(xent, self.output_schema.field_blobs())
73  else:
74  net.ReduceFrontSum(xent, self.output_schema.field_blobs())
def input_record(self)
Definition: layers.py:149
def is_schema_subset(schema, original_schema)
Definition: schema.py:985