Program Listing for File classifier.h¶
↰ Return to documentation for file (src/models/classifier.h
)
#pragma once
#include "marian.h"
#include "models/states.h"
#include "layers/constructors.h"
#include "layers/factory.h"
namespace marian {
class ClassifierBase :public LayerBase {
using LayerBase::LayerBase;
protected:
Ptr<Options> options_;
const std::string prefix_{"classifier"};
const bool inference_{false};
const size_t batchIndex_{0};
public:
ClassifierBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", "classifier")),
inference_(options->get<bool>("inference", false)),
batchIndex_(options->get<size_t>("index", 1)) {} // assume that training input has batch index 0 and labels has 1
virtual ~ClassifierBase() {}
virtual Ptr<ClassifierState> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
template <typename T>
T opt(const std::string& key) const {
return options_->get<T>(key);
}
// Should be used to clear any batch-wise temporary objects if present
virtual void clear() = 0;
};
}