5 from caffe2.python
import workspace
9 """A device checker in Python to check consistency across multiple devices. 11 This is not the most efficient way to check devices, as the Python interface 12 will involve a lot of copy back and forth operations. Use at your own risk. 15 def __init__(self, threshold, device_options):
20 input_device_options=None):
21 """Checks the operator with different device implementations. 24 op: the operator to be checked. 25 inputs: the input data in numpy arrays. 26 outputs_to_check: the outputs to check between devices. 27 input_device_options: a mapping from input name to a device to use 28 (instead of self._device_options) 30 boolean: True if it passes, False if it does not pass. 32 op = copy.deepcopy(op)
33 input_device_options = input_device_options
or {}
39 for i, arr
in enumerate(inputs):
41 op.input[i], np.array(arr),
42 input_device_options.get(op.input[i], device_option))
43 op.device_option.CopyFrom(device_option)
47 for idx
in outputs_to_check])
53 for j
in range(len(outputs_to_check)):
56 if not np.allclose(x, y,
58 print(
'Failure in checking device option {}' 59 ' and output {}. The outputs are:' 60 .format(i, op.output[outputs_to_check[j]]))
63 print(np.max(np.abs(x - y)))
71 def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set()):
72 """Checks a network by inspecting all of its intermediate results, and 77 if blobs_to_check
is None:
78 blobs_to_check = sum([list(op.output)
for op
in net.op], [])
79 blobs_to_check = [b
for b
in blobs_to_check
if b
not in ignore]
82 for name, arr
in inputs.items():
86 op.device_option.CopyFrom(device_option)
93 for i
in range(1, len(results)):
94 for j
in range(len(blobs_to_check)):
97 if not np.allclose(x, y,
99 print(
'Failure in checking device option {}' 100 ' and output {}. The outputs are:' 101 .format(i, blobs_to_check[j]))
104 print(np.max(np.abs(x - y)))
def ResetWorkspace(root_folder=None)
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def FeedBlob(name, arr, device_option=None)
def RunOperatorOnce(operator)
def CheckNet(self, net, inputs={}, blobs_to_check=None, ignore=set())