Caffe2 - Python API
A deep learning, cross platform ML framework
batch_mse_loss.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from caffe2.python.layers.tags import (
13  Tags
14 )
15 import numpy as np
16 
17 
18 class BatchMSELoss(ModelLayer):
19 
20  def __init__(self, model, input_record, name='batch_mse_loss', **kwargs):
21  super(BatchMSELoss, self).__init__(model, name, input_record, **kwargs)
22 
25  ('label', schema.Scalar()),
26  ('prediction', schema.Scalar())
27  ),
28  input_record
29  )
30  self.tags.update({Tags.TRAIN_ONLY})
31 
33  np.float32,
34  model.net.NextScopedBlob(name + '_output'))
35 
36  def add_ops(self, net):
37  prediction = net.Squeeze(
38  self.input_record.prediction(),
39  net.NextScopedBlob('squeezed_prediction'),
40  dims=[1]
41  )
42 
43  label = net.StopGradient(
44  self.input_record.label(),
45  net.NextScopedBlob('stopped_label')
46  )
47 
48  l2dist = net.SquaredL2Distance(
49  [label, prediction],
50  net.NextScopedBlob('l2')
51  )
52 
53  net.AveragedLoss(l2dist, self.output_schema.field_blobs())
def input_record(self)
Definition: layers.py:149
def is_schema_subset(schema, original_schema)
Definition: schema.py:985