Program Listing for File model.h¶
↰ Return to documentation for file (src/examples/mnist/model.h
)
#pragma once
#include <iomanip>
#include <iostream>
#include <memory>
#include "common/definitions.h"
#include "graph/expression_graph.h"
#include "models/costs.h"
#include "models/model_base.h"
#include "layers/loss.h"
#include "examples/mnist/dataset.h"
namespace marian {
namespace models {
// @TODO: looking at this file, simplify the new RationalLoss idea. Here it gets too complicated
class MNISTCrossEntropyCost : public ICost {
public:
MNISTCrossEntropyCost() {}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto top = model->build(graph, batch, clearGraph).getLogits();
auto vfLabels = std::static_pointer_cast<data::DataBatch>(batch)->labels();
// convert float to IndexType
std::vector<IndexType> vLabels(vfLabels.begin(), vfLabels.end());
auto labels = graph->indices(vLabels);
// Define a top-level node for training
// use CE loss
auto loss = sum(cross_entropy(top, labels), /*axis =*/ 0);
auto multiLoss = New<SumMultiRationalLoss>();
multiLoss->push_back({loss, (float)vLabels.size()});
return multiLoss;
}
};
class MNISTLogsoftmax : public ILogProb {
public:
MNISTLogsoftmax() {}
virtual ~MNISTLogsoftmax(){}
Logits apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto top = model->build(graph, batch, clearGraph);
return top.applyUnaryFunction(logsoftmax);
}
};
class MnistFeedForwardNet : public IModel {
public:
typedef data::MNISTData dataset_type;
template <class... Args>
MnistFeedForwardNet(Ptr<Options> options, Args... /*args*/)
: options_(options), inference_(options->get<bool>("inference", false)) {}
virtual ~MnistFeedForwardNet(){}
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool /*clean*/ = false) override {
return Logits(apply(graph, batch, inference_));
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
LOG(critical, "Loading MNIST model is not supported");
}
void save(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
LOG(critical, "Saving MNIST model is not supported");
}
void save(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/) {
LOG(critical, "Saving MNIST model is not supported");
}
Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> /*graph*/,
size_t /*multiplier*/) {
LOG(critical, "Collecting stats in MNIST model is not supported");
return nullptr;
}
virtual void clear(Ptr<ExpressionGraph> graph) override { graph->clear(); };
protected:
Ptr<Options> options_;
const bool inference_{false};
virtual Expr apply(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool /*inference*/ = false) {
const std::vector<int> dims = {784, 2048, 2048, 10};
// Start with an empty expression graph
clear(g);
// Create an input layer of shape batchSize x numFeatures and populate it
// with training features
auto features
= std::static_pointer_cast<data::DataBatch>(batch)->features();
auto x = g->constant({(int)batch->size(), dims[0]},
inits::fromVector(features));
// Construct hidden layers
std::vector<Expr> layers, weights, biases;
for(size_t i = 0; i < dims.size() - 1; ++i) {
int in = dims[i];
int out = dims[i + 1];
if(i == 0) {
// Create a dropout node as the parent of x,
// and place that dropout node as the value of layers[0]
layers.emplace_back(dropout(x, 0.2));
} else {
// Multiply the matrix in layers[i-1] by the matrix in weights[i-1]
// Take the result, and perform matrix addition on biases[i-1].
// Wrap the result in rectified linear activation function,
// and finally wrap that in a dropout node
layers.emplace_back(dropout(
relu(affine(layers.back(), weights.back(), biases.back())), 0.2));
}
// Construct a weight node for the outgoing connections from layer i
weights.emplace_back(
g->param("W" + std::to_string(i), {in, out}, inits::glorotUniform()));
// Construct a bias node. These weights are initialized to zero
biases.emplace_back(
g->param("b" + std::to_string(i), {1, out}, inits::zeros()));
}
// Perform matrix multiplication and addition for the last layer
auto last = affine(layers.back(), weights.back(), biases.back());
return last;
}
};
} // namespace models
} // namespace marian