1 #ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ 2 #define CAFFE2_CORE_OPERATOR_GRADIENT_H_ 4 #include "caffe2/core/registry.h" 5 #include "caffe2/proto/caffe2.pb.h" 6 #include "caffe2/utils/proto_utils.h" 21 inline bool IsDense()
const {
24 inline bool IsSparse()
const {
25 return (indices_.size() || values_.size());
27 inline bool IsEmpty()
const {
28 return (!IsDense() && !IsSparse());
36 vector<OperatorDef> ops_;
37 vector<GradientWrapper> g_input_;
41 const vector<OperatorDef>& ops,
42 const vector<GradientWrapper>& v)
43 : ops_(ops), g_input_(v) {}
49 const OperatorDef& def,
50 const vector<GradientWrapper>& g_output)
51 : def_(def), g_output_(g_output), g_input_(def.input_size()){};
53 virtual bool CopyDeviceOption()
const {
56 virtual bool CopyEngine()
const {
59 virtual bool CopyArguments()
const {
75 vector<OperatorDef> new_defs = GetGradientDefs();
76 for (
auto& opdef : new_defs) {
77 opdef.set_is_gradient_op(
true);
82 const OperatorDef& Def()
const {
87 virtual vector<OperatorDef> GetGradientDefs() {
88 CAFFE_NOT_IMPLEMENTED;
97 string I(
const int i) {
98 CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
101 string O(
const int i) {
102 CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
103 return def_.output(i);
105 string GI(
const int i) {
107 !g_input_.at(i).IsSparse(),
110 " already set to sparse.");
111 g_input_.at(i).dense_ = GradientName(def_.input(i));
112 return GradientName(def_.input(i));
114 string GI_I(
const int i) {
116 !g_input_.at(i).IsDense(),
119 " already set to dense.");
120 g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
121 return GradientSliceIndices(def_.input(i));
123 string GI_V(
const int i) {
125 !g_input_.at(i).IsDense(),
128 " already set to dense.");
129 g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
130 return GradientSliceValues(def_.input(i));
132 string GO(
const int i) {
134 g_output_.at(i).IsDense(),
135 "Gradient of output ",
137 " is either sparse or not provided.");
138 return g_output_.at(i).dense_;
140 string GO_I(
const int i) {
142 g_output_.at(i).IsSparse(),
143 "Gradient of output ",
145 " is either dense or not provided.");
146 return g_output_.at(i).indices_;
148 string GO_V(
const int i) {
150 g_output_.at(i).IsSparse(),
151 "Gradient of output ",
153 "is either dense or not provided.");
154 return g_output_.at(i).values_;
157 return g_output_.at(i);
161 void SetDense(
const int i,
const string& name) {
163 !g_input_.at(i).IsSparse(),
166 " already set to sparse.");
167 g_input_.at(i).dense_ = name;
169 void SetSparse(
const int i,
const string& indices,
const string& values) {
171 !g_input_.at(i).IsDense(),
174 " already set to dense.");
175 g_input_.at(i).indices_ = indices;
176 g_input_.at(i).values_ = values;
183 template <
class... Args>
185 return vector<OperatorDef>{CreateOperatorDef(args...)};
194 CaffeMap<string, string> m;
195 for (
auto& out : op.output()) {
196 if (IsGradientBlob(out)) {
197 m[out] = out.substr(0, out.length() - 5);
206 static string GradientName(
const string& name) {
207 return name +
"_grad";
210 static bool IsGradientBlob(
const string& name) {
211 return name.length() > 5 && name.find(
"_grad") == name.length() - 5;
214 static string GradientNameToParam(
const string& name) {
215 CHECK(IsGradientBlob(name));
216 return name.substr(0, name.length() - 5);
219 static string GradientSliceIndices(
const string& name) {
220 return name +
"_grad_indices";
223 static string GradientSliceValues(
const string& name) {
224 return name +
"_grad_values";
230 const OperatorDef& def_;
231 const vector<GradientWrapper>& g_output_;
232 vector<GradientWrapper> g_input_;
245 using GradientMakerBase::GradientMakerBase;
246 vector<OperatorDef> GetGradientDefs()
override {
247 return vector<OperatorDef>();
258 using GradientMakerBase::GradientMakerBase;
261 false,
"One should not call gradient for operator ", def_.type(),
".");
273 using GradientMakerBase::GradientMakerBase;
279 " should have a gradient but is not implemented yet.");
283 CAFFE_DECLARE_REGISTRY(
287 const vector<GradientWrapper>&);
289 #define REGISTER_GRADIENT(name, ...) \ 290 CAFFE_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) 291 #define REGISTER_GRADIENT_STR(str_name, ...) \ 292 CAFFE_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) 295 #define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) 300 #define SHOULD_NOT_DO_GRADIENT(name) \ 301 REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) 303 #define GRADIENT_NOT_IMPLEMENTED_YET(name) \ 304 REGISTER_GRADIENT(name, GradientNotImplementedYet) 310 const OperatorDef& def,
311 const vector<GradientWrapper>& g_output);
315 #endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_
static vector< OperatorDef > SingleGradientDef(const Args &... args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
static CaffeMap< string, string > MatchGradsToParams(const OperatorDef &op)
Returns map that returns the parameters that the gradients are for.
A helper class to indicate that the gradient mechanism is not ready.
GradientOpsMeta Get() override
Returns the gradient ops meta.
A helper class to indicate that the operator should have no gradient.
GradientOpsMeta Get() override
Returns the gradient ops meta.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
GradientOpsMeta GetGradientForOp(const OperatorDef &def, const vector< GradientWrapper > &g_output)
Gets the GradientOpsMeta for the given operator def.
A helper class to indicate that the operator does not need gradient computation.
virtual GradientOpsMeta Get()
Returns the gradient ops meta.