Program Listing for File mnist_ffnn.cpp

Return to documentation for file (src/examples/mnist/mnist_ffnn.cpp)

#include <algorithm>
#include <chrono>
#include <cstdio>
#include <iomanip>
#include <string>

#include "marian.h"

#include "examples/mnist/model.h"
#include "examples/mnist/training.h"
#include "training/graph_group_async.h"
#include "training/graph_group_singleton.h"
#include "training/graph_group_sync.h"

const std::vector<std::string> TRAIN_SET
    = {"../src/examples/mnist/train-images-idx3-ubyte",
       "../src/examples/mnist/train-labels-idx1-ubyte"};
const std::vector<std::string> VALID_SET
    = {"../src/examples/mnist/t10k-images-idx3-ubyte",
       "../src/examples/mnist/t10k-labels-idx1-ubyte"};

using namespace marian;

int main(int argc, char** argv) {
  auto options = parseOptions(argc, argv, cli::mode::training, false);

  if(!options->hasAndNotEmpty("train-sets"))
    options->set("train-sets", TRAIN_SET);
  if(!options->hasAndNotEmpty("valid-sets"))
    options->set("valid-sets", VALID_SET);

  if(options->get<std::string>("type") != "mnist-lenet")
    options->set("type", "mnist-ffnn");

  auto devices = Config::getDevices(options);

  if(devices.size() == 1)
    New<TrainMNIST<SingletonGraph>>(options)->run();
  else if(options->get<bool>("sync-sgd"))
    New<TrainMNIST<SyncGraphGroup>>(options)->run();
  else
    New<TrainMNIST<AsyncGraphGroup>>(options)->run();

  return 0;
}