1 #include "caffe2/core/tensor.h" 3 #include "caffe2/core/blob_stats.h" 9 "If set, keeps memory when a tensor is shrinking its size.");
13 CAFFE_KNOWN_TYPE(Tensor<CPUContext>);
15 TensorPrinter::TensorPrinter(
16 const std::string& tensor_name,
17 const std::string& file_name,
19 : to_file_(!file_name.empty()),
20 limit_(limit ? limit : k_limit_default_),
21 tensor_name_(tensor_name) {
25 log_file_.reset(
new std::ofstream(
26 file_name, std::ofstream::out | std::ofstream::trunc));
29 "Failed to open TensorPrinter file ",
32 log_file_->rdstate());
36 TensorPrinter::~TensorPrinter() {
37 if (log_file_.get()) {
42 std::string TensorPrinter::MetaStr(
const Tensor<CPUContext>& tensor) {
43 std::stringstream meta_stream;
44 meta_stream <<
"Tensor " << tensor_name_ <<
" of type " 45 << tensor.meta().name() <<
". Dims: (";
46 for (
const auto dim : tensor.dims()) {
47 meta_stream << dim <<
",";
50 return meta_stream.str();
53 static CaffeMap<CaffeTypeId, TypeCall> type_call_registry_ {
54 {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorType<Tensor<CPUContext>>}
57 TypeCall GetTypeCallFunction(CaffeTypeId
id) {
58 auto f = type_call_registry_.find(
id);
59 if (f == type_call_registry_.end()) {
65 void RegisterTypeCallFunction(CaffeTypeId
id, TypeCall c) {
66 type_call_registry_[id] = c;
69 static CaffeMap<CaffeTypeId, ShapeCall> shape_call_registry_ {
70 {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorShape<Tensor<CPUContext>>}
73 ShapeCall GetShapeCallFunction(CaffeTypeId
id) {
74 auto f = shape_call_registry_.find(
id);
75 if (f == shape_call_registry_.end()) {
81 void RegisterShapeCallFunction(CaffeTypeId
id, ShapeCall c) {
82 shape_call_registry_[id] = c;
87 struct TensorCPUStatGetter : BlobStatGetter {
88 size_t sizeBytes(
const Blob& blob)
const override {
89 const auto& tensor = blob.Get<TensorCPU>();
90 auto nbytes = tensor.nbytes();
91 if (nbytes > 0 && tensor.IsType<std::string>()) {
92 const auto* data = tensor.data<std::string>();
93 for (
size_t i = 0; i < tensor.size(); ++i) {
94 nbytes += data[i].size();
100 REGISTER_BLOB_STAT_GETTER(TensorCPU, TensorCPUStatGetter);
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Commandline flags support for Caffe2.