3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core, scope
9 from caffe2.proto
import caffe2_pb2
12 def Iter(model, blob_out, **kwargs):
13 if 'device_option' in kwargs:
14 del kwargs[
'device_option']
15 model.param_init_net.ConstantFill(
16 [], blob_out, shape=[1], value=0, dtype=core.DataType.INT64,
19 return model.net.Iter(blob_out, blob_out, **kwargs)
22 def Accuracy(model, blob_in, blob_out, **kwargs):
23 dev = kwargs[
'device_option']
if 'device_option' in kwargs \
25 is_cpu = dev
is None or dev.device_type == caffe2_pb2.CPU
28 if not is_cpu
and 'top_k' in kwargs
and kwargs[
'top_k'] > 1:
29 pred_host = model.net.CopyGPUToCPU(blob_in[0], blob_in[0] +
"_host")
30 label_host = model.net.CopyGPUToCPU(blob_in[1], blob_in[1] +
"_host")
33 model.net.Accuracy([pred_host, label_host],
38 model.net.Accuracy(blob_in, blob_out)
def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None)