Program Listing for File generic.h¶
↰ Return to documentation for file (src/layers/generic.h
)
#pragma once
#include "common/definitions.h"
#include "graph/expression_operators.h"
#include "marian.h"
#include "data/shortlist.h"
#include "layers/factory.h"
namespace marian {
namespace mlp {
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
} // namespace mlp
} // namespace marian
namespace marian {
class LayerBase {
protected:
Ptr<ExpressionGraph> graph_;
Ptr<Options> options_;
public:
LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
template <typename T>
T opt(const std::string key) const {
return options_->get<T>(key);
}
template <typename T>
T opt(const std::string key, const T& defaultValue) const {
return options_->get<T>(key, defaultValue);
}
};
struct IUnaryLayer {
virtual ~IUnaryLayer() {}
virtual Expr apply(Expr) = 0;
virtual Expr apply(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
struct IHasShortList {
virtual void setShortlist(Ptr<data::Shortlist> shortlist) = 0;
virtual void clear() = 0;
};
struct IEmbeddingLayer {
virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const = 0;
virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
// alternative from indices directly
virtual Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const = 0;
virtual ~IEmbeddingLayer() {}
};
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
const bool embeddingFix_;
const float dropoutEmbeddings_; // this drops out full embedding vectors
const bool inference_;
const size_t batchIndex_;
mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
Ptr<Options> options,
const std::string& prefix,
size_t batchIndex,
float dropoutEmbeddings,
bool embeddingFix)
: LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", prefix)),
embeddingFix_(embeddingFix),
dropoutEmbeddings_(dropoutEmbeddings),
inference_(options->get<bool>("inference", false)),
batchIndex_(options->get<size_t>("index", batchIndex)) {}
virtual ~EncoderDecoderLayerBase() {}
private:
Ptr<IEmbeddingLayer> createEmbeddingLayer() const;
Ptr<IEmbeddingLayer> createULREmbeddingLayer() const;
public:
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
namespace mlp {
class Dense : public LayerBase, public IUnaryLayer {
public:
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
auto name = opt<std::string>("prefix");
auto dim = opt<int>("dim");
auto useLayerNorm = opt<bool>("layer-normalization", false);
auto useNematusNorm = opt<bool>("nematus-normalization", false);
auto activation = (act)opt<int>("activation", (int)act::linear);
auto g = graph_;
std::vector<Expr> outputs;
size_t i = 0;
std::string num;
for(auto&& in : inputs) {
if(inputs.size() > 1)
num = std::to_string(i);
Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
if(useLayerNorm) {
if(useNematusNorm) {
auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
} else {
auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
outputs.push_back(layerNorm(dot(in, W), gamma, b));
}
} else {
outputs.push_back(affine(in, W, b));
}
i++;
}
// clang-format off
switch(activation) {
case act::linear: return plus(outputs);
case act::tanh: return tanh(outputs);
case act::sigmoid: return sigmoid(outputs);
case act::ReLU: return relu(outputs);
case act::LeakyReLU: return leakyrelu(outputs);
case act::PReLU: return prelu(outputs);
case act::swish: return swish(outputs);
default: return plus(outputs);
}
// clang-format on
};
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
} // namespace mlp
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
static inline std::function<Expr(Expr)> activationByName(const std::string& actName) {
if (actName == "relu")
return (ActivationFunction*)relu;
else if (actName == "swish")
return (ActivationFunction*)swish;
else if (actName == "gelu")
return (ActivationFunction*)gelu;
else if (actName == "sigmoid")
return (ActivationFunction*)sigmoid;
else if (actName == "") // return identity function if activation name is empty
return [](Expr x) { return x; };
ABORT("Invalid activation name '{}'", actName);
}
// like affine() but with built-in parameters, activation, and dropout
static inline Expr denseInline(Expr x,
std::string prefix,
std::string suffix,
int outDim,
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
std::string actName = "",
float dropProb = 0.0f) {
auto graph = x->graph();
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, initFn);
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
if(actName == "relu") {
x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework
} else {
x = affine(x, W, b);
x = activationByName(actName)(x);
}
x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}
static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones());
auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros());
return marian::layerNorm(x, scale, bias, 1e-6f);
}
static inline Expr rmsNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
auto scale = x->graph()->param(prefix + "_rms_scale" + suffix, {1, dimModel}, inits::ones());
return marian::rmsNorm(x, scale, nullptr, 1e-6f);
}
} // namespace marian