Class MnistFeedForwardNet

Inheritance Relationships

Base Type

  • public IModel

Derived Type

Class Documentation

class MnistFeedForwardNet : public IModel

Subclassed by marian::models::MnistLeNet

Public Types

typedef data::MNISTData dataset_type

Public Functions

template<class ...Args>
MnistFeedForwardNet(Ptr<Options> options, Args...)
virtual ~MnistFeedForwardNet()
virtual Logits build(Ptr<ExpressionGraph> graph, Ptr<data::Batch> batch, bool = false)
void load(Ptr<ExpressionGraph>, const std::string&, bool)
void save(Ptr<ExpressionGraph>, const std::string&, bool)
void save(Ptr<ExpressionGraph>, const std::string&)
Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph>, size_t)
virtual void clear(Ptr<ExpressionGraph> graph)

Protected Functions

virtual Expr apply(Ptr<ExpressionGraph> g, Ptr<data::Batch> batch, bool = false)

Builds an expression graph representing a feed-forward classifier.

Return

a shared pointer to the newly constructed expression graph

Parameters
  • dims: number of nodes in each layer of the feed-forward classifier

  • batch: a batch of training or testing examples

  • training: create a classifier for training or for inference only

Protected Attributes

Ptr<Options> options_
const bool inference_ = {false}