Class GRUNematus

Inheritance Relationships

Base Type

Class Documentation

class GRUNematus : public marian::rnn::Cell

GRU unit supporting Nematus-like layer normalization.

If no layer normalization is used, the unit is equivalent to marian::rnn::GRU and uses the same kernel.



Public Functions

GRUNematus(Ptr<ExpressionGraph> graph, Ptr<Options> options)
virtual State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr)
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs)
virtual State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr)

Protected Attributes

Expr UUx_
Expr WWx_
Expr bbx_
Expr U_
Expr W_
Expr b_
Expr Ux_
Expr Wx_
Expr bx_
Expr W_lns_
Expr W_lnb_
Expr Wx_lns_
Expr Wx_lnb_
Expr U_lns_
Expr U_lnb_
Expr Ux_lns_
Expr Ux_lnb_
bool encoder_
bool final_
bool transition_
bool layerNorm_
float dropout_
Expr dropMaskX_
Expr dropMaskS_
Expr fakeInput_