Caffe2 - Python API
A deep learning, cross platform ML framework
Public Member Functions | Public Attributes | List of all members
rnn_cell.RNNCell Class Reference
Inheritance diagram for rnn_cell.RNNCell:
rnn_cell.LSTMCell rnn_cell.LSTMWithAttentionCell rnn_cell.MILSTMCell rnn_cell.MILSTMWithAttentionCell

Public Member Functions

def __init__ (self, name)
 
def scope (self, name)
 
def apply_over_sequence (self, model, inputs, seq_lengths, initial_states, outputs_with_grads=None)
 
def apply (self, model, input_t, seq_lengths, states, timestep)
 
def prepare_input (self, model, input_blob)
 
def get_state_names (self)
 

Public Attributes

 name
 
 recompute_blobs
 

Detailed Description

Base class for writing recurrent / stateful operations.

One needs to implement 3 methods: _apply, prepare_input and get_state_names.
As a result base class will provice apply_over_sequence method, which
allows you to apply recurrent operations over a sequence of any length.

Definition at line 20 of file rnn_cell.py.

Member Function Documentation

◆ get_state_names()

def rnn_cell.RNNCell.get_state_names (   self)
Return the names of the recurrent states.
It's required by apply_over_sequence method in order to allocate
recurrent states for all steps with meaningful names.

Definition at line 111 of file rnn_cell.py.

◆ prepare_input()

def rnn_cell.RNNCell.prepare_input (   self,
  model,
  input_blob 
)
If some operations in _apply method depend only on the input,
not on recurrent states, they could be computed in advance.

model: CNNModelHelper object new operators would be added to

input_blob: either the whole input sequence with shape
(sequence_length, batch_size, input_dim) or a single input with shape
(1, batch_size, input_dim).

Definition at line 98 of file rnn_cell.py.


The documentation for this class was generated from the following file: