Program Listing for File validator.h¶
↰ Return to documentation for file (src/examples/mnist/validator.h
)
#pragma once
#include "common/options.h"
#include "data/batch_generator.h"
#include "graph/expression_graph.h"
#include "models/model_base.h"
#include "training/validator.h"
#include "examples/mnist/dataset.h"
using namespace marian;
namespace marian {
class MNISTAccuracyValidator : public Validator<data::MNISTData, models::IModel> {
public:
MNISTAccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
createBatchGenerator(/*isTranslating=*/false);
builder_ = models::createModelFromOptions(options, models::usage::translation);
}
virtual ~MNISTAccuracyValidator(){}
virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& /*graphs*/) override {
LOG(warn, "Keeping best model for MNIST examples is not supported");
}
std::string type() override { return "accuracy"; }
protected:
virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
float correct = 0;
size_t samples = 0;
for(auto batch : *batchGenerator_) {
auto probs = builder_->build(graphs[0], batch, true).getLogits();
graphs[0]->forward();
std::vector<float> scores;
probs->val()->get(scores);
correct += countCorrect(scores, batch->labels());
samples += batch->size();
}
return correct / float(samples);
}
private:
float countCorrect(const std::vector<float>& probs, const std::vector<float>& labels) {
size_t numLabels = probs.size() / labels.size();
float numCorrect = 0;
for(size_t i = 0; i < probs.size(); i += numLabels) {
auto pred = std::distance(probs.begin() + i,
std::max_element(probs.begin() + i, probs.begin() + i + numLabels));
if(pred == labels[i / numLabels])
++numCorrect;
}
return numCorrect;
}
};
} // namespace marian