.. _program_listing_file_src_layers_constructors.h: Program Listing for File constructors.h ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``src/layers/constructors.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "layers/embedding.h" #include "layers/factory.h" #include "layers/generic.h" #include "layers/output.h" namespace marian { namespace mlp { struct LayerFactory : public Factory { virtual Ptr construct(Ptr graph) = 0; }; class DenseFactory : public LayerFactory { public: Ptr construct(Ptr graph) override { return New(graph, options_); } DenseFactory clone() { DenseFactory aClone; aClone.options_->merge(options_); return aClone; } }; typedef Accumulator dense; struct LogitLayerFactory : public Factory { using Factory::Factory; virtual Ptr construct(Ptr graph) = 0; }; class OutputFactory : public LogitLayerFactory { using LogitLayerFactory::LogitLayerFactory; protected: std::string tiedTransposedName_; Ptr shortlist_; public: Accumulator tieTransposed(const std::string& tied) { tiedTransposedName_ = tied; return Accumulator(*this); } void setShortlist(Ptr shortlist) { shortlist_ = shortlist; } Ptr construct(Ptr graph) override { auto output = New(graph, options_); output->tieTransposed(graph->get(tiedTransposedName_)); output->setShortlist(shortlist_); return output; } OutputFactory clone() { OutputFactory aClone; aClone.options_->merge(options_); aClone.tiedTransposedName_ = tiedTransposedName_; aClone.shortlist_ = shortlist_; return aClone; } }; typedef Accumulator output; class MLP : public IUnaryLogitLayer, public IHasShortList { protected: Ptr graph_; Ptr options_; std::vector> layers_; public: MLP(Ptr graph, Ptr options) : graph_(graph), options_(options) {} Expr apply(const std::vector& av) override { Expr output; if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); for(size_t i = 1; i < layers_.size(); ++i) output = layers_[i]->apply(output); return output; } Logits applyAsLogits(const std::vector& av) override { // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different // return type auto lastLayer = std::dynamic_pointer_cast(layers_.back()); ABORT_IF( !lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); if(layers_.size() == 1) { if(av.size() == 1) return lastLayer->applyAsLogits(av[0]); else return lastLayer->applyAsLogits(av); } else { Expr output; if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); for(size_t i = 1; i < layers_.size() - 1; ++i) output = layers_[i]->apply(output); return lastLayer->applyAsLogits(output); } } Expr apply(Expr e) override { return apply(std::vector{e}); } Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector{e}); } void push_back(Ptr layer) { layers_.push_back(layer); } void push_back(Ptr layer) { layers_.push_back(layer); } void setShortlist(Ptr shortlist) override final { auto p = tryAsHasShortlist(); ABORT_IF( !p, "setShortlist() called on an MLP with an output layer that does not support short lists"); p->setShortlist(shortlist); } void clear() override final { auto p = tryAsHasShortlist(); if(p) p->clear(); } private: Ptr tryAsHasShortlist() const { return std::dynamic_pointer_cast(layers_.back()); } }; class MLPFactory : public Factory { using Factory::Factory; private: std::vector> layers_; public: Ptr construct(Ptr graph) { auto mlp = New(graph, options_); for(auto layer : layers_) { layer->mergeOpts(options_); mlp->push_back(layer->construct(graph)); } return mlp; } template Accumulator push_back(const LF& lf) { layers_.push_back(New(lf)); return Accumulator(*this); } // Special case for last layer, which may be a IUnaryLogitLayer. Requires some hackery, // which will go away if we get rid of the abstract factories, and instead just construct // all layers immediately, which is my long-term goal for Marian. private: template class AsLayerFactory : public LayerFactory { WrappedFactory us; public: AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} Ptr construct(Ptr graph) override final { auto p = std::static_pointer_cast(us.construct(graph)); ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); return p; } }; template static inline AsLayerFactory asLayerFactory(const WrappedFactory& wrapped) { return wrapped; } public: Accumulator push_back(const Accumulator& lf) { push_back(AsLayerFactory(lf)); // layers_.push_back(New>(asLayerFactory((OutputFactory&)lf))); return Accumulator(*this); } }; typedef Accumulator mlp; } // namespace mlp typedef ConstructingFactory EmbeddingFactory; typedef ConstructingFactory ULREmbeddingFactory; typedef Accumulator embedding; typedef Accumulator ulr_embedding; } // namespace marian