Caffe2 - Python API
A deep learning, cross platform ML framework
device_checker.py
1 
3 import numpy as np
4 import copy
5 from caffe2.python import workspace
6 
7 
8 class DeviceChecker(object):
9  """A device checker in Python to check consistency across multiple devices.
10 
11  This is not the most efficient way to check devices, as the Python interface
12  will involve a lot of copy back and forth operations. Use at your own risk.
13  """
14 
15  def __init__(self, threshold, device_options):
16  self._threshold = threshold
17  self._device_options = device_options
18 
19  def CheckSimple(self, op, inputs, outputs_to_check,
20  input_device_options=None):
21  """Checks the operator with different device implementations.
22 
23  Inputs:
24  op: the operator to be checked.
25  inputs: the input data in numpy arrays.
26  outputs_to_check: the outputs to check between devices.
27  input_device_options: a mapping from input name to a device to use
28  (instead of self._device_options)
29  Outputs:
30  boolean: True if it passes, False if it does not pass.
31  """
32  op = copy.deepcopy(op)
33  input_device_options = input_device_options or {}
34  # Entering the checker workspace
35  old_ws_name = workspace.CurrentWorkspace()
36  results = []
37  workspace.SwitchWorkspace("_device_check_", True)
38  for i, device_option in enumerate(self._device_options):
39  for i, arr in enumerate(inputs):
41  op.input[i], np.array(arr),
42  input_device_options.get(op.input[i], device_option))
43  op.device_option.CopyFrom(device_option)
45  results.append(
46  [workspace.FetchBlob(op.output[idx])
47  for idx in outputs_to_check])
48  # Everything is done, reset the workspace.
50  # After running on all devices, check correctness
51  success = True
52  for i in range(1, len(self._device_options)):
53  for j in range(len(outputs_to_check)):
54  x = results[i][j]
55  y = results[0][j]
56  if not np.allclose(x, y,
57  atol=self._threshold, rtol=self._threshold):
58  print('Failure in checking device option {}'
59  ' and output {}. The outputs are:'
60  .format(i, op.output[outputs_to_check[j]]))
61  print(x.flatten())
62  print(y.flatten())
63  print(np.max(np.abs(x - y)))
64  success = False
65  # else:
66  # print ('Passed device pair (0, %d), %s %s' %
67  # (i, outputs_to_check[j], y.shape))
68  workspace.SwitchWorkspace(old_ws_name)
69  return success
70 
71  def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set()):
72  """Checks a network by inspecting all of its intermediate results, and
73  see if things match.
74  """
75  old_ws_name = workspace.CurrentWorkspace()
76  results = []
77  if blobs_to_check is None:
78  blobs_to_check = sum([list(op.output) for op in net.op], [])
79  blobs_to_check = [b for b in blobs_to_check if b not in ignore]
80  workspace.SwitchWorkspace("_device_check_", True)
81  for i, device_option in enumerate(self._device_options):
82  for name, arr in inputs.items():
83  # print 'feeding', name
84  workspace.FeedBlob(name, arr, device_option)
85  for op in net.op:
86  op.device_option.CopyFrom(device_option)
88  results.append(
89  [workspace.FetchBlob(name) for name in blobs_to_check]
90  )
91  # After running on all devices, check correctness
92  success = True
93  for i in range(1, len(results)):
94  for j in range(len(blobs_to_check)):
95  x = results[i][j]
96  y = results[0][j]
97  if not np.allclose(x, y,
98  atol=self._threshold, rtol=self._threshold):
99  print('Failure in checking device option {}'
100  ' and output {}. The outputs are:'
101  .format(i, blobs_to_check[j]))
102  print(x.flatten())
103  print(y.flatten())
104  print(np.max(np.abs(x - y)))
105  success = False
106  # else:
107  # print ('Passed device pair (%d, %d), %s %s: %s' %
108  # (i, j, blobs_to_check[j], y.shape,
109  # str(y.flatten())))
110  workspace.SwitchWorkspace(old_ws_name)
111  return success
def ResetWorkspace(root_folder=None)
Definition: workspace.py:130
SwitchWorkspace
Definition: workspace.py:30
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def RunNetOnce(net)
Definition: workspace.py:160
CurrentWorkspace
Definition: workspace.py:24
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def RunOperatorOnce(operator)
Definition: workspace.py:148
def FetchBlob(name)
Definition: workspace.py:276
def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set())