Caffe2 - Python API
A deep learning, cross platform ML framework
recurrent.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core
9 from caffe2.python.scope import CurrentNameScope
10 from caffe2.python.cnn import CNNModelHelper
11 
12 
13 
14 def recurrent_net(
15  net, cell_net, inputs, initial_cell_inputs,
16  links, timestep=None, scope=None, outputs_with_grads=(0,),
17  recompute_blobs_on_backward=None,
18 ):
19  '''
20  net: the main net operator should be added to
21 
22  cell_net: cell_net which is executed in a recurrent fasion
23 
24  inputs: sequences to be fed into the recurrent net. Currently only one input
25  is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
26  of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
27 
28  initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
29  Format for each input is:
30  (cell_net_input_name, external_blob_with_data)
31 
32  links: a dictionary from cell_net input names in moment t+1 and
33  output names of moment t. Currently we assume that each output becomes
34  an input for the next timestep.
35 
36  timestep: name of the timestep blob to be used. If not provided "timestep"
37  is used.
38 
39  scope: Internal blobs are going to be scoped in a format
40  <scope_name>/<blob_name>
41  If not provided we generate a scope name automatically
42 
43  outputs_with_grads : position indices of output blobs which will receive
44  error gradient (from outside recurrent network) during backpropagation
45 
46  recompute_blobs_on_backward: specify a list of blobs that will be
47  recomputed for backward pass, and thus need not to be
48  stored for each forward timestep.
49  '''
50  assert len(inputs) == 1, "Only one input blob is supported so far"
51 
52  # Validate scoping
53  for einp in cell_net.Proto().external_input:
54  assert einp.startswith(CurrentNameScope()), \
55  '''
56  Cell net external inputs are not properly scoped, use
57  AddScopedExternalInputs() when creating them
58  '''
59 
60  input_blobs = [str(i[0]) for i in inputs]
61  initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
62  op_name = net.NextName('recurrent')
63 
64  def s(name):
65  # We have to manually scope due to our internal/external blob
66  # relationships.
67  scope_name = op_name if scope is None else scope
68  return "{}/{}".format(str(scope_name), str(name))
69 
70  # determine inputs that are considered to be references
71  # it is those that are not referred to in inputs or initial_cell_inputs
72  known_inputs = map(str, input_blobs + initial_input_blobs)
73  known_inputs += [str(x[0]) for x in initial_cell_inputs]
74  if timestep is not None:
75  known_inputs.append(str(timestep))
76  references = [
77  core.BlobReference(b) for b in cell_net.Proto().external_input
78  if b not in known_inputs]
79 
80  inner_outputs = list(cell_net.Proto().external_output)
81  # These gradients are expected to be available during the backward pass
82  inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
83 
84  # compute the backward pass of the cell net
85  backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
86  cell_net.Proto().op, inner_outputs_map)
87  backward_mapping = {str(k): v for k, v in backward_mapping.items()}
88  backward_cell_net = core.Net("RecurrentBackwardStep")
89  del backward_cell_net.Proto().op[:]
90 
91  if recompute_blobs_on_backward is not None:
92  # Insert operators to re-compute the specified blobs.
93  # They are added in the same order as for the forward pass, thus
94  # the order is correct.
95  recompute_blobs_on_backward = set(
96  [str(b) for b in recompute_blobs_on_backward]
97  )
98  for op in cell_net.Proto().op:
99  if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
100  backward_cell_net.Proto().op.extend([op])
101  assert set(op.output).issubset(recompute_blobs_on_backward), \
102  'Outputs {} are output by op but not recomputed: {}'.format(
103  set(op.output) - recompute_blobs_on_backward,
104  op
105  )
106  else:
107  recompute_blobs_on_backward = set()
108 
109  backward_cell_net.Proto().op.extend(backward_ops)
110  # compute blobs used but not defined in the backward pass
111  backward_ssa, backward_blob_versions = core.get_ssa(
112  backward_cell_net.Proto())
113  undefined = core.get_undefined_blobs(backward_ssa)
114 
115  # also add to the output list the intermediate outputs of fwd_step that
116  # are used by backward.
117  ssa, blob_versions = core.get_ssa(cell_net.Proto())
118  scratches = [
119  blob for (blob, ver) in blob_versions.items()
120  if ver > 0 and
121  blob in undefined and
122  blob not in cell_net.Proto().external_output]
123  backward_cell_net.Proto().external_input.extend(scratches)
124 
125  all_inputs = [i[1] for i in inputs] + [
126  x[1] for x in initial_cell_inputs] + references
127  all_outputs = []
128 
129  cell_net.Proto().type = 'simple'
130  backward_cell_net.Proto().type = 'simple'
131 
132  # Internal arguments used by RecurrentNetwork operator
133 
134  # Links are in the format blob_name, recurrent_states, offset.
135  # In the moment t we know that corresponding data block is at
136  # t + offset position in the recurrent_states tensor
137  forward_links = []
138  backward_links = []
139 
140  # Aliases are used to expose outputs to external world
141  # Format (internal_blob, external_blob, offset)
142  # Negative offset stands for going from the end,
143  # positive - from the beginning
144  aliases = []
145 
146  # States held inputs to the cell net
147  recurrent_states = []
148 
149  for cell_input, _ in initial_cell_inputs:
150  cell_input = str(cell_input)
151  # Recurrent_states is going to be (T + 1) x ...
152  # It stores all inputs and outputs of the cell net over time.
153  # Or their gradients in the case of the backward pass.
154  state = s(cell_input + "_states")
155  states_grad = state + "_grad"
156  cell_output = links[str(cell_input)]
157  forward_links.append((cell_input, state, 0))
158  forward_links.append((cell_output, state, 1))
159  backward_links.append((cell_output + "_grad", states_grad, 1))
160 
161  backward_cell_net.Proto().external_input.append(
162  str(cell_output) + "_grad")
163  aliases.append((state, cell_output + "_all", 1))
164  aliases.append((state, cell_output + "_last", -1))
165  all_outputs.extend([cell_output + "_all", cell_output + "_last"])
166 
167  recurrent_states.append(state)
168 
169  recurrent_input_grad = cell_input + "_grad"
170  if not backward_blob_versions.get(recurrent_input_grad, 0):
171  # If nobody writes to this recurrent input gradient, we need
172  # to make sure it gets to the states grad blob after all.
173  # We do this by using backward_links which triggers an alias
174  # This logic is being used for example in a SumOp case
175  backward_links.append(
176  (backward_mapping[cell_input], states_grad, 0))
177  else:
178  backward_links.append((cell_input + "_grad", states_grad, 0))
179 
180  for reference in references:
181  # Similar to above, in a case of a SumOp we need to write our parameter
182  # gradient to an external blob. In this case we can be sure that
183  # reference + "_grad" is a correct parameter name as we know how
184  # RecurrentNetworkOp gradient schema looks like.
185  reference_grad = reference + "_grad"
186  if (reference in backward_mapping and
187  reference_grad != str(backward_mapping[reference])):
188  # We can use an Alias because after each timestep
189  # RNN op adds value from reference_grad into and _acc blob
190  # which accumulates gradients for corresponding parameter accross
191  # timesteps. Then in the end of RNN op these two are being
192  # swaped and reference_grad blob becomes a real blob instead of
193  # being an alias
194  backward_cell_net.Alias(
195  backward_mapping[reference], reference_grad)
196 
197  for input_t, input_blob in inputs:
198  forward_links.append((str(input_t), str(input_blob), 0))
199  backward_links.append((
200  backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
201  ))
202  backward_cell_net.Proto().external_input.extend(
203  cell_net.Proto().external_input)
204  backward_cell_net.Proto().external_input.extend(
205  cell_net.Proto().external_output)
206 
207  def unpack_triple(x):
208  if x:
209  a, b, c = zip(*x)
210  return a, b, c
211  return [], [], []
212 
213  # Splitting to separate lists so we can pass them to c++
214  # where we ensemle them back
215  link_internal, link_external, link_offset = unpack_triple(forward_links)
216  backward_link_internal, backward_link_external, backward_link_offset = \
217  unpack_triple(backward_links)
218  alias_src, alias_dst, alias_offset = unpack_triple(aliases)
219 
220  params = [x for x in references if x in backward_mapping.keys()]
221  recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
222 
223  global _workspace_seq
224  results = net.RecurrentNetwork(
225  all_inputs,
226  all_outputs + [s("step_workspaces")],
227  param=map(all_inputs.index, params),
228  alias_src=alias_src,
229  alias_dst=map(str, alias_dst),
230  alias_offset=alias_offset,
231  recurrent_states=recurrent_states,
232  initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
233  link_internal=map(str, link_internal),
234  link_external=map(str, link_external),
235  link_offset=link_offset,
236  backward_link_internal=map(str, backward_link_internal),
237  backward_link_external=map(str, backward_link_external),
238  backward_link_offset=backward_link_offset,
239  step_net=str(cell_net.Proto()),
240  backward_step_net=str(backward_cell_net.Proto()),
241  timestep="timestep" if timestep is None else str(timestep),
242  outputs_with_grads=outputs_with_grads,
243  recompute_blobs_on_backward=map(str, recompute_blobs_on_backward)
244  )
245  # The last output is a list of step workspaces,
246  # which is only needed internally for gradient propogation
247  return results[:-1]
248 
249 
250 def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
251  scope, outputs_with_grads=(0,), memory_optimization=False,
252  forget_bias=0.0):
253  '''
254  Adds MI flavor of standard LSTM recurrent network operator to a model.
255  See https://arxiv.org/pdf/1606.06630.pdf
256 
257  model: CNNModelHelper object new operators would be added to
258 
259  input_blob: the input sequence in a format T x N x D
260  where T is sequence size, N - batch size and D - input dimention
261 
262  seq_lengths: blob containing sequence lengths which would be passed to
263  LSTMUnit operator
264 
265  initial_states: a tupple of (hidden_input_blob, cell_input_blob)
266  which are going to be inputs to the cell net on the first iteration
267 
268  dim_in: input dimention
269 
270  dim_out: output dimention
271 
272  outputs_with_grads : position indices of output blobs which will receive
273  external error gradient during backpropagation
274 
275  memory_optimization: if enabled, the LSTM step is recomputed on backward step
276  so that we don't need to store forward activations for each
277  timestep. Saves memory with cost of computation.
278  '''
279  def s(name):
280  # We have to manually scope due to our internal/external blob
281  # relationships.
282  return "{}/{}".format(str(scope), str(name))
283 
284  """ initial bulk fully-connected """
285  input_blob = model.FC(
286  input_blob, s('i2h'), dim_in=dim_in, dim_out=4 * dim_out, axis=2)
287 
288  """ the step net """
289  step_model = CNNModelHelper(name='milstm_cell', param_model=model)
290  input_t, timestep, cell_t_prev, hidden_t_prev = (
291  step_model.net.AddScopedExternalInputs(
292  'input_t', 'timestep', 'cell_t_prev', 'hidden_t_prev'))
293  # hU^T
294  # Shape: [1, batch_size, 4 * hidden_size]
295  prev_t = step_model.FC(
296  hidden_t_prev, s('prev_t'), dim_in=dim_out,
297  dim_out=4 * dim_out, axis=2)
298  # defining MI parameters
299  alpha = step_model.param_init_net.ConstantFill(
300  [],
301  [s('alpha')],
302  shape=[4 * dim_out],
303  value=1.0
304  )
305  beta1 = step_model.param_init_net.ConstantFill(
306  [],
307  [s('beta1')],
308  shape=[4 * dim_out],
309  value=1.0
310  )
311  beta2 = step_model.param_init_net.ConstantFill(
312  [],
313  [s('beta2')],
314  shape=[4 * dim_out],
315  value=1.0
316  )
317  b = step_model.param_init_net.ConstantFill(
318  [],
319  [s('b')],
320  shape=[4 * dim_out],
321  value=0.0
322  )
323  model.params.extend([alpha, beta1, beta2, b])
324  # alpha * (xW^T * hU^T)
325  # Shape: [1, batch_size, 4 * hidden_size]
326  alpha_tdash = step_model.net.Mul(
327  [prev_t, input_t],
328  s('alpha_tdash')
329  )
330  # Shape: [batch_size, 4 * hidden_size]
331  alpha_tdash_rs, _ = step_model.net.Reshape(
332  alpha_tdash,
333  [s('alpha_tdash_rs'), s('alpha_tdash_old_shape')],
334  shape=[-1, 4 * dim_out],
335  )
336  alpha_t = step_model.net.Mul(
337  [alpha_tdash_rs, alpha],
338  s('alpha_t'),
339  broadcast=1,
340  use_grad_hack=1
341  )
342  # beta1 * hU^T
343  # Shape: [batch_size, 4 * hidden_size]
344  prev_t_rs, _ = step_model.net.Reshape(
345  prev_t,
346  [s('prev_t_rs'), s('prev_t_old_shape')],
347  shape=[-1, 4 * dim_out],
348  )
349  beta1_t = step_model.net.Mul(
350  [prev_t_rs, beta1],
351  s('beta1_t'),
352  broadcast=1,
353  use_grad_hack=1
354  )
355  # beta2 * xW^T
356  # Shape: [batch_szie, 4 * hidden_size]
357  input_t_rs, _ = step_model.net.Reshape(
358  input_t,
359  [s('input_t_rs'), s('input_t_old_shape')],
360  shape=[-1, 4 * dim_out],
361  )
362  beta2_t = step_model.net.Mul(
363  [input_t_rs, beta2],
364  s('beta2_t'),
365  broadcast=1,
366  use_grad_hack=1
367  )
368  # Add 'em all up
369  gates_tdash = step_model.net.Sum(
370  [alpha_t, beta1_t, beta2_t],
371  s('gates_tdash')
372  )
373  gates_t = step_model.net.Add(
374  [gates_tdash, b],
375  s('gates_t'),
376  broadcast=1,
377  use_grad_hack=1
378  )
379  # # Shape: [1, batch_size, 4 * hidden_size]
380  gates_t_rs, _ = step_model.net.Reshape(
381  gates_t,
382  [s('gates_t_rs'), s('gates_t_old_shape')],
383  shape=[1, -1, 4 * dim_out],
384  )
385 
386  hidden_t, cell_t = step_model.net.LSTMUnit(
387  [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
388  [s('hidden_t'), s('cell_t')],
389  forget_bias=forget_bias,
390  )
391  step_model.net.AddExternalOutputs(cell_t, hidden_t)
392 
393  """ recurrent network """
394  (hidden_input_blob, cell_input_blob) = initial_states
395  output, last_output, all_states, last_state = recurrent_net(
396  net=model.net,
397  cell_net=step_model.net,
398  inputs=[(input_t, input_blob)],
399  initial_cell_inputs=[
400  (hidden_t_prev, hidden_input_blob),
401  (cell_t_prev, cell_input_blob),
402  ],
403  links={
404  hidden_t_prev: hidden_t,
405  cell_t_prev: cell_t,
406  },
407  timestep=timestep,
408  scope=scope,
409  outputs_with_grads=outputs_with_grads,
410  recompute_blobs_on_backward=[gates_t] if memory_optimization else None
411  )
412  return output, last_output, all_states, last_state
def GetBackwardPass(cls, operators, ys)
Definition: core.py:949
def get_undefined_blobs(ssa)
Definition: core.py:1010
def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), memory_optimization=False, forget_bias=0.0)
Definition: recurrent.py:252
def get_ssa(net, blob_versions=None)
Definition: core.py:969
def recurrent_net(net, cell_net, inputs, initial_cell_inputs, links, timestep=None, scope=None, outputs_with_grads=(0,), recompute_blobs_on_backward=None)
Definition: recurrent.py:18