Program Listing for File types.h¶
↰ Return to documentation for file (src/rnn/types.h
)
#pragma once
#include "marian.h"
#include <iostream>
#include <vector>
namespace marian {
namespace rnn {
struct State {
Expr output;
Expr cell;
State select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
return{ select(output, selIdx, beamSize, isBatchMajor),
select(cell, selIdx, beamSize, isBatchMajor) };
}
// this function is also called by Logits
static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor)
{
if (!sel)
return sel; // keep nullptr untouched
sel = atleast_4d(sel);
int dimBatch = (int)selIdx.size() / beamSize;
int dimDepth = sel->shape()[-1];
int dimTime = isBatchMajor ? sel->shape()[-2] : sel->shape()[-3];
ABORT_IF(dimTime != 1 && !isBatchMajor, "unexpected time extent for RNN state"); // (the reshape()/rows() trick won't work in this case)
int numCols = isBatchMajor ? dimDepth * dimTime : dimDepth;
// @TODO: Can this complex operation be more easily written using index_select()?
sel = reshape(sel, { sel->shape().elements() / numCols, numCols }); // [beamSize * dimBatch, dimDepth] or [beamSize * dimBatch, dimTime * dimDepth]
sel = rows(sel, selIdx);
sel = reshape(sel, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
return sel;
}
};
class States {
private:
std::vector<State> states_;
public:
States() {}
States(const std::vector<State>& states) : states_(states) {}
States(size_t num, State state) : states_(num, state) {}
std::vector<State>::iterator begin() { return states_.begin(); }
std::vector<State>::iterator end() { return states_.end(); }
std::vector<State>::const_iterator begin() const { return states_.begin(); }
std::vector<State>::const_iterator end() const { return states_.end(); }
Expr outputs() {
std::vector<Expr> outputs;
for(auto s : states_)
outputs.push_back(atleast_3d(s.output));
if(outputs.size() > 1)
return concatenate(outputs, /*axis =*/ -3);
else
return outputs[0];
}
State& operator[](size_t i) { return states_[i]; };
const State& operator[](size_t i) const { return states_[i]; };
State& back() { return states_.back(); }
const State& back() const { return states_.back(); }
State& front() { return states_.front(); }
const State& front() const { return states_.front(); }
size_t size() const { return states_.size(); };
void push_back(const State& state) { states_.push_back(state); }
// create updated set of states that reflect reordering and dropping of hypotheses
States select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
States selected;
for(auto& state : states_)
selected.push_back(state.select(selIdx, beamSize, isBatchMajor));
return selected;
}
void reverse() { std::reverse(states_.begin(), states_.end()); }
void clear() { states_.clear(); }
};
class Cell;
class CellInput;
class Stackable : public std::enable_shared_from_this<Stackable> {
protected:
Ptr<Options> options_;
public:
Stackable(Ptr<Options> options) : options_(options) {}
// required for dynamic_pointer_cast to detect polymorphism
virtual ~Stackable() {}
template <typename Cast>
inline Ptr<Cast> as() {
return std::dynamic_pointer_cast<Cast>(shared_from_this());
}
template <typename Cast>
inline bool is() {
return as<Cast>() != nullptr;
}
Ptr<Options> getOptions() { return options_; }
template <typename T>
T opt(const std::string& key) {
return options_->get<T>(key);
}
template <typename T>
T opt(const std::string& key, T defaultValue) {
return options_->get<T>(key, defaultValue);
}
virtual void clear() = 0;
};
class CellInput : public Stackable {
public:
CellInput(Ptr<Options> options) : Stackable(options) {}
virtual Expr apply(State) = 0;
virtual int dimOutput() = 0;
};
class RNN;
class Cell : public Stackable {
protected:
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> lazyInputs_;
public:
Cell(Ptr<Options> options) : Stackable(options) {}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
return applyState(applyInput(inputs), state, mask);
}
virtual std::vector<Expr> getLazyInputs(Ptr<rnn::RNN> parent) {
std::vector<Expr> inputs;
for(auto lazy : lazyInputs_)
inputs.push_back(lazy(parent));
return inputs;
}
virtual void setLazyInputs(
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> lazy) {
lazyInputs_ = lazy;
}
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) = 0;
virtual State applyState(std::vector<Expr>, State, Expr = nullptr) = 0;
virtual void clear() override {}
};
class MultiCellInput : public CellInput {
protected:
std::vector<Ptr<CellInput>> inputs_;
public:
MultiCellInput(const std::vector<Ptr<CellInput>>& inputs,
Ptr<Options> options)
: CellInput(options), inputs_(inputs) {}
void push_back(Ptr<CellInput> input) { inputs_.push_back(input); }
virtual Expr apply(State state) override {
std::vector<Expr> outputs;
for(auto input : inputs_)
outputs.push_back(input->apply(state));
if(outputs.size() > 1)
return concatenate(outputs, /*axis =*/ -1);
else
return outputs[0];
}
virtual int dimOutput() override {
int sum = 0;
for(auto input : inputs_)
sum += input->dimOutput();
return sum;
}
virtual void clear() override {
for(auto i : inputs_)
i->clear();
}
};
class StackedCell : public Cell {
protected:
std::vector<Ptr<Stackable>> stackables_;
std::vector<Expr> lastInputs_;
public:
StackedCell(Ptr<ExpressionGraph>, Ptr<Options> options) : Cell(options) {}
StackedCell(Ptr<ExpressionGraph>,
Ptr<Options> options,
const std::vector<Ptr<Stackable>>& stackables)
: Cell(options), stackables_(stackables) {}
void push_back(Ptr<Stackable> stackable) { stackables_.push_back(stackable); }
virtual std::vector<Expr> applyInput(std::vector<Expr> inputs) override {
// lastInputs_ = inputs;
return stackables_[0]->as<Cell>()->applyInput(inputs);
}
virtual State applyState(std::vector<Expr> mappedInputs,
State state,
Expr mask = nullptr) override {
State hidden
= stackables_[0]->as<Cell>()->applyState(mappedInputs, state, mask);
;
for(size_t i = 1; i < stackables_.size(); ++i) {
if(stackables_[i]->is<Cell>()) {
auto hiddenNext
= stackables_[i]->as<Cell>()->apply(lastInputs_, hidden, mask);
lastInputs_.clear();
hidden = hiddenNext;
} else {
lastInputs_.push_back(stackables_[i]->as<CellInput>()->apply(hidden));
// lastInputs_ = { stackables_[i]->as<CellInput>()->apply(hidden) };
}
}
return hidden;
};
Ptr<Stackable> operator[](int i) { return stackables_[i]; }
Ptr<Stackable> at(int i) { return stackables_[i]; }
virtual void clear() override {
for(auto s : stackables_)
s->clear();
}
virtual std::vector<Expr> getLazyInputs(Ptr<rnn::RNN> parent) override {
ABORT_IF(!stackables_[0]->is<Cell>(),
"First stackable should be of type Cell");
return stackables_[0]->as<Cell>()->getLazyInputs(parent);
}
virtual void setLazyInputs(
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> lazy) override {
ABORT_IF(!stackables_[0]->is<Cell>(),
"First stackable should be of type Cell");
stackables_[0]->as<Cell>()->setLazyInputs(lazy);
}
};
} // namespace rnn
} // namespace marian