.. _program_listing_file_src_examples_mnist_validator.h: Program Listing for File validator.h ==================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/examples/mnist/validator.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "common/options.h" #include "data/batch_generator.h" #include "graph/expression_graph.h" #include "models/model_base.h" #include "training/validator.h" #include "examples/mnist/dataset.h" using namespace marian; namespace marian { class MNISTAccuracyValidator : public Validator { public: MNISTAccuracyValidator(Ptr options) : Validator(std::vector>(), options, false) { createBatchGenerator(/*isTranslating=*/false); builder_ = models::createModelFromOptions(options, models::usage::translation); } virtual ~MNISTAccuracyValidator(){} virtual void keepBest(const std::vector>& /*graphs*/) override { LOG(warn, "Keeping best model for MNIST examples is not supported"); } std::string type() override { return "accuracy"; } protected: virtual float validateBG(const std::vector>& graphs) override { float correct = 0; size_t samples = 0; for(auto batch : *batchGenerator_) { auto probs = builder_->build(graphs[0], batch, true).getLogits(); graphs[0]->forward(); std::vector scores; probs->val()->get(scores); correct += countCorrect(scores, batch->labels()); samples += batch->size(); } return correct / float(samples); } private: float countCorrect(const std::vector& probs, const std::vector& labels) { size_t numLabels = probs.size() / labels.size(); float numCorrect = 0; for(size_t i = 0; i < probs.size(); i += numLabels) { auto pred = std::distance(probs.begin() + i, std::max_element(probs.begin() + i, probs.begin() + i + numLabels)); if(pred == labels[i / numLabels]) ++numCorrect; } return numCorrect; } }; } // namespace marian