Program Listing for File node_initializers.h¶
↰ Return to documentation for file (src/graph/node_initializers.h
)
// TODO: move to backend, into graph/
#pragma once
#include "common/config.h"
#include "tensors/tensor.h"
#include "tensors/tensor_operators.h"
#include <functional>
#include <random>
namespace marian {
class ExpressionGraph; // Forward declaration
namespace inits {
class NodeInitializer {
protected:
Weak<Allocator> allocator_;
public:
virtual void apply(Tensor t) = 0;
void setAllocator(Ptr<Allocator> allocator) { allocator_ = allocator; }
virtual ~NodeInitializer() {}
};
Ptr<NodeInitializer> dummy();
Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func);
Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func, Type intermediateType);
Ptr<NodeInitializer> fromValue(float value);
static Ptr<NodeInitializer> zeros() { return fromValue(0.0f); }
static Ptr<NodeInitializer> ones() { return fromValue(1.0f); }
Ptr<NodeInitializer> eye(float value = 1.f);
Ptr<NodeInitializer> normal(float mean = 0.f, float stddev = 1.f);
Ptr<NodeInitializer> uniform(float a = 0.f, float b = 1.f);
Ptr<NodeInitializer> bernoulli(float p, float scale = 1.f, float shift = 0.f);
Ptr<NodeInitializer> glorotUniform(bool fanIn = false, bool fanOut = false, float scale = 1.f);
Ptr<NodeInitializer> glorotNormal(bool fanIn = false, bool fanOut = false, float scale = 1.f);
Ptr<NodeInitializer> dropout(float dropoutProbability);
Ptr<NodeInitializer> gumbel(float eps = 1e-5f);
template <typename T>
Ptr<NodeInitializer> fromVector(const std::vector<T>& v);
template <typename T>
Ptr<NodeInitializer> fromVector(std::vector<T>&& v);
Ptr<NodeInitializer> fromSparseVector(std::pair<std::vector<size_t>, std::vector<float>>& v);
Ptr<NodeInitializer> fromItem(const io::Item& item);
Ptr<NodeInitializer> fromTensor(Tensor tensor);
Ptr<NodeInitializer> fromWord2vec(const std::string& file,
int dimVoc,
int dimEmb,
bool normalize = false);
Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);
template <typename T>
Ptr<NodeInitializer> range(T begin, T end, T step = (T)1);
} // namespace inits
} // namespace marian