Program Listing for File node_initializers.h

Return to documentation for file (src/graph/node_initializers.h)

// TODO: move to backend, into graph/
#pragma once

#include "common/config.h"
#include "tensors/tensor.h"
#include "tensors/tensor_operators.h"

#include <functional>
#include <random>

namespace marian {

class ExpressionGraph; // Forward declaration
namespace inits {

class NodeInitializer {
protected:
  Weak<Allocator> allocator_;

public:
  virtual void apply(Tensor t) = 0;
  void setAllocator(Ptr<Allocator> allocator) { allocator_ = allocator; }
  virtual ~NodeInitializer() {}
};

Ptr<NodeInitializer> dummy();

Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func);

Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func, Type intermediateType);

Ptr<NodeInitializer> fromValue(float value);

static Ptr<NodeInitializer> zeros() { return fromValue(0.0f); }

static Ptr<NodeInitializer> ones() { return fromValue(1.0f); }

Ptr<NodeInitializer> eye(float value = 1.f);

Ptr<NodeInitializer> normal(float mean = 0.f, float stddev = 1.f);

Ptr<NodeInitializer> uniform(float a = 0.f, float b = 1.f);

Ptr<NodeInitializer> bernoulli(float p, float scale = 1.f, float shift = 0.f);

Ptr<NodeInitializer> glorotUniform(bool fanIn = false, bool fanOut = false, float scale = 1.f);

Ptr<NodeInitializer> glorotNormal(bool fanIn = false, bool fanOut = false, float scale = 1.f);

Ptr<NodeInitializer> dropout(float dropoutProbability);

Ptr<NodeInitializer> gumbel(float eps = 1e-5f);

template <typename T>
Ptr<NodeInitializer> fromVector(const std::vector<T>& v);

template <typename T>
Ptr<NodeInitializer> fromVector(std::vector<T>&& v);

Ptr<NodeInitializer> fromSparseVector(std::pair<std::vector<size_t>, std::vector<float>>& v);

Ptr<NodeInitializer> fromItem(const io::Item& item);

Ptr<NodeInitializer> fromTensor(Tensor tensor);

Ptr<NodeInitializer> fromWord2vec(const std::string& file,
                                  int dimVoc,
                                  int dimEmb,
                                  bool normalize = false);

Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);

template <typename T>
Ptr<NodeInitializer> range(T begin, T end, T step = (T)1);

}  // namespace inits

}  // namespace marian