.. _program_listing_file_src_examples_mnist_model.h: Program Listing for File model.h ================================ |exhale_lsh| :ref:`Return to documentation for file ` (``src/examples/mnist/model.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include "common/definitions.h" #include "graph/expression_graph.h" #include "models/costs.h" #include "models/model_base.h" #include "layers/loss.h" #include "examples/mnist/dataset.h" namespace marian { namespace models { // @TODO: looking at this file, simplify the new RationalLoss idea. Here it gets too complicated class MNISTCrossEntropyCost : public ICost { public: MNISTCrossEntropyCost() {} Ptr apply(Ptr model, Ptr graph, Ptr batch, bool clearGraph = true) override { auto top = model->build(graph, batch, clearGraph).getLogits(); auto vfLabels = std::static_pointer_cast(batch)->labels(); // convert float to IndexType std::vector vLabels(vfLabels.begin(), vfLabels.end()); auto labels = graph->indices(vLabels); // Define a top-level node for training // use CE loss auto loss = sum(cross_entropy(top, labels), /*axis =*/ 0); auto multiLoss = New(); multiLoss->push_back({loss, (float)vLabels.size()}); return multiLoss; } }; class MNISTLogsoftmax : public ILogProb { public: MNISTLogsoftmax() {} virtual ~MNISTLogsoftmax(){} Logits apply(Ptr model, Ptr graph, Ptr batch, bool clearGraph = true) override { auto top = model->build(graph, batch, clearGraph); return top.applyUnaryFunction(logsoftmax); } }; class MnistFeedForwardNet : public IModel { public: typedef data::MNISTData dataset_type; template MnistFeedForwardNet(Ptr options, Args... /*args*/) : options_(options), inference_(options->get("inference", false)) {} virtual ~MnistFeedForwardNet(){} virtual Logits build(Ptr graph, Ptr batch, bool /*clean*/ = false) override { return Logits(apply(graph, batch, inference_)); } void load(Ptr /*graph*/, const std::string& /*name*/, bool) override { LOG(critical, "Loading MNIST model is not supported"); } void save(Ptr /*graph*/, const std::string& /*name*/, bool) override { LOG(critical, "Saving MNIST model is not supported"); } void save(Ptr /*graph*/, const std::string& /*name*/) { LOG(critical, "Saving MNIST model is not supported"); } Ptr collectStats(Ptr /*graph*/, size_t /*multiplier*/) { LOG(critical, "Collecting stats in MNIST model is not supported"); return nullptr; } virtual void clear(Ptr graph) override { graph->clear(); }; protected: Ptr options_; const bool inference_{false}; virtual Expr apply(Ptr g, Ptr batch, bool /*inference*/ = false) { const std::vector dims = {784, 2048, 2048, 10}; // Start with an empty expression graph clear(g); // Create an input layer of shape batchSize x numFeatures and populate it // with training features auto features = std::static_pointer_cast(batch)->features(); auto x = g->constant({(int)batch->size(), dims[0]}, inits::fromVector(features)); // Construct hidden layers std::vector layers, weights, biases; for(size_t i = 0; i < dims.size() - 1; ++i) { int in = dims[i]; int out = dims[i + 1]; if(i == 0) { // Create a dropout node as the parent of x, // and place that dropout node as the value of layers[0] layers.emplace_back(dropout(x, 0.2)); } else { // Multiply the matrix in layers[i-1] by the matrix in weights[i-1] // Take the result, and perform matrix addition on biases[i-1]. // Wrap the result in rectified linear activation function, // and finally wrap that in a dropout node layers.emplace_back(dropout( relu(affine(layers.back(), weights.back(), biases.back())), 0.2)); } // Construct a weight node for the outgoing connections from layer i weights.emplace_back( g->param("W" + std::to_string(i), {in, out}, inits::glorotUniform())); // Construct a bias node. These weights are initialized to zero biases.emplace_back( g->param("b" + std::to_string(i), {1, out}, inits::zeros())); } // Perform matrix multiplication and addition for the last layer auto last = affine(layers.back(), weights.back(), biases.back()); return last; } }; } // namespace models } // namespace marian