Caffe2 - Python API
A deep learning, cross platform ML framework
layer_test_util.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from collections import namedtuple
9 
10 from caffe2.python import (
11  core,
12  layer_model_instantiator,
13  layer_model_helper,
14  schema,
15  test_util,
16  workspace,
17 )
18 import numpy as np
19 
20 
21 OpSpec = namedtuple("OpSpec", "type input output")
22 
23 
25 
26  def setUp(self):
27  super(LayersTestCase, self).setUp()
28  input_feature_schema = schema.Struct(
29  ('float_features', schema.Scalar((np.float32, (32,)))),
30  )
31  trainer_extra_schema = schema.Struct()
32 
34  'test_model',
35  input_feature_schema=input_feature_schema,
36  trainer_extra_schema=trainer_extra_schema)
37 
38  def new_record(self, schema_obj):
39  return schema.NewRecord(self.model.net, schema_obj)
40 
41  def get_training_nets(self):
42  """
43  We don't use
44  layer_model_instantiator.generate_training_nets_forward_only()
45  here because it includes initialization of global constants, which make
46  testing tricky
47  """
48  train_net = core.Net('train_net')
49  train_init_net = core.Net('train_init_net')
50  for layer in self.model.layers:
51  layer.add_operators(train_net, train_init_net)
52  return train_init_net, train_net
53 
54  def get_predict_net(self):
56 
57  def run_train_net(self):
58  self.model.output_schema = schema.Struct()
59  train_init_net, train_net = \
61  workspace.RunNetOnce(train_init_net)
62  workspace.RunNetOnce(train_net)
63 
64  def assertBlobsEqual(self, spec_blobs, op_blobs):
65  """
66  spec_blobs can either be None or a list of blob names. If it's None,
67  then no assertion is performed. The elements of the list can be None,
68  in that case, it means that position will not be checked.
69  """
70  if spec_blobs is None:
71  return
72  self.assertEqual(len(spec_blobs), len(op_blobs))
73  for spec_blob, op_blob in zip(spec_blobs, op_blobs):
74  if spec_blob is None:
75  continue
76  self.assertEqual(spec_blob, op_blob)
77 
78  def assertNetContainOps(self, net, op_specs):
79  """
80  Given a net and a list of OpSpec's, check that the net match the spec
81  """
82  ops = net.Proto().op
83  self.assertEqual(len(op_specs), len(ops))
84  for op, op_spec in zip(ops, op_specs):
85  self.assertEqual(op_spec.type, op.type)
86  self.assertBlobsEqual(op_spec.input, op.input)
87  self.assertBlobsEqual(op_spec.output, op.output)
88  return ops
def generate_training_nets(model, include_tags=None)
def NewRecord(net, schema)
Definition: schema.py:908
def RunNetOnce(net)
Definition: workspace.py:160
def assertBlobsEqual(self, spec_blobs, op_blobs)
def assertNetContainOps(self, net, op_specs)
def generate_predict_net(model, include_tags=None)