.. _program_listing_file_src_models_encoder_decoder.h: Program Listing for File encoder_decoder.h ========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/models/encoder_decoder.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "models/decoder.h" #include "models/encoder.h" #include "models/model_base.h" #include "models/states.h" namespace marian { class IEncoderDecoder : public models::IModel { public: virtual ~IEncoderDecoder() {} virtual void load(Ptr graph, const std::vector& items, bool markedReloaded = true) = 0; virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override = 0; virtual void mmap(Ptr graph, const void* ptr, bool markedReloaded = true) = 0; virtual void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override = 0; virtual void clear(Ptr graph) override = 0; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override = 0; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) = 0; virtual Ptr startState(Ptr graph, Ptr batch) = 0; virtual Ptr step(Ptr graph, Ptr state, const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] const Words& words, // [beamIndex * activeBatchSize + batchIndex] const std::vector& batchIndices, // [batchIndex] int beamSize) = 0; virtual Ptr getOptions() = 0; virtual void setShortlistGenerator( Ptr shortlistGenerator) = 0; virtual Ptr getShortlist() = 0; virtual data::SoftAlignment getAlignment() = 0; }; class EncoderDecoder : public IEncoderDecoder, public LayerBase { protected: Ptr shortlistGenerator_; const std::string prefix_; const bool inference_{ false }; std::vector> encoders_; std::vector> decoders_; std::set modelFeatures_; Config::YamlNode getModelParameters(); std::string getModelParametersAsString(); virtual void createDecoderConfig(const std::string& name); public: typedef data::Corpus dataset_type; EncoderDecoder(Ptr graph, Ptr options); virtual Ptr getOptions() override { return options_; } std::vector>& getEncoders(); void push_back(Ptr encoder); std::vector>& getDecoders(); void push_back(Ptr decoder); virtual void load(Ptr graph, const std::vector& items, bool markedReloaded = true) override; virtual void load(Ptr graph, const std::string& name, bool markedReloaded = true) override; virtual void mmap(Ptr graph, const void* ptr, bool markedReloaded = true) override; virtual void save(Ptr graph, const std::string& name, bool saveTranslatorConfig = false) override; virtual void clear(Ptr graph) override; template T opt(const std::string& key) { return options_->get(key); } template T opt(const std::string& key, const T& def) { return options_->get(key, def); } template void set(std::string key, T value) { options_->set(key, value); } virtual void setShortlistGenerator( Ptr shortlistGenerator) override { shortlistGenerator_ = shortlistGenerator; }; virtual Ptr getShortlist() override { return decoders_[0]->getShortlist(); }; // convert alignment tensors that live GPU-side into a CPU-side vector of vectors virtual data::SoftAlignment getAlignment() override { data::SoftAlignment softAlignments; auto alignments = decoders_[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] for(auto alignment : alignments) { // [beam depth, max src length, batch size, 1] softAlignments.push_back({}); alignment->val()->get(softAlignments.back()); } return softAlignments; // [tgt index][beam depth * max src length * batch size] }; /*********************************************************************/ virtual Ptr startState(Ptr graph, Ptr batch) override; virtual Ptr step(Ptr graph, Ptr state, const std::vector& hypIndices, const Words& words, const std::vector& batchIndices, int beamSize) override; virtual Ptr stepAll(Ptr graph, Ptr batch, bool clearGraph = true); virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override; virtual Logits build(Ptr graph, Ptr batch, bool clearGraph = true) override; }; } // namespace marian