Program Listing for File expression_operators.h¶
↰ Return to documentation for file (src/graph/expression_operators.h
)
#pragma once
#include "graph/expression_graph.h"
#include "graph/node_initializers.h"
namespace marian {
Expr debug(Expr a, const std::string& message = "");
Expr checkpoint(Expr a);
typedef Expr(ActivationFunction)(Expr);
typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFunctor;
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0);
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0);
typedef std::function<void(Expr)> LambdaNodeCallback;
Expr callback(Expr node, LambdaNodeCallback call);
Expr plus(const std::vector<Expr>& nodes);
Expr sigmoid(Expr a);
Expr sigmoid(const std::vector<Expr>& nodes);
Expr swish(Expr a);
Expr swish(const std::vector<Expr>& nodes);
Expr gelu(Expr a);
Expr gelu(const std::vector<Expr>&);
Expr tanh(const std::vector<Expr>& nodes);
template <typename... Args>
Expr tanh(Args... args) {
std::vector<Expr> nodes{args...};
return tanh(nodes);
}
Expr relu(Expr a);
Expr relu(const std::vector<Expr>& nodes);
Expr leakyrelu(Expr a);
Expr leakyrelu(const std::vector<Expr>& nodes);
Expr prelu(Expr a, float alpha = 0.01);
Expr prelu(const std::vector<Expr>&, float alpha = 0.01);
// Exponentiation and Logarithmic functions
Expr log(Expr a);
Expr exp(Expr a);
// Trigonometric functions
Expr sin(Expr a);
Expr cos(Expr a);
Expr tan(Expr a);
Expr operator-(Expr a);
/*********************************************************/
Expr operator+(Expr a, Expr b);
Expr operator+(float a, Expr b);
Expr operator+(Expr a, float b);
Expr operator-(Expr a, Expr b);
Expr operator-(float a, Expr b);
Expr operator-(Expr a, float b);
Expr operator*(Expr a, Expr b);
Expr operator*(float a, Expr b);
Expr operator*(Expr a, float b);
Expr operator/(Expr a, Expr b);
Expr operator/(float a, Expr b);
Expr operator/(Expr a, float b);
Expr sqrt(Expr a, float eps = 0.f);
Expr square(Expr a);
Expr abs(Expr a);
// Expr pow(Expr a, Expr b);
// Expr pow(float a, Expr b);
// Expr pow(Expr a, float b);
Expr logaddexp(Expr a, Expr b);
/*
* Element-wise min/max
* Performs an element-wise min max comparison between expressions.
* @see min, max for axis level operations
* @see MinimumNodeOp, MaximumNodeOp
* @todo implement version without ExpressionGraph::constant.
*/
Expr maximum(Expr a, Expr b);
Expr maximum(float a, Expr b);
Expr maximum(Expr a, float b);
Expr minimum(Expr a, Expr b);
Expr minimum(float a, Expr b);
Expr minimum(Expr a, float b);
typedef std::tuple<Expr, Expr> Expr2;
template <int I>
Expr get(Expr2 tuple) { return std::get<I>(tuple); }
Expr2 topk(Expr a, int k, int axis, bool descending = true);
Expr2 argmax(Expr a, int axis);
Expr2 argmin(Expr a, int axis);
/*
* Expr-Expr comparisons
*/
Expr lt(Expr a, Expr b);
Expr eq(Expr a, Expr b);
Expr gt(Expr a, Expr b);
Expr ge(Expr a, Expr b);
Expr ne(Expr a, Expr b);
Expr le(Expr a, Expr b);
/*
* Float-Expr comparisons
* Floats are promoted to a @ref ExpressionGraph::constant and use the Expr-Expr methods
*/
Expr lt(float a, Expr b);
Expr eq(float a, Expr b);
Expr gt(float a, Expr b);
Expr ge(float a, Expr b);
Expr ne(float a, Expr b);
Expr le(float a, Expr b);
Expr lt(Expr a, float b);
Expr eq(Expr a, float b);
Expr gt(Expr a, float b);
Expr ge(Expr a, float b);
Expr ne(Expr a, float b);
Expr le(Expr a, float b);
Expr dot(Expr a,
Expr b,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr bdot(Expr a,
Expr b,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr bdot_legacy(Expr a,
Expr b,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr affine(Expr a,
Expr b,
Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr affineWithRelu(Expr a,
Expr b,
Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB = false);
Expr transpose(Expr a);
Expr transpose(Expr a, const std::vector<int>& axes);
Expr swapAxes(Expr x, int axis1, int axis2);
Expr cast(Expr a, Type type = Type::float32);
Expr concatenate(const std::vector<Expr>& concats, int ax = 0);
Expr repeat(Expr a, size_t repeats, int ax = 0);
Expr reshape(Expr a, Shape shape);
Expr clip(Expr a, float c);
Expr clipGradient(Expr a, float c);
Expr atleast_1d(Expr a);
Expr atleast_2d(Expr a);
Expr atleast_3d(Expr a);
Expr atleast_4d(Expr a);
Expr atleast_nd(Expr a, size_t dims);
static inline Expr constant_like(Expr a, const Ptr<inits::NodeInitializer>& init) {
return a->graph()->constant(a->shape(), init, a->value_type());
}
template<typename ElementType>
Expr constant_like(Expr a, const std::vector<ElementType>& v) { return constant_like(a, inits::fromVector(std::move(v))); }
template<typename ElementType>
Expr constant_like(Expr a, std::vector<ElementType>&& v) { return constant_like(a, inits::fromVector(v)); }
Expr flatten(Expr a);
Expr flatten_2d(Expr a);
Expr stopGradient(Expr a);
Expr gather(Expr a, int axis, Expr indices);
Expr scatter(Expr a, int axis, Expr indices, Expr source);
#if 0
// reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis.
// with broadcasting
auto knn = get<0>(KNN->apply(query, k)); // [beam, time, batch, k]
auto W = reshape(gather(Wt_, -2, flatten(knn)), {beam * time * batch, k, dim});
auto b = reshape(gather(b_, -1, flatten(knn)), {beam * time * batch, 1, k });
query = reshape(query, {beam * time * batch, 1, dim});
auto logits = bdot(query, W, false, true); // [beam * time * batch, 1, k]
logits = reshape(logits + b, {beam, time, batch, k}); // @TODO: add baffine node
auto shape = indices.shape();
shape.set(-1, 32000);
auto output = grep->constant(shape, inits::lowest(), logits->value_type());
output = scatter(output, -1, indices, logits);
// auto a = graph->constant({2,2,5,32000}, inits::fromValue(minimal))
// scatter(a, -1, indices, values)
// PyTorch does for out-of-place scatter: out = a.scatter(-1, indices, values)
Expr scatter(Expr a, int axis, Expr indices, Expr b);
#endif
Expr index_select(Expr a, int axis, Expr indices);
Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);
static inline Expr rows(Expr a, Expr indices) {
return index_select(a, 0, indices);
}
static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
return index_select(a, 0, indexVector);
}
static inline Expr cols(Expr a, Expr indices) {
return index_select(a, -1, indices);
}
static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
return index_select(a, -1, indexVector);
}
Expr slice(Expr a, int axis, Slice slice);
static inline Expr slice(Expr a, int axis, int index) {
return slice(a, axis, Slice(index));
}
static inline Expr narrow(Expr a, int axis, size_t start, size_t length) {
return slice(a, axis, Slice((int)start, (int)(start + length)));
}
/*********************************************************/
// Aggregations
Expr sum(Expr a, int ax = 0);
Expr mean(Expr a, int ax = 0);
Expr std(Expr a, int ax);
Expr var(Expr a, int ax);
Expr max(Expr a, int ax);
Expr min(Expr a, int ax);
Expr prod(Expr a, int ax);
Expr logsumexp(Expr a, int ax);
Expr softmax(Expr x, int axis = -1);
Expr softmax(Expr a, Expr zeroOneMask, int axis = -1);
Expr logsoftmax(Expr a);
Expr cross_entropy(Expr a, Expr b, float labelSmoothingAlpha = 0.f, Type outputType = Type::float32);
Expr unlikelihood(Expr a, Expr b);
Expr scalar_product(Expr a, Expr b, int ax = 0);
Expr weighted_average(Expr in, Expr weights, int ax = 0);
Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
Expr highway(Expr y, Expr x, Expr t);
Expr highway(const std::string prefix, Expr x);
static inline Expr dropout(Expr x, Expr mask) {
if (mask)
return x * mask;
else
return x;
}
static inline Expr dropout(Expr x, float dropProb, Shape shape) {
if(dropProb == 0)
return x;
auto graph = x->graph();
auto mask = graph->dropoutMask(dropProb, shape);
return dropout(x, mask);
}
static inline Expr dropout(Expr x, float dropProb) {
if(dropProb == 0)
return x;
return dropout(x, dropProb, x->shape());
}
Expr shift(Expr x, Shape shift, float padValue = 0);
Expr convert2cudnnFormat(Expr x);
Expr convertFromcudnnFormat(Expr x);
Expr avg_pooling(Expr x,
int height,
int width,
int padHeight = 0,
int padWidth = 0,
int strideHeight = 1,
int strideWidth = 1);
Expr max_pooling(Expr x,
int height,
int width,
int padHeight = 0,
int padWidth = 0,
int strideHeight = 1,
int strideWidth = 1);
Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven = false);
} // namespace marian