Caffe2 - Python API
A deep learning, cross platform ML framework
char_rnn.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, workspace, cnn, utils
9 from caffe2.python.recurrent import LSTM
10 from caffe2.proto import caffe2_pb2
11 
12 
13 import argparse
14 import logging
15 import numpy as np
16 from datetime import datetime
17 
18 '''
19 This script takes a text file as input and uses a recurrent neural network
20 to learn to predict next character in a sequence.
21 '''
22 
23 logging.basicConfig()
24 log = logging.getLogger("char_rnn")
25 log.setLevel(logging.DEBUG)
26 
27 
28 # Default set() here is intentional as it would accumulate values like a global
29 # variable
30 def CreateNetOnce(net, created_names=set()):
31  name = net.Name()
32  if name not in created_names:
33  created_names.add(name)
35 
36 
37 class CharRNN(object):
38  def __init__(self, args):
39  self.seq_length = args.seq_length
40  self.batch_size = args.batch_size
41  self.iters_to_report = args.iters_to_report
42  self.hidden_size = args.hidden_size
43 
44  with open(args.train_data) as f:
45  self.text = f.read()
46 
47  self.vocab = list(set(self.text))
48  self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}
49  self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}
50  self.D = len(self.char_to_idx)
51 
52  print("Input has {} characters. Total input size: {}".format(
53  len(self.vocab), len(self.text)))
54 
55  def CreateModel(self):
56  log.debug("Start training")
57  model = cnn.CNNModelHelper(name="char_rnn")
58 
59  input_blob, seq_lengths, hidden_init, cell_init, target = \
60  model.net.AddExternalInputs(
61  'input_blob',
62  'seq_lengths',
63  'hidden_init',
64  'cell_init',
65  'target',
66  )
67 
68  hidden_output_all, self.hidden_output, _, self.cell_state = LSTM(
69  model, input_blob, seq_lengths, (hidden_init, cell_init),
70  self.D, self.hidden_size, scope="LSTM")
71  output = model.FC(hidden_output_all, None, dim_in=self.hidden_size,
72  dim_out=self.D, axis=2)
73 
74  # axis is 2 as first two are T (time) and N (batch size).
75  # We treat them as one big batch of size T * N
76  softmax = model.Softmax(output, 'softmax', axis=2)
77 
78  softmax_reshaped, _ = model.Reshape(
79  softmax, ['softmax_reshaped', '_'], shape=[-1, self.D])
80 
81  # Create a copy of the current net. We will use it on the forward
82  # pass where we don't need loss and backward operators
83  self.forward_net = core.Net(model.net.Proto())
84 
85  xent = model.LabelCrossEntropy([softmax_reshaped, target], 'xent')
86  # Loss is average both across batch and through time
87  # Thats why the learning rate below is multiplied by self.seq_length
88  loss = model.AveragedLoss(xent, 'loss')
89  model.AddGradientOperators([loss])
90 
91  # Hand made SGD update. Normally one can use helper functions
92  # to build an optimizer
93  ITER = model.Iter("iter")
94  LR = model.LearningRate(
95  ITER, "LR",
96  base_lr=-0.1 * self.seq_length,
97  policy="step", stepsize=1, gamma=0.9999)
98  ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
99 
100  # Update weights for each of the model parameters
101  for param in model.params:
102  param_grad = model.param_to_grad[param]
103  model.WeightedSum([param, ONE, param_grad, LR], param)
104 
105  self.model = model
106  self.predictions = softmax
107  self.loss = loss
108 
109  self.prepare_state = core.Net("prepare_state")
110  self.prepare_state.Copy(self.hidden_output, hidden_init)
111  self.prepare_state.Copy(self.cell_state, cell_init)
112 
113  def _idx_at_pos(self, pos):
114  return self.char_to_idx[self.text[pos]]
115 
116  def TrainModel(self):
117  log.debug("Training model")
118 
119  workspace.RunNetOnce(self.model.param_init_net)
120 
121  # As though we predict the same probablity for each character
122  smooth_loss = -np.log(1.0 / self.D) * self.seq_length
123  last_n_iter = 0
124  last_n_loss = 0.0
125  num_iter = 0
126  N = len(self.text)
127 
128  # We split text into batch_size peaces. Each peace will be used only
129  # by a corresponding batch during the training process
130  text_block_positions = np.zeros(self.batch_size, dtype=np.int32)
131  text_block_size = N // self.batch_size
132  text_block_starts = range(0, N, text_block_size)
133  text_block_sizes = [text_block_size] * self.batch_size
134  text_block_sizes[self.batch_size - 1] += N % self.batch_size
135  assert sum(text_block_sizes) == N
136 
137  # Writing to output states which will be copied to input
138  # states within the loop below
139  workspace.FeedBlob(self.hidden_output, np.zeros(
140  [1, self.batch_size, self.hidden_size], dtype=np.float32
141  ))
142  workspace.FeedBlob(self.cell_state, np.zeros(
143  [1, self.batch_size, self.hidden_size], dtype=np.float32
144  ))
146 
147  # We iterate over text in a loop many times. Each time we peak
148  # seq_length segment and feed it to LSTM as a sequence
149  last_time = datetime.now()
150  progress = 0
151  while True:
153  "seq_lengths",
154  np.array([self.seq_length] * self.batch_size,
155  dtype=np.int32)
156  )
157  workspace.RunNet(self.prepare_state.Name())
158 
159  input = np.zeros(
160  [self.seq_length, self.batch_size, self.D]
161  ).astype(np.float32)
162  target = np.zeros(
163  [self.seq_length * self.batch_size]
164  ).astype(np.int32)
165 
166  for e in range(self.batch_size):
167  for i in range(self.seq_length):
168  pos = text_block_starts[e] + text_block_positions[e]
169  input[i][e][self._idx_at_pos(pos)] = 1
170  target[i * self.batch_size + e] =\
171  self._idx_at_pos((pos + 1) % N)
172  text_block_positions[e] = (
173  text_block_positions[e] + 1) % text_block_sizes[e]
174  progress += 1
175 
176  workspace.FeedBlob('input_blob', input)
177  workspace.FeedBlob('target', target)
178 
179  CreateNetOnce(self.model.net)
180  workspace.RunNet(self.model.net.Name())
181 
182  num_iter += 1
183  last_n_iter += 1
184 
185  if num_iter % self.iters_to_report == 0:
186  new_time = datetime.now()
187  print("Characters Per Second: {}". format(
188  int(progress / (new_time - last_time).total_seconds())
189  ))
190  print("Iterations Per Second: {}". format(
191  int(self.iters_to_report /
192  (new_time - last_time).total_seconds())
193  ))
194 
195  last_time = new_time
196  progress = 0
197 
198  print("{} Iteration {} {}".
199  format('-' * 10, num_iter, '-' * 10))
200 
201  loss = workspace.FetchBlob(self.loss) * self.seq_length
202  smooth_loss = 0.999 * smooth_loss + 0.001 * loss
203  last_n_loss += loss
204 
205  if num_iter % self.iters_to_report == 0:
206  self.GenerateText(500, np.random.choice(self.vocab))
207 
208  log.debug("Loss since last report: {}"
209  .format(last_n_loss / last_n_iter))
210  log.debug("Smooth loss: {}".format(smooth_loss))
211 
212  last_n_loss = 0.0
213  last_n_iter = 0
214 
215  def GenerateText(self, num_characters, ch):
216  # Given a starting symbol we feed a fake sequence of size 1 to
217  # our RNN num_character times. After each time we use output
218  # probabilities to pick a next character to feed to the network.
219  # Same character becomes part of the output
220  CreateNetOnce(self.forward_net)
221 
222  text = '' + ch
223  for i in range(num_characters):
225  "seq_lengths", np.array([1] * self.batch_size, dtype=np.int32))
226  workspace.RunNet(self.prepare_state.Name())
227 
228  input = np.zeros([1, self.batch_size, self.D]).astype(np.float32)
229  input[0][0][self.char_to_idx[ch]] = 1
230 
231  workspace.FeedBlob("input_blob", input)
232  workspace.RunNet(self.forward_net.Name())
233 
235  next = np.random.choice(self.D, p=p[0][0])
236 
237  ch = self.idx_to_char[next]
238  text += ch
239 
240  print(text)
241 
242 
243 @utils.debug
244 def main():
245  parser = argparse.ArgumentParser(
246  description="Caffe2: Char RNN Training"
247  )
248  parser.add_argument("--train_data", type=str, default=None,
249  help="Path to training data in a text file format",
250  required=True)
251  parser.add_argument("--seq_length", type=int, default=25,
252  help="One training example sequence length")
253  parser.add_argument("--batch_size", type=int, default=1,
254  help="Training batch size")
255  parser.add_argument("--iters_to_report", type=int, default=500,
256  help="How often to report loss and generate text")
257  parser.add_argument("--hidden_size", type=int, default=100,
258  help="Dimention of the hidden representation")
259  parser.add_argument("--gpu", action="store_true",
260  help="If set, training is going to use GPU 0")
261 
262  args = parser.parse_args()
263 
264  device = core.DeviceOption(
265  caffe2_pb2.CUDA if args.gpu else caffe2_pb2.CPU, 0)
266  with core.DeviceScope(device):
267  model = CharRNN(args)
268  model.CreateModel()
269  model.TrainModel()
270 
271 
272 if __name__ == '__main__':
273  workspace.GlobalInit(['caffe2', '--caffe2_log_level=2'])
274  main()
def _idx_at_pos(self, pos)
Definition: char_rnn.py:113
DeviceScope
Definition: core.py:27
def RunNet(name, num_iter=1)
Definition: workspace.py:164
def RunNetOnce(net)
Definition: workspace.py:160
def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None)
Definition: core.py:103
def GenerateText(self, num_characters, ch)
Definition: char_rnn.py:215
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def CreateNet(net, overwrite=False, input_blobs=None)
Definition: workspace.py:140
def FetchBlob(name)
Definition: workspace.py:276