Program Listing for File rnn.h¶
↰ Return to documentation for file (src/rnn/rnn.h
)
#pragma once
#include "layers/generic.h"
#include "marian.h"
#include "rnn/cells.h"
#include "rnn/types.h"
#include <algorithm>
#include <chrono>
#include <cstdio>
#include <iomanip>
#include <string>
namespace marian {
namespace rnn {
enum struct dir : int {
forward,
backward,
alternating_forward,
alternating_backward
};
}
} // namespace marian
YAML_REGISTER_TYPE(marian::rnn::dir, int)
namespace marian {
namespace rnn {
class BaseRNN {
protected:
Ptr<ExpressionGraph> graph_;
Ptr<Options> options_;
public:
BaseRNN(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
virtual ~BaseRNN() {}
virtual Expr transduce(Expr, Expr = nullptr) = 0;
virtual Expr transduce(Expr, State, Expr = nullptr) = 0;
virtual Expr transduce(Expr, States, Expr = nullptr) = 0;
virtual States lastCellStates() = 0;
virtual void push_back(Ptr<Cell>) = 0;
virtual Ptr<Cell> at(int i) = 0;
virtual Ptr<Options> getOptions() { return options_; }
};
class RNN;
class SingleLayerRNN : public BaseRNN {
private:
Ptr<Cell> cell_;
dir direction_;
States last_;
States apply(const Expr input,
const States initialState,
const Expr mask = nullptr) {
last_.clear();
State state = initialState.front();
cell_->clear();
auto xWs = cell_->applyInput({input});
auto timeSteps = input->shape()[-3];
States outputs;
for(int i = 0; i < timeSteps; ++i) {
int j = i;
if(direction_ == dir::backward)
j = timeSteps - i - 1;
std::vector<Expr> steps(xWs.size());
std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) {
return slice(e, -3, j);
});
if(mask)
state = cell_->applyState(steps, state, slice(mask, -3, j));
else
state = cell_->applyState(steps, state);
outputs.push_back(state);
}
if(direction_ == dir::backward)
outputs.reverse();
last_.push_back(outputs.back());
return outputs;
}
States apply(const Expr input, const Expr mask = nullptr) {
auto graph = input->graph();
int dimBatch = input->shape()[-2];
int dimState = cell_->getOptions()->get<int>("dimState");
auto output = graph->zeros({1, dimBatch, dimState});
Expr cell = output;
State startState{output, cell};
return apply(input, States({startState}), mask);
}
SingleLayerRNN(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: BaseRNN(graph, options),
direction_((dir)options->get<int>("direction", (int)dir::forward)) {}
public:
friend RNN;
virtual ~SingleLayerRNN() {}
// @TODO: benchmark whether this concatenation is a good idea
virtual Expr transduce(Expr input, Expr mask = nullptr) override {
return apply(input, mask).outputs();
}
virtual Expr transduce(Expr input, States states, Expr mask = nullptr) override {
return apply(input, states, mask).outputs();
}
virtual Expr transduce(Expr input, State state, Expr mask = nullptr) override {
return apply(input, States({state}), mask).outputs();
}
States lastCellStates() override { return last_; }
void push_back(Ptr<Cell> cell) override { cell_ = cell; }
virtual Ptr<Cell> at(int i) override {
ABORT_IF(i > 0, "SingleRNN only has one cell");
return cell_;
}
};
class RNN : public BaseRNN, public std::enable_shared_from_this<RNN> {
private:
bool skip_;
bool skipFirst_;
std::vector<Ptr<SingleLayerRNN>> rnns_;
public:
RNN(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: BaseRNN(graph, options),
skip_(options->get("skip", false)),
skipFirst_(options->get("skipFirst", false)) {}
void push_back(Ptr<Cell> cell) override {
auto rnn
= Ptr<SingleLayerRNN>(new SingleLayerRNN(graph_, cell->getOptions()));
rnn->push_back(cell);
rnns_.push_back(rnn);
}
Expr transduce(Expr input, Expr mask = nullptr) override {
ABORT_IF(rnns_.empty(), "0 layers in RNN");
Expr output;
Expr layerInput = input;
for(size_t i = 0; i < rnns_.size(); ++i) {
auto lazyInput = layerInput;
auto cell = rnns_[i]->at(0);
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
lazyInput = concatenate(lazyInputs, /*axis =*/ -1);
}
auto layerOutput = rnns_[i]->transduce(lazyInput, mask);
if(skip_ && (skipFirst_ || i > 0))
output = layerOutput + layerInput;
else
output = layerOutput;
layerInput = output;
}
return output;
}
Expr transduce(Expr input, States states, Expr mask = nullptr) override {
ABORT_IF(rnns_.empty(), "0 layers in RNN");
Expr output;
Expr layerInput = input;
for(size_t i = 0; i < rnns_.size(); ++i) {
Expr lazyInput;
auto cell = rnns_[i]->at(0);
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
lazyInput = concatenate(lazyInputs, /*axis =*/ -1);
} else {
lazyInput = layerInput;
}
auto layerOutput
= rnns_[i]->transduce(lazyInput, States({states[i]}), mask);
if(skip_ && (skipFirst_ || i > 0))
output = layerOutput + layerInput;
else
output = layerOutput;
layerInput = output;
}
return output;
}
Expr transduce(Expr input, State state, Expr mask = nullptr) override {
ABORT_IF(rnns_.empty(), "0 layers in RNN");
Expr output;
Expr layerInput = input;
for(size_t i = 0; i < rnns_.size(); ++i) {
auto lazyInput = layerInput;
auto cell = rnns_[i]->at(0);
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
lazyInput = concatenate(lazyInputs, /*axis =*/ -1);
}
auto layerOutput = rnns_[i]->transduce(lazyInput, States({state}), mask);
if(skip_ && (skipFirst_ || i > 0))
output = layerOutput + layerInput;
else
output = layerOutput;
layerInput = output;
}
return output;
}
States lastCellStates() override {
States temp;
for(auto rnn : rnns_)
temp.push_back(rnn->lastCellStates().back());
return temp;
}
virtual Ptr<Cell> at(int i) override { return rnns_[i]->at(0); }
};
} // namespace rnn
} // namespace marian