Program Listing for File constructors.h¶
↰ Return to documentation for file (src/layers/constructors.h
)
#pragma once
#include "layers/embedding.h"
#include "layers/factory.h"
#include "layers/generic.h"
#include "layers/output.h"
namespace marian {
namespace mlp {
struct LayerFactory : public Factory {
virtual Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
class DenseFactory : public LayerFactory {
public:
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override {
return New<Dense>(graph, options_);
}
DenseFactory clone() {
DenseFactory aClone;
aClone.options_->merge(options_);
return aClone;
}
};
typedef Accumulator<DenseFactory> dense;
struct LogitLayerFactory : public Factory {
using Factory::Factory;
virtual Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
protected:
std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
public:
Accumulator<OutputFactory> tieTransposed(const std::string& tied) {
tiedTransposedName_ = tied;
return Accumulator<OutputFactory>(*this);
}
void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
auto output = New<Output>(graph, options_);
output->tieTransposed(graph->get(tiedTransposedName_));
output->setShortlist(shortlist_);
return output;
}
OutputFactory clone() {
OutputFactory aClone;
aClone.options_->merge(options_);
aClone.tiedTransposedName_ = tiedTransposedName_;
aClone.shortlist_ = shortlist_;
return aClone;
}
};
typedef Accumulator<OutputFactory> output;
class MLP : public IUnaryLogitLayer, public IHasShortList {
protected:
Ptr<ExpressionGraph> graph_;
Ptr<Options> options_;
std::vector<Ptr<IUnaryLayer>> layers_;
public:
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
Expr apply(const std::vector<Expr>& av) override {
Expr output;
if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
for(size_t i = 1; i < layers_.size(); ++i)
output = layers_[i]->apply(output);
return output;
}
Logits applyAsLogits(const std::vector<Expr>& av) override {
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
// return type
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
ABORT_IF(
!lastLayer,
"MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
if(layers_.size() == 1) {
if(av.size() == 1)
return lastLayer->applyAsLogits(av[0]);
else
return lastLayer->applyAsLogits(av);
} else {
Expr output;
if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
for(size_t i = 1; i < layers_.size() - 1; ++i)
output = layers_[i]->apply(output);
return lastLayer->applyAsLogits(output);
}
}
Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
ABORT_IF(
!p,
"setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
void clear() override final {
auto p = tryAsHasShortlist();
if(p)
p->clear();
}
private:
Ptr<IHasShortList> tryAsHasShortlist() const {
return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
}
};
class MLPFactory : public Factory {
using Factory::Factory;
private:
std::vector<Ptr<LayerFactory>> layers_;
public:
Ptr<MLP> construct(Ptr<ExpressionGraph> graph) {
auto mlp = New<MLP>(graph, options_);
for(auto layer : layers_) {
layer->mergeOpts(options_);
mlp->push_back(layer->construct(graph));
}
return mlp;
}
template <class LF>
Accumulator<MLPFactory> push_back(const LF& lf) {
layers_.push_back(New<LF>(lf));
return Accumulator<MLPFactory>(*this);
}
// Special case for last layer, which may be a IUnaryLogitLayer. Requires some hackery,
// which will go away if we get rid of the abstract factories, and instead just construct
// all layers immediately, which is my long-term goal for Marian.
private:
template <class WrappedFactory>
class AsLayerFactory : public LayerFactory {
WrappedFactory us;
public:
AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
return p;
}
};
template <class WrappedFactory>
static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
return wrapped;
}
public:
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
// layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
return Accumulator<MLPFactory>(*this);
}
};
typedef Accumulator<MLPFactory> mlp;
} // namespace mlp
typedef ConstructingFactory<Embedding> EmbeddingFactory;
typedef ConstructingFactory<ULREmbedding> ULREmbeddingFactory;
typedef Accumulator<EmbeddingFactory> embedding;
typedef Accumulator<ULREmbeddingFactory> ulr_embedding;
} // namespace marian