Program Listing for File cells.h¶
↰ Return to documentation for file (src/rnn/cells.h
)
#pragma once
#include "marian.h"
#include "layers/generic.h"
#include "rnn/types.h"
#include <algorithm>
#include <chrono>
#include <cstdio>
#include <iomanip>
#include <string>
namespace marian {
namespace rnn {
class Tanh : public Cell {
private:
Expr U_, W_, b_;
Expr gamma1_;
Expr gamma2_;
bool layerNorm_;
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
public:
Tanh(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options_->get<std::string>("prefix");
layerNorm_ = options_->get<bool>("layer-normalization", false);
dropout_ = options_->get<float>("dropout", 0);
U_ = graph->param(
prefix + "_U", {dimState, dimState}, inits::glorotUniform());
if(dimInput)
W_ = graph->param(
prefix + "_W", {dimInput, dimState}, inits::glorotUniform());
b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
prefix + "_gamma1", {1, 3 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
prefix + "_gamma2", {1, 3 * dimState}, inits::fromValue(1.f));
}
}
State apply(std::vector<Expr> inputs, State states, Expr mask = nullptr) {
return applyState(applyInput(inputs), states, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
Expr input;
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
xW = layerNorm(xW, gamma1_);
return {xW};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
Expr recState = state.output;
auto stateDropped = recState;
stateDropped = dropout(recState, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
Expr output;
if(xWs.empty())
output = tanh(sU, b_);
else {
output = tanh(xWs.front(), sU, b_);
}
if(mask)
return {output * mask, nullptr};
else
return {output, state.cell};
}
};
/******************************************************************************/
class ReLU : public Cell {
private:
Expr U_, W_, b_;
Expr gamma1_;
Expr gamma2_;
bool layerNorm_;
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
public:
ReLU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options_->get<std::string>("prefix");
layerNorm_ = options_->get<bool>("layer-normalization", false);
dropout_ = options_->get<float>("dropout", 0);
U_ = graph->param(prefix + "_U", {dimState, dimState}, inits::eye());
if(dimInput)
W_ = graph->param(
prefix + "_W", {dimInput, dimState}, inits::glorotUniform());
b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(prefix + "_gamma1", {1, dimState}, inits::ones());
gamma2_ = graph->param(prefix + "_gamma2", {1, dimState}, inits::ones());
}
}
State apply(std::vector<Expr> inputs, State states, Expr mask = nullptr) {
return applyState(applyInput(inputs), states, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
Expr input;
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
xW = layerNorm(xW, gamma1_);
return {xW};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
Expr recState = state.output;
auto stateDropped = recState;
stateDropped = dropout(recState, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
Expr output;
if(xWs.empty())
output = relu(sU + b_);
else {
output = relu(xWs.front() + sU + b_);
}
if(mask)
return {output * mask, state.cell};
else
return {output, state.cell};
}
};
/******************************************************************************/
Expr gruOps(const std::vector<Expr>& nodes, bool final = false);
class GRU : public Cell {
protected:
std::string prefix_;
Expr U_, W_, b_;
Expr gamma1_;
Expr gamma2_;
bool final_;
bool layerNorm_;
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
Expr fakeInput_;
public:
GRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = opt<int>("dimInput");
int dimState = opt<int>("dimState");
std::string prefix = opt<std::string>("prefix");
layerNorm_ = opt<bool>("layer-normalization", false);
dropout_ = opt<float>("dropout", 0);
final_ = opt<bool>("final", false);
auto U = graph->param(
prefix + "_U", {dimState, 2 * dimState}, inits::glorotUniform());
auto Ux = graph->param(
prefix + "_Ux", {dimState, dimState}, inits::glorotUniform());
U_ = concatenate({U, Ux}, /*axis =*/ -1);
if(dimInput > 0) {
auto W = graph->param(
prefix + "_W", {dimInput, 2 * dimState}, inits::glorotUniform());
auto Wx = graph->param(
prefix + "_Wx", {dimInput, dimState}, inits::glorotUniform());
W_ = concatenate({W, Wx}, /*axis =*/ -1);
}
auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros());
auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros());
b_ = concatenate({b, bx}, /*axis =*/ -1);
// @TODO use this and adjust Amun model type saving and loading
// U_ = graph->param(prefix + "_U", {dimState, 3 * dimState},
// (Expr a) : UnaryNodeOp(a)inits::glorotUniform());
// W_ = graph->param(prefix + "_W", {dimInput, 3 * dimState},
// (Expr a) : UnaryNodeOp(a)inits::glorotUniform());
// b_ = graph->param(prefix + "_b", {1, 3 * dimState},
// (Expr a) : UnaryNodeOp(a)inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
prefix + "_gamma1", {1, 3 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
prefix + "_gamma2", {1, 3 * dimState}, inits::fromValue(1.f));
}
}
virtual State apply(std::vector<Expr> inputs,
State state,
Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
Expr input;
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs[0];
input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
xW = layerNorm(xW, gamma1_);
return {xW};
}
virtual State applyState(std::vector<Expr> xWs,
State state,
Expr mask = nullptr) override {
auto stateOrig = state.output;
auto stateDropped = stateOrig;
stateDropped = dropout(stateOrig, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
Expr xW;
if(xWs.empty()) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
}
auto output = mask ? gruOps({stateOrig, xW, sU, b_, mask}, final_)
: gruOps({stateOrig, xW, sU, b_}, final_);
return {output, state.cell}; // no cell state, hence copy
}
};
class GRUNematus : public Cell {
protected:
// Concatenated Us and Ws, used unless layer normalization enabled
Expr UUx_, WWx_, bbx_;
// Parameters used if layer normalization enabled
Expr U_, W_, b_;
Expr Ux_, Wx_, bx_;
// Layer normalization parameters
Expr W_lns_, W_lnb_;
Expr Wx_lns_, Wx_lnb_;
Expr U_lns_, U_lnb_;
Expr Ux_lns_, Ux_lnb_;
// Whether it is an encoder or decoder
bool encoder_;
// Whether it is an RNN final layer or hidden layer
bool final_;
// Whether it is a transition layer
bool transition_;
// Use layer normalization
bool layerNorm_;
// Dropout probability
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
// Fake input with zeros replaces W and Wx in transition cells
Expr fakeInput_;
public:
GRUNematus(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = opt<int>("dimInput");
int dimState = opt<int>("dimState");
auto prefix = opt<std::string>("prefix");
encoder_ = prefix.substr(0, 7) == "encoder";
transition_ = opt<bool>("transition", false);
layerNorm_ = opt<bool>("layer-normalization", false);
dropout_ = opt<float>("dropout", 0);
final_ = opt<bool>("final", false);
auto U = graph->param(
prefix + "_U", {dimState, 2 * dimState}, inits::glorotUniform());
auto Ux = graph->param(
prefix + "_Ux", {dimState, dimState}, inits::glorotUniform());
if(layerNorm_) {
U_ = U;
Ux_ = Ux;
} else {
UUx_ = concatenate({U, Ux}, /*axis =*/ -1);
}
if(dimInput > 0) {
auto W = graph->param(
prefix + "_W", {dimInput, 2 * dimState}, inits::glorotUniform());
auto Wx = graph->param(
prefix + "_Wx", {dimInput, dimState}, inits::glorotUniform());
if(layerNorm_) {
W_ = W;
Wx_ = Wx;
} else {
WWx_ = concatenate({W, Wx}, /*axis =*/ -1);
}
}
auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros());
auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros());
if(layerNorm_) {
b_ = b;
bx_ = bx;
// in specific cases we need to pass bx to the kernel
if(encoder_ && transition_) {
auto b0 = graph->constant({1, 2 * dimState}, inits::zeros());
bbx_ = concatenate({b0, bx}, /*axis =*/ -1);
} else {
bbx_ = graph->constant({1, 3 * dimState}, inits::zeros());
}
} else {
bbx_ = concatenate({b, bx}, /*axis =*/ -1);
}
if(dropout_ > 0.0f) {
if(dimInput)
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput) {
W_lns_ = graph->param(
prefix + "_W_lns", {1, 2 * dimState}, inits::fromValue(1.f));
W_lnb_
= graph->param(prefix + "_W_lnb", {1, 2 * dimState}, inits::zeros());
Wx_lns_ = graph->param(
prefix + "_Wx_lns", {1, 1 * dimState}, inits::fromValue(1.f));
Wx_lnb_
= graph->param(prefix + "_Wx_lnb", {1, 1 * dimState}, inits::zeros());
}
U_lns_ = graph->param(
prefix + "_U_lns", {1, 2 * dimState}, inits::fromValue(1.f));
U_lnb_ = graph->param(prefix + "_U_lnb", {1, 2 * dimState}, inits::zeros());
Ux_lns_ = graph->param(
prefix + "_Ux_lns", {1, 1 * dimState}, inits::fromValue(1.f));
Ux_lnb_
= graph->param(prefix + "_Ux_lnb", {1, 1 * dimState}, inits::zeros());
}
}
virtual State apply(std::vector<Expr> inputs,
State state,
Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
Expr input;
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs[0];
input = dropout(input, dropMaskX_);
Expr xW;
if(layerNorm_) {
Expr W; // RUH_1_ in Amun
Expr Wx; // RUH_2_ in Amun
if(final_) {
W = dot(input, W_);
Wx = dot(input, Wx_);
} else {
W = affine(input, W_, b_);
Wx = affine(input, Wx_, bx_);
}
W = layerNorm(W, W_lns_, W_lnb_, NEMATUS_LN_EPS);
Wx = layerNorm(Wx, Wx_lns_, Wx_lnb_, NEMATUS_LN_EPS);
xW = concatenate({W, Wx}, /*axis =*/ -1);
} else {
xW = dot(input, WWx_);
}
return {xW};
}
virtual State applyState(std::vector<Expr> xWs,
State state,
Expr mask = nullptr) override {
// make sure that we have transition layers
assert(transition_ == xWs.empty());
auto stateOrig = state.output;
auto stateDropped = stateOrig;
stateDropped = dropout(stateOrig, dropMaskS_);
Expr sU;
if(layerNorm_) {
Expr U; // Temp_1_ in Amun
Expr Ux; // Temp_2_ in Amun
if(encoder_) {
U = layerNorm(dot(stateDropped, U_), U_lns_, U_lnb_, NEMATUS_LN_EPS);
Ux = layerNorm(
dot(stateDropped, Ux_), Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS);
if(transition_) {
U = U + b_;
}
} else {
if(final_ || transition_) {
U = affine(stateDropped, U_, b_);
Ux = affine(stateDropped, Ux_, bx_);
} else {
U = dot(stateDropped, U_);
Ux = dot(stateDropped, Ux_);
}
U = layerNorm(U, U_lns_, U_lnb_, NEMATUS_LN_EPS);
Ux = layerNorm(Ux, Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS);
}
sU = concatenate({U, Ux}, /*axis =*/ -1);
} else {
sU = dot(stateDropped, UUx_);
}
Expr xW;
if(transition_) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
}
Expr output;
output = mask ? gruOps({stateOrig, xW, sU, bbx_, mask}, final_)
: gruOps({stateOrig, xW, sU, bbx_}, final_);
return {output, state.cell}; // no cell state, hence copy
}
};
/******************************************************************************/
Expr lstmOpsC(const std::vector<Expr>& nodes);
Expr lstmOpsO(const std::vector<Expr>& nodes);
class FastLSTM : public Cell {
protected:
std::string prefix_;
Expr U_, W_, b_;
Expr gamma1_;
Expr gamma2_;
bool layerNorm_;
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
Expr fakeInput_;
public:
FastLSTM(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = opt<int>("dimInput");
int dimState = opt<int>("dimState");
std::string prefix = opt<std::string>("prefix");
layerNorm_ = opt<bool>("layer-normalization", false);
dropout_ = opt<float>("dropout", 0);
U_ = graph->param(
prefix + "_U", {dimState, 4 * dimState}, inits::glorotUniform());
if(dimInput)
W_ = graph->param(
prefix + "_W", {dimInput, 4 * dimState}, inits::glorotUniform());
b_ = graph->param(prefix + "_b", {1, 4 * dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
prefix + "_gamma1", {1, 4 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
prefix + "_gamma2", {1, 4 * dimState}, inits::fromValue(1.f));
}
}
virtual State apply(std::vector<Expr> inputs,
State state,
Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
Expr input;
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1) {
input = concatenate(inputs, /*axis =*/ -1);
} else
input = inputs.front();
input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
xW = layerNorm(xW, gamma1_);
return {xW};
}
virtual State applyState(std::vector<Expr> xWs,
State state,
Expr mask = nullptr) override {
auto recState = state.output;
auto cellState = state.cell;
auto recStateDropped = recState;
recStateDropped = dropout(recState, dropMaskS_);
auto sU = dot(recStateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
Expr xW;
if(xWs.empty()) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
}
// dc/dp where p = W_i, U_i, ..., but without index o
auto nextCellState = mask ? lstmOpsC({cellState, xW, sU, b_, mask})
: lstmOpsC({cellState, xW, sU, b_});
// dh/dp dh/dc where p = W_o, U_o, b_o
auto nextRecState = lstmOpsO({nextCellState, xW, sU, b_});
return {nextRecState, nextCellState};
}
};
using LSTM = FastLSTM;
/******************************************************************************/
// Experimental cells, use with care
template <class CellType>
class Multiplicative : public CellType {
protected:
Expr Um_, Wm_, bm_, bwm_;
Expr gamma1m_, gamma2m_;
public:
Multiplicative(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: CellType(graph, options) {
int dimInput = options->get<int>("dimInput");
int dimState = options->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
Um_ = graph->param(
prefix + "_Um", {dimState, dimState}, inits::glorotUniform());
Wm_ = graph->param(
prefix + "_Wm", {dimInput, dimState}, inits::glorotUniform());
bm_ = graph->param(prefix + "_bm", {1, dimState}, inits::zeros());
bwm_ = graph->param(prefix + "_bwm", {1, dimState}, inits::zeros());
if(CellType::layerNorm_) {
gamma1m_ = graph->param(
prefix + "_gamma1m", {1, dimState}, inits::fromValue(1.f));
gamma2m_ = graph->param(
prefix + "_gamma2m", {1, dimState}, inits::fromValue(1.f));
}
}
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
ABORT_IF(inputs.empty(), "Multiplicative LSTM expects input");
Expr input;
if(inputs.size() > 1) {
input = concatenate(inputs, /*axis =*/ -1);
} else
input = inputs.front();
auto xWs = CellType::applyInput({input});
auto xWm = affine(input, Wm_, bwm_);
if(CellType::layerNorm_)
xWm = layerNorm(xWm, gamma1m_);
xWs.push_back(xWm);
return xWs;
}
virtual State applyState(std::vector<Expr> xWs,
State state,
Expr mask = nullptr) override {
auto xWm = xWs.back();
xWs.pop_back();
auto sUm = affine(state.output, Um_, bm_);
if(CellType::layerNorm_)
sUm = layerNorm(sUm, gamma2m_);
auto mstate = xWm * sUm;
return CellType::applyState(xWs, State({mstate, state.cell}), mask);
}
};
using MLSTM = Multiplicative<LSTM>;
using MGRU = Multiplicative<GRU>;
/******************************************************************************/
// SlowLSTM and TestLSTM are for comparing efficient kernels for gradients with
// naive but correct LSTM version.
class SlowLSTM : public Cell {
private:
Expr Uf_, Wf_, bf_;
Expr Ui_, Wi_, bi_;
Expr Uo_, Wo_, bo_;
Expr Uc_, Wc_, bc_;
public:
SlowLSTM(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
Uf_ = graph->param(
prefix + "_Uf", {dimState, dimState}, inits::glorotUniform());
Wf_ = graph->param(
prefix + "_Wf", {dimInput, dimState}, inits::glorotUniform());
bf_ = graph->param(prefix + "_bf", {1, dimState}, inits::zeros());
Ui_ = graph->param(
prefix + "_Ui", {dimState, dimState}, inits::glorotUniform());
Wi_ = graph->param(
prefix + "_Wi", {dimInput, dimState}, inits::glorotUniform());
bi_ = graph->param(prefix + "_bi", {1, dimState}, inits::zeros());
Uc_ = graph->param(
prefix + "_Uc", {dimState, dimState}, inits::glorotUniform());
Wc_ = graph->param(
prefix + "_Wc", {dimInput, dimState}, inits::glorotUniform());
bc_ = graph->param(prefix + "_bc", {1, dimState}, inits::zeros());
Uo_ = graph->param(
prefix + "_Uo", {dimState, dimState}, inits::glorotUniform());
Wo_ = graph->param(
prefix + "_Wo", {dimInput, dimState}, inits::glorotUniform());
bo_ = graph->param(prefix + "_bo", {1, dimState}, inits::zeros());
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
ABORT_IF(inputs.empty(), "Slow LSTM expects input");
Expr input;
if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
auto xWf = dot(input, Wf_);
auto xWi = dot(input, Wi_);
auto xWo = dot(input, Wo_);
auto xWc = dot(input, Wc_);
return {xWf, xWi, xWo, xWc};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
auto recState = state.output;
auto cellState = state.cell;
auto sUf = affine(recState, Uf_, bf_);
auto sUi = affine(recState, Ui_, bi_);
auto sUo = affine(recState, Uo_, bo_);
auto sUc = affine(recState, Uc_, bc_);
auto f = sigmoid(xWs[0] + sUf);
auto i = sigmoid(xWs[1] + sUi);
auto o = sigmoid(xWs[2] + sUo);
auto c = tanh(xWs[3] + sUc);
auto nextCellState = f * cellState + i * c;
auto maskedCellState = mask ? mask * nextCellState : nextCellState;
auto nextState = o * tanh(maskedCellState);
auto maskedState = mask ? mask * nextState : nextState;
return {maskedState, maskedCellState};
}
};
/******************************************************************************/
class TestLSTM : public Cell {
private:
Expr U_, W_, b_;
public:
TestLSTM(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
auto Uf = graph->param(
prefix + "_Uf", {dimState, dimState}, inits::glorotUniform());
auto Wf = graph->param(
prefix + "_Wf", {dimInput, dimState}, inits::glorotUniform());
auto bf = graph->param(prefix + "_bf", {1, dimState}, inits::zeros());
auto Ui = graph->param(
prefix + "_Ui", {dimState, dimState}, inits::glorotUniform());
auto Wi = graph->param(
prefix + "_Wi", {dimInput, dimState}, inits::glorotUniform());
auto bi = graph->param(prefix + "_bi", {1, dimState}, inits::zeros());
auto Uc = graph->param(
prefix + "_Uc", {dimState, dimState}, inits::glorotUniform());
auto Wc = graph->param(
prefix + "_Wc", {dimInput, dimState}, inits::glorotUniform());
auto bc = graph->param(prefix + "_bc", {1, dimState}, inits::zeros());
auto Uo = graph->param(
prefix + "_Uo", {dimState, dimState}, inits::glorotUniform());
auto Wo = graph->param(
prefix + "_Wo", {dimInput, dimState}, inits::glorotUniform());
auto bo = graph->param(prefix + "_bo", {1, dimState}, inits::zeros());
U_ = concatenate({Uf, Ui, Uc, Uo}, /*axis =*/ -1);
W_ = concatenate({Wf, Wi, Wc, Wo}, /*axis =*/ -1);
b_ = concatenate({bf, bi, bc, bo}, /*axis =*/ -1);
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
ABORT_IF(inputs.empty(), "Test LSTM expects input");
Expr input;
if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
auto xW = dot(input, W_);
return {xW};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
auto recState = state.output;
auto cellState = state.cell;
auto sU = dot(recState, U_);
auto xW = xWs.front();
// dc/dp where p = W_i, U_i, ..., but without index o
auto nextCellState = mask ? lstmOpsC({cellState, xW, sU, b_, mask})
: lstmOpsC({cellState, xW, sU, b_});
// dh/dp dh/dc where p = W_o, U_o, b_o
auto nextRecState = mask ? lstmOpsO({nextCellState, xW, sU, b_, mask})
: lstmOpsO({nextCellState, xW, sU, b_});
return {nextRecState, nextCellState};
}
};
class SRU : public Cell {
private:
Expr W_;
Expr Wr_, br_;
Expr Wf_, bf_;
float dropout_;
Expr dropMaskX_;
float layerNorm_;
Expr gamma_, gammaf_, gammar_;
public:
SRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = opt<int>("dimInput");
int dimState = opt<int>("dimState");
std::string prefix = opt<std::string>("prefix");
ABORT_IF(dimInput != dimState,
"For SRU state and input dims have to be equal");
dropout_ = opt<float>("dropout", 0);
layerNorm_ = opt<bool>("layer-normalization", false);
W_ = graph->param(
prefix + "_W", {dimInput, dimInput}, inits::glorotUniform());
Wf_ = graph->param(
prefix + "_Wf", {dimInput, dimInput}, inits::glorotUniform());
bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros());
Wr_ = graph->param(
prefix + "_Wr", {dimInput, dimInput}, inits::glorotUniform());
br_ = graph->param(prefix + "_br", {1, dimInput}, inits::zeros());
if(dropout_ > 0.0f) {
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
}
if(layerNorm_) {
if(dimInput)
gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones());
gammar_ = graph->param(prefix + "_gammar", {1, dimState}, inits::ones());
gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones());
}
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
ABORT_IF(inputs.empty(), "SRU expects input");
Expr input;
if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
auto inputDropped = dropout(input, dropMaskX_);
Expr x, f, r;
if(layerNorm_) {
x = layerNorm(dot(inputDropped, W_), gamma_);
f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_);
r = layerNorm(dot(inputDropped, Wr_), gammar_, br_);
} else {
x = dot(inputDropped, W_);
f = affine(inputDropped, Wf_, bf_);
r = affine(inputDropped, Wr_, br_);
}
return {x, f, r, input};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
auto recState = state.output;
auto cellState = state.cell;
auto x = xWs[0];
auto f = xWs[1];
auto r = xWs[2];
auto input = xWs[3];
auto nextCellState = highway(cellState, x, f); // rename to "gate"?
auto nextState = highway(tanh(nextCellState), input, r);
auto maskedCellState = mask ? mask * nextCellState : nextCellState;
auto maskedState = mask ? mask * nextState : nextState;
return {maskedState, maskedCellState};
}
};
class SSRU : public Cell {
private:
Expr W_;
Expr Wf_, bf_;
float dropout_;
Expr dropMaskX_;
float layerNorm_;
Expr gamma_, gammaf_;
public:
SSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
ABORT_IF(dimInput != dimState,
"For SSRU state and input dims have to be equal");
dropout_ = opt<float>("dropout", 0);
layerNorm_ = opt<bool>("layer-normalization", false);
W_ = graph->param(
prefix + "_W", {dimInput, dimInput}, inits::glorotUniform());
Wf_ = graph->param(
prefix + "_Wf", {dimInput, dimInput}, inits::glorotUniform());
bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros());
if(dropout_ > 0.0f) {
dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
}
if(layerNorm_) {
if(dimInput)
gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones());
gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones());
}
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
ABORT_IF(inputs.empty(), "SSRU expects input");
Expr input;
if(inputs.size() > 1)
input = concatenate(inputs, /*axis =*/ -1);
else
input = inputs.front();
auto inputDropped = dropout(input, dropMaskX_);
Expr x, f;
if(layerNorm_) {
x = layerNorm(dot(inputDropped, W_), gamma_);
f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_);
} else {
x = dot(inputDropped, W_);
f = affine(inputDropped, Wf_, bf_);
}
return {x, f};
}
State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) override {
auto recState = state.output;
auto cellState = state.cell;
auto x = xWs[0];
auto f = xWs[1];
auto nextCellState = highway(cellState, x, f); // rename to "gate"?
auto nextState = relu(nextCellState);
auto maskedCellState = mask ? mask * nextCellState : nextCellState;
auto maskedState = mask ? mask * nextState : nextState;
return {maskedState, maskedCellState};
}
};
// class LSSRU : public Cell {
// private:
// Expr W_;
// Expr Wf_, bf_;
// float dropout_;
// Expr dropMaskX_;
// public:
// LSSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
// int dimInput = options_->get<int>("dimInput");
// int dimState = options_->get<int>("dimState");
// std::string prefix = options->get<std::string>("prefix");
// ABORT_IF(dimInput != dimState, "For SRU state and input dims have to be equal");
// dropout_ = opt<float>("dropout", 0);
// W_ = graph->param(prefix + "_W",
// {dimInput, dimInput},
// inits::glorotUniform());
// Wf_ = graph->param(prefix + "_Wf",
// {dimInput, dimInput},
// inits::glorotUniform());
// bf_ = graph->param(
// prefix + "_bf", {1, dimInput}, inits::zeros());
// if(dropout_ > 0.0f) {
// dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
// }
// }
// State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
// return applyState(applyInput(inputs), state, mask);
// }
// std::vector<Expr> applyInput(std::vector<Expr> inputs) {
// ABORT_IF(inputs.empty(), "Slow SRU expects input");
// Expr input;
// if(inputs.size() > 1)
// input = concatenate(inputs, /*axis =*/ -1);
// else
// input = inputs.front();
// auto inputDropped = dropout(input, dropMaskX_);
// auto x = dot(inputDropped, W_);
// auto f = affine(inputDropped, Wf_, bf_);
// return {x, f};
// }
// State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) {
// auto recState = state.output;
// auto cellState = state.cell;
// auto x = xWs[0];
// auto f = xWs[1];
// auto nextCellState = highwayLinear(cellState, x, f, 2.f); // rename to "gate"?
// auto nextState = relu(nextCellState);
// //auto nextState = nextCellState;
// auto maskedCellState = mask ? mask * nextCellState : nextCellState;
// auto maskedState = mask ? mask * nextState : nextState;
// return {maskedState, maskedCellState};
// }
// };
} // namespace rnn
} // namespace marian