3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core, workspace, cnn, utils
9 from caffe2.python.recurrent
import LSTM
10 from caffe2.proto
import caffe2_pb2
16 from datetime
import datetime
19 This script takes a text file as input and uses a recurrent neural network 20 to learn to predict next character in a sequence. 24 log = logging.getLogger(
"char_rnn")
25 log.setLevel(logging.DEBUG)
30 def CreateNetOnce(net, created_names=set()):
32 if name
not in created_names:
33 created_names.add(name)
38 def __init__(self, args):
44 with open(args.train_data)
as f:
52 print(
"Input has {} characters. Total input size: {}".format(
55 def CreateModel(self):
56 log.debug(
"Start training")
59 input_blob, seq_lengths, hidden_init, cell_init, target = \
60 model.net.AddExternalInputs(
68 hidden_output_all, self.hidden_output, _, self.
cell_state = LSTM(
69 model, input_blob, seq_lengths, (hidden_init, cell_init),
71 output = model.FC(hidden_output_all,
None, dim_in=self.
hidden_size,
72 dim_out=self.
D, axis=2)
76 softmax = model.Softmax(output,
'softmax', axis=2)
78 softmax_reshaped, _ = model.Reshape(
79 softmax, [
'softmax_reshaped',
'_'], shape=[-1, self.
D])
85 xent = model.LabelCrossEntropy([softmax_reshaped, target],
'xent')
88 loss = model.AveragedLoss(xent,
'loss')
89 model.AddGradientOperators([loss])
93 ITER = model.Iter(
"iter")
94 LR = model.LearningRate(
97 policy=
"step", stepsize=1, gamma=0.9999)
98 ONE = model.param_init_net.ConstantFill([],
"ONE", shape=[1], value=1.0)
101 for param
in model.params:
102 param_grad = model.param_to_grad[param]
103 model.WeightedSum([param, ONE, param_grad, LR], param)
113 def _idx_at_pos(self, pos):
116 def TrainModel(self):
117 log.debug(
"Training model")
122 smooth_loss = -np.log(1.0 / self.
D) * self.
seq_length 130 text_block_positions = np.zeros(self.
batch_size, dtype=np.int32)
132 text_block_starts = range(0, N, text_block_size)
133 text_block_sizes = [text_block_size] * self.
batch_size 135 assert sum(text_block_sizes) == N
149 last_time = datetime.now()
168 pos = text_block_starts[e] + text_block_positions[e]
172 text_block_positions[e] = (
173 text_block_positions[e] + 1) % text_block_sizes[e]
179 CreateNetOnce(self.
model.net)
186 new_time = datetime.now()
187 print(
"Characters Per Second: {}". format(
188 int(progress / (new_time - last_time).total_seconds())
190 print(
"Iterations Per Second: {}". format(
192 (new_time - last_time).total_seconds())
198 print(
"{} Iteration {} {}".
199 format(
'-' * 10, num_iter,
'-' * 10))
202 smooth_loss = 0.999 * smooth_loss + 0.001 * loss
208 log.debug(
"Loss since last report: {}" 209 .format(last_n_loss / last_n_iter))
210 log.debug(
"Smooth loss: {}".format(smooth_loss))
215 def GenerateText(self, num_characters, ch):
223 for i
in range(num_characters):
225 "seq_lengths", np.array([1] * self.
batch_size, dtype=np.int32))
228 input = np.zeros([1, self.
batch_size, self.
D]).astype(np.float32)
235 next = np.random.choice(self.
D, p=p[0][0])
245 parser = argparse.ArgumentParser(
246 description=
"Caffe2: Char RNN Training" 248 parser.add_argument(
"--train_data", type=str, default=
None,
249 help=
"Path to training data in a text file format",
251 parser.add_argument(
"--seq_length", type=int, default=25,
252 help=
"One training example sequence length")
253 parser.add_argument(
"--batch_size", type=int, default=1,
254 help=
"Training batch size")
255 parser.add_argument(
"--iters_to_report", type=int, default=500,
256 help=
"How often to report loss and generate text")
257 parser.add_argument(
"--hidden_size", type=int, default=100,
258 help=
"Dimention of the hidden representation")
259 parser.add_argument(
"--gpu", action=
"store_true",
260 help=
"If set, training is going to use GPU 0")
262 args = parser.parse_args()
265 caffe2_pb2.CUDA
if args.gpu
else caffe2_pb2.CPU, 0)
272 if __name__ ==
'__main__':
def _idx_at_pos(self, pos)
def RunNet(name, num_iter=1)
def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None)
def GenerateText(self, num_characters, ch)
def FeedBlob(name, arr, device_option=None)
def CreateNet(net, overwrite=False, input_blobs=None)