3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core
9 from caffe2.python.scope
import CurrentNameScope
10 from caffe2.python.cnn
import CNNModelHelper
15 net, cell_net, inputs, initial_cell_inputs,
16 links, timestep=None, scope=None, outputs_with_grads=(0,),
17 recompute_blobs_on_backward=
None,
20 net: the main net operator should be added to 22 cell_net: cell_net which is executed in a recurrent fasion 24 inputs: sequences to be fed into the recurrent net. Currently only one input 25 is supported. It has to be in a format T x N x (D1...Dk) where T is lengths 26 of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions 28 initial_cell_inputs: inputs of the cell_net for the 0 timestamp. 29 Format for each input is: 30 (cell_net_input_name, external_blob_with_data) 32 links: a dictionary from cell_net input names in moment t+1 and 33 output names of moment t. Currently we assume that each output becomes 34 an input for the next timestep. 36 timestep: name of the timestep blob to be used. If not provided "timestep" 39 scope: Internal blobs are going to be scoped in a format 40 <scope_name>/<blob_name> 41 If not provided we generate a scope name automatically 43 outputs_with_grads : position indices of output blobs which will receive 44 error gradient (from outside recurrent network) during backpropagation 46 recompute_blobs_on_backward: specify a list of blobs that will be 47 recomputed for backward pass, and thus need not to be 48 stored for each forward timestep. 50 assert len(inputs) == 1,
"Only one input blob is supported so far" 53 for einp
in cell_net.Proto().external_input:
54 assert einp.startswith(CurrentNameScope()), \
56 Cell net external inputs are not properly scoped, use 57 AddScopedExternalInputs() when creating them 60 input_blobs = [str(i[0])
for i
in inputs]
61 initial_input_blobs = [str(x[1])
for x
in initial_cell_inputs]
62 op_name = net.NextName(
'recurrent')
67 scope_name = op_name
if scope
is None else scope
68 return "{}/{}".format(str(scope_name), str(name))
72 known_inputs = map(str, input_blobs + initial_input_blobs)
73 known_inputs += [str(x[0])
for x
in initial_cell_inputs]
74 if timestep
is not None:
75 known_inputs.append(str(timestep))
78 if b
not in known_inputs]
80 inner_outputs = list(cell_net.Proto().external_output)
82 inner_outputs_map = {o: o +
'_grad' for o
in inner_outputs}
86 cell_net.Proto().op, inner_outputs_map)
87 backward_mapping = {str(k): v
for k, v
in backward_mapping.items()}
88 backward_cell_net =
core.Net(
"RecurrentBackwardStep")
89 del backward_cell_net.Proto().op[:]
91 if recompute_blobs_on_backward
is not None:
95 recompute_blobs_on_backward = set(
96 [str(b)
for b
in recompute_blobs_on_backward]
98 for op
in cell_net.Proto().op:
99 if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
100 backward_cell_net.Proto().op.extend([op])
101 assert set(op.output).issubset(recompute_blobs_on_backward), \
102 'Outputs {} are output by op but not recomputed: {}'.format(
103 set(op.output) - recompute_blobs_on_backward,
107 recompute_blobs_on_backward = set()
109 backward_cell_net.Proto().op.extend(backward_ops)
112 backward_cell_net.Proto())
119 blob
for (blob, ver)
in blob_versions.items()
121 blob
in undefined
and 122 blob
not in cell_net.Proto().external_output]
123 backward_cell_net.Proto().external_input.extend(scratches)
125 all_inputs = [i[1]
for i
in inputs] + [
126 x[1]
for x
in initial_cell_inputs] + references
129 cell_net.Proto().type =
'simple' 130 backward_cell_net.Proto().type =
'simple' 147 recurrent_states = []
149 for cell_input, _
in initial_cell_inputs:
150 cell_input = str(cell_input)
154 state = s(cell_input +
"_states")
155 states_grad = state +
"_grad" 156 cell_output = links[str(cell_input)]
157 forward_links.append((cell_input, state, 0))
158 forward_links.append((cell_output, state, 1))
159 backward_links.append((cell_output +
"_grad", states_grad, 1))
161 backward_cell_net.Proto().external_input.append(
162 str(cell_output) +
"_grad")
163 aliases.append((state, cell_output +
"_all", 1))
164 aliases.append((state, cell_output +
"_last", -1))
165 all_outputs.extend([cell_output +
"_all", cell_output +
"_last"])
167 recurrent_states.append(state)
169 recurrent_input_grad = cell_input +
"_grad" 170 if not backward_blob_versions.get(recurrent_input_grad, 0):
175 backward_links.append(
176 (backward_mapping[cell_input], states_grad, 0))
178 backward_links.append((cell_input +
"_grad", states_grad, 0))
180 for reference
in references:
185 reference_grad = reference +
"_grad" 186 if (reference
in backward_mapping
and 187 reference_grad != str(backward_mapping[reference])):
194 backward_cell_net.Alias(
195 backward_mapping[reference], reference_grad)
197 for input_t, input_blob
in inputs:
198 forward_links.append((str(input_t), str(input_blob), 0))
199 backward_links.append((
200 backward_mapping[str(input_t)], str(input_blob) +
"_grad", 0
202 backward_cell_net.Proto().external_input.extend(
203 cell_net.Proto().external_input)
204 backward_cell_net.Proto().external_input.extend(
205 cell_net.Proto().external_output)
207 def unpack_triple(x):
215 link_internal, link_external, link_offset = unpack_triple(forward_links)
216 backward_link_internal, backward_link_external, backward_link_offset = \
217 unpack_triple(backward_links)
218 alias_src, alias_dst, alias_offset = unpack_triple(aliases)
220 params = [x
for x
in references
if x
in backward_mapping.keys()]
221 recurrent_inputs = [str(x[1])
for x
in initial_cell_inputs]
223 global _workspace_seq
224 results = net.RecurrentNetwork(
226 all_outputs + [s(
"step_workspaces")],
227 param=map(all_inputs.index, params),
229 alias_dst=map(str, alias_dst),
230 alias_offset=alias_offset,
231 recurrent_states=recurrent_states,
232 initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
233 link_internal=map(str, link_internal),
234 link_external=map(str, link_external),
235 link_offset=link_offset,
236 backward_link_internal=map(str, backward_link_internal),
237 backward_link_external=map(str, backward_link_external),
238 backward_link_offset=backward_link_offset,
239 step_net=str(cell_net.Proto()),
240 backward_step_net=str(backward_cell_net.Proto()),
241 timestep=
"timestep" if timestep
is None else str(timestep),
242 outputs_with_grads=outputs_with_grads,
243 recompute_blobs_on_backward=map(str, recompute_blobs_on_backward)
250 def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
251 scope, outputs_with_grads=(0,), memory_optimization=
False,
254 Adds MI flavor of standard LSTM recurrent network operator to a model. 255 See https://arxiv.org/pdf/1606.06630.pdf 257 model: CNNModelHelper object new operators would be added to 259 input_blob: the input sequence in a format T x N x D 260 where T is sequence size, N - batch size and D - input dimention 262 seq_lengths: blob containing sequence lengths which would be passed to 265 initial_states: a tupple of (hidden_input_blob, cell_input_blob) 266 which are going to be inputs to the cell net on the first iteration 268 dim_in: input dimention 270 dim_out: output dimention 272 outputs_with_grads : position indices of output blobs which will receive 273 external error gradient during backpropagation 275 memory_optimization: if enabled, the LSTM step is recomputed on backward step 276 so that we don't need to store forward activations for each 277 timestep. Saves memory with cost of computation. 282 return "{}/{}".format(str(scope), str(name))
284 """ initial bulk fully-connected """ 285 input_blob = model.FC(
286 input_blob, s(
'i2h'), dim_in=dim_in, dim_out=4 * dim_out, axis=2)
289 step_model = CNNModelHelper(name=
'milstm_cell', param_model=model)
290 input_t, timestep, cell_t_prev, hidden_t_prev = (
291 step_model.net.AddScopedExternalInputs(
292 'input_t',
'timestep',
'cell_t_prev',
'hidden_t_prev'))
295 prev_t = step_model.FC(
296 hidden_t_prev, s(
'prev_t'), dim_in=dim_out,
297 dim_out=4 * dim_out, axis=2)
299 alpha = step_model.param_init_net.ConstantFill(
305 beta1 = step_model.param_init_net.ConstantFill(
311 beta2 = step_model.param_init_net.ConstantFill(
317 b = step_model.param_init_net.ConstantFill(
323 model.params.extend([alpha, beta1, beta2, b])
326 alpha_tdash = step_model.net.Mul(
331 alpha_tdash_rs, _ = step_model.net.Reshape(
333 [s(
'alpha_tdash_rs'), s(
'alpha_tdash_old_shape')],
334 shape=[-1, 4 * dim_out],
336 alpha_t = step_model.net.Mul(
337 [alpha_tdash_rs, alpha],
344 prev_t_rs, _ = step_model.net.Reshape(
346 [s(
'prev_t_rs'), s(
'prev_t_old_shape')],
347 shape=[-1, 4 * dim_out],
349 beta1_t = step_model.net.Mul(
357 input_t_rs, _ = step_model.net.Reshape(
359 [s(
'input_t_rs'), s(
'input_t_old_shape')],
360 shape=[-1, 4 * dim_out],
362 beta2_t = step_model.net.Mul(
369 gates_tdash = step_model.net.Sum(
370 [alpha_t, beta1_t, beta2_t],
373 gates_t = step_model.net.Add(
380 gates_t_rs, _ = step_model.net.Reshape(
382 [s(
'gates_t_rs'), s(
'gates_t_old_shape')],
383 shape=[1, -1, 4 * dim_out],
386 hidden_t, cell_t = step_model.net.LSTMUnit(
387 [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
388 [s(
'hidden_t'), s(
'cell_t')],
389 forget_bias=forget_bias,
391 step_model.net.AddExternalOutputs(cell_t, hidden_t)
393 """ recurrent network """ 394 (hidden_input_blob, cell_input_blob) = initial_states
397 cell_net=step_model.net,
398 inputs=[(input_t, input_blob)],
399 initial_cell_inputs=[
400 (hidden_t_prev, hidden_input_blob),
401 (cell_t_prev, cell_input_blob),
404 hidden_t_prev: hidden_t,
409 outputs_with_grads=outputs_with_grads,
410 recompute_blobs_on_backward=[gates_t]
if memory_optimization
else None 412 return output, last_output, all_states, last_state
def GetBackwardPass(cls, operators, ys)
def get_undefined_blobs(ssa)
def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), memory_optimization=False, forget_bias=0.0)
def get_ssa(net, blob_versions=None)
def recurrent_net(net, cell_net, inputs, initial_cell_inputs, links, timestep=None, scope=None, outputs_with_grads=(0,), recompute_blobs_on_backward=None)