Class CellFactory

Inheritance Relationships

Base Type

Derived Type

Class Documentation

class CellFactory : public marian::Factory

Base class for constructing RNN cells.

RNN cells only process a single timestep instead of the whole batches of input sequences. There are nine types of RNN cells provided by Marian, i.e., gru, gru-nematus, lstm, mlstm, mgru, tanh, relu, sru, ssru.

Subclassed by marian::rnn::StackedCellFactory

Public Functions

virtual Ptr<Cell> construct(Ptr<ExpressionGraph> graph)
CellFactory clone()
virtual void add_input(std::function<Expr(Ptr<rnn::RNN>)> func)
virtual void add_input(Expr input)

Protected Attributes

std::vector<std::function<Expr(Ptr<rnn::RNN>)>> inputs_