.. _program_listing_file_src_translator_scorers.h: Program Listing for File scorers.h ================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/translator/scorers.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "marian.h" #include "data/shortlist.h" #include "models/model_factory.h" #include "3rd_party/mio/mio.hpp" namespace marian { class ScorerState { public: virtual ~ScorerState(){} virtual Logits getLogProbs() const = 0; virtual void blacklist(Expr /*totalCosts*/, Ptr /*batch*/){}; }; class Scorer { protected: std::string name_; float weight_; public: Scorer(const std::string& name, float weight) : name_(name), weight_(weight) {} virtual ~Scorer(){} std::string getName() { return name_; } float getWeight() { return weight_; } virtual void clear(Ptr) = 0; virtual Ptr startState(Ptr, Ptr) = 0; virtual Ptr step(Ptr, Ptr, const std::vector&, const Words&, const std::vector& batchIndices, int beamSize) = 0; virtual void init(Ptr) {} virtual void setShortlistGenerator(Ptr /*shortlistGenerator*/){}; virtual Ptr getShortlist() { return nullptr; }; virtual std::vector getAlignment() { return {}; }; }; class ScorerWrapperState : public ScorerState { protected: Ptr state_; public: ScorerWrapperState(Ptr state) : state_(state) {} virtual ~ScorerWrapperState() {} virtual Ptr getState() { return state_; } virtual Logits getLogProbs() const override { return state_->getLogProbs(); }; virtual void blacklist(Expr totalCosts, Ptr batch) override { state_->blacklist(totalCosts, batch); } }; // class to wrap IEncoderDecoder in a Scorer interface class ScorerWrapper : public Scorer { private: Ptr encdec_; std::string fname_; std::vector items_; const void* ptr_; public: ScorerWrapper(Ptr encdec, const std::string& name, float weight, std::vector& items) : Scorer(name, weight), encdec_(std::static_pointer_cast(encdec)), items_(items), ptr_{0} {} ScorerWrapper(Ptr encdec, const std::string& name, float weight, const std::string& fname) : Scorer(name, weight), encdec_(std::static_pointer_cast(encdec)), fname_(fname), ptr_{0} {} ScorerWrapper(Ptr encdec, const std::string& name, float weight, const void* ptr) : Scorer(name, weight), encdec_(std::static_pointer_cast(encdec)), ptr_{ptr} {} virtual ~ScorerWrapper() {} virtual void init(Ptr graph) override { graph->switchParams(getName()); if(!items_.empty()) encdec_->load(graph, items_); else if(ptr_) encdec_->mmap(graph, ptr_); else encdec_->load(graph, fname_); } virtual void clear(Ptr graph) override { graph->switchParams(getName()); encdec_->clear(graph); } virtual Ptr startState(Ptr graph, Ptr batch) override { graph->switchParams(getName()); return New(encdec_->startState(graph, batch)); } virtual Ptr step(Ptr graph, Ptr state, const std::vector& hypIndices, const Words& words, const std::vector& batchIndices, int beamSize) override { graph->switchParams(getName()); auto wrapperState = std::dynamic_pointer_cast(state); auto newState = encdec_->step(graph, wrapperState->getState(), hypIndices, words, batchIndices, beamSize); return New(newState); } virtual void setShortlistGenerator( Ptr shortlistGenerator) override { encdec_->setShortlistGenerator(shortlistGenerator); }; virtual Ptr getShortlist() override { return encdec_->getShortlist(); }; virtual std::vector getAlignment() override { // This is called during decoding, where alignments only exist for the last time step. Hence front(). // This makes as copy. @TODO: It should be OK to return this as a const&. return encdec_->getAlignment().front(); // [beam depth * max src length * batch size] } }; Ptr scorerByType(const std::string& fname, float weight, std::vector items, Ptr options); Ptr scorerByType(const std::string& fname, float weight, const std::string& model, Ptr config); std::vector> createScorers(Ptr options); std::vector> createScorers(Ptr options, const std::vector> models); Ptr scorerByType(const std::string& fname, float weight, const void* ptr, Ptr config); std::vector> createScorers(Ptr options, const std::vector& ptrs); std::vector> createScorers(Ptr options, const std::vector& mmaps); } // namespace marian