Program Listing for File expression_operators.h

Return to documentation for file (src/graph/expression_operators.h)

#pragma once
#include "graph/expression_graph.h"
#include "graph/node_initializers.h"

namespace marian {

Expr debug(Expr a, const std::string& message = "");

Expr checkpoint(Expr a);

typedef Expr(ActivationFunction)(Expr);

typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFunctor;

Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0);

Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0);

typedef std::function<void(Expr)> LambdaNodeCallback;
Expr callback(Expr node, LambdaNodeCallback call);

Expr plus(const std::vector<Expr>& nodes);

Expr sigmoid(Expr a);

Expr sigmoid(const std::vector<Expr>& nodes);

Expr swish(Expr a);

Expr swish(const std::vector<Expr>& nodes);

Expr gelu(Expr a);

Expr gelu(const std::vector<Expr>&);

Expr tanh(const std::vector<Expr>& nodes);

template <typename... Args>
Expr tanh(Args... args) {
  std::vector<Expr> nodes{args...};
  return tanh(nodes);

Expr relu(Expr a);

Expr relu(const std::vector<Expr>& nodes);

Expr leakyrelu(Expr a);

Expr leakyrelu(const std::vector<Expr>& nodes);

Expr prelu(Expr a, float alpha = 0.01);

Expr prelu(const std::vector<Expr>&, float alpha = 0.01);

// Exponentiation and Logarithmic functions
Expr log(Expr a);

Expr exp(Expr a);

// Trigonometric functions
Expr sin(Expr a);

Expr cos(Expr a);

Expr tan(Expr a);

Expr operator-(Expr a);


Expr operator+(Expr a, Expr b);
Expr operator+(float a, Expr b);
Expr operator+(Expr a, float b);

Expr operator-(Expr a, Expr b);
Expr operator-(float a, Expr b);
Expr operator-(Expr a, float b);

Expr operator*(Expr a, Expr b);
Expr operator*(float a, Expr b);
Expr operator*(Expr a, float b);

Expr operator/(Expr a, Expr b);
Expr operator/(float a, Expr b);
Expr operator/(Expr a, float b);

Expr sqrt(Expr a, float eps = 0.f);

Expr square(Expr a);

Expr abs(Expr a);

// Expr pow(Expr a, Expr b);
// Expr pow(float a, Expr b);
// Expr pow(Expr a, float b);

Expr logaddexp(Expr a, Expr b);

 * Element-wise min/max
 * Performs an element-wise min max comparison between expressions.
 * @see min, max for axis level operations
 * @see MinimumNodeOp, MaximumNodeOp
 * @todo implement version without ExpressionGraph::constant.

Expr maximum(Expr a, Expr b);

Expr maximum(float a, Expr b);

Expr maximum(Expr a, float b);

Expr minimum(Expr a, Expr b);

Expr minimum(float a, Expr b);

Expr minimum(Expr a, float b);

typedef std::tuple<Expr, Expr> Expr2;

template <int I>
Expr get(Expr2 tuple) { return std::get<I>(tuple); }

Expr2 topk(Expr a, int k, int axis, bool descending = true);

Expr2 argmax(Expr a, int axis);

Expr2 argmin(Expr a, int axis);

 * Expr-Expr comparisons
Expr lt(Expr a, Expr b);
Expr eq(Expr a, Expr b);
Expr gt(Expr a, Expr b);
Expr ge(Expr a, Expr b);
Expr ne(Expr a, Expr b);
Expr le(Expr a, Expr b);

 * Float-Expr comparisons
 * Floats are promoted to a @ref ExpressionGraph::constant and use the Expr-Expr methods
Expr lt(float a, Expr b);
Expr eq(float a, Expr b);
Expr gt(float a, Expr b);
Expr ge(float a, Expr b);
Expr ne(float a, Expr b);
Expr le(float a, Expr b);

Expr lt(Expr a, float b);
Expr eq(Expr a, float b);
Expr gt(Expr a, float b);
Expr ge(Expr a, float b);
Expr ne(Expr a, float b);
Expr le(Expr a, float b);

Expr dot(Expr a,
         Expr b,
         bool transA = false,
         bool transB = false,
         float scalar = 1.f);

Expr bdot(Expr a,
          Expr b,
          bool transA = false,
          bool transB = false,
          float scalar = 1.f);

Expr bdot_legacy(Expr a,
                 Expr b,
                 bool transA = false,
                 bool transB = false,
                 float scalar = 1.f);

Expr affine(Expr a,
            Expr b,
            Expr bias,
            bool transA = false,
            bool transB = false,
            float scalar = 1.f);

Expr affineWithRelu(Expr a,
                    Expr b,
                    Expr bias,
                    bool transA = false,
                    bool transB = false,
                    float scalar = 1.f);

Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);

Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB = false);

Expr transpose(Expr a);

Expr transpose(Expr a, const std::vector<int>& axes);

Expr swapAxes(Expr x, int axis1, int axis2);

Expr cast(Expr a, Type type = Type::float32);

Expr concatenate(const std::vector<Expr>& concats, int ax = 0);

Expr repeat(Expr a, size_t repeats, int ax = 0);

Expr reshape(Expr a, Shape shape);

Expr clip(Expr a, float c);

Expr clipGradient(Expr a, float c);

Expr atleast_1d(Expr a);

Expr atleast_2d(Expr a);

Expr atleast_3d(Expr a);

Expr atleast_4d(Expr a);

Expr atleast_nd(Expr a, size_t dims);

static inline Expr constant_like(Expr a, const Ptr<inits::NodeInitializer>& init) {
  return a->graph()->constant(a->shape(), init, a->value_type());

template<typename ElementType>
Expr constant_like(Expr a, const std::vector<ElementType>& v) { return constant_like(a, inits::fromVector(std::move(v))); }

template<typename ElementType>
Expr constant_like(Expr a, std::vector<ElementType>&& v) { return constant_like(a, inits::fromVector(v)); }

Expr flatten(Expr a);

Expr flatten_2d(Expr a);

Expr stopGradient(Expr a);

Expr gather(Expr a, int axis, Expr indices);

Expr scatter(Expr a, int axis, Expr indices, Expr source);

#if 0
 // reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis.
 // with broadcasting

 auto knn = get<0>(KNN->apply(query, k)); // [beam, time, batch, k]

 auto W = reshape(gather(Wt_, -2, flatten(knn)), {beam * time * batch, k, dim});
 auto b = reshape(gather(b_,  -1, flatten(knn)), {beam * time * batch, 1, k });
 query       = reshape(query, {beam * time * batch, 1, dim});
 auto logits = bdot(query, W, false, true); // [beam * time * batch, 1, k]
 logits      = reshape(logits + b, {beam, time, batch, k}); // @TODO: add baffine node

 auto shape = indices.shape();
 shape.set(-1, 32000);
 auto output = grep->constant(shape, inits::lowest(), logits->value_type());
 output = scatter(output, -1, indices, logits);

 // auto a = graph->constant({2,2,5,32000}, inits::fromValue(minimal))
 // scatter(a, -1, indices, values)
 // PyTorch does for out-of-place scatter: out = a.scatter(-1, indices, values)
Expr scatter(Expr a, int axis, Expr indices, Expr b);


Expr index_select(Expr a, int axis, Expr indices);

Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);

static inline Expr rows(Expr a, Expr indices) {
  return index_select(a, 0, indices);

static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
  return index_select(a, 0, indexVector);

static inline Expr cols(Expr a, Expr indices) {
  return index_select(a, -1, indices);

static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
  return index_select(a, -1, indexVector);

Expr slice(Expr a, int axis, Slice slice);

static inline Expr slice(Expr a, int axis, int index) {
  return slice(a, axis, Slice(index));

static inline Expr narrow(Expr a, int axis, size_t start, size_t length) {
  return slice(a, axis, Slice((int)start, (int)(start + length)));


// Aggregations

Expr sum(Expr a, int ax = 0);

Expr mean(Expr a, int ax = 0);

Expr std(Expr a, int ax);

Expr var(Expr a, int ax);

Expr max(Expr a, int ax);

Expr min(Expr a, int ax);

Expr prod(Expr a, int ax);

Expr logsumexp(Expr a, int ax);

Expr softmax(Expr x, int axis = -1);

Expr softmax(Expr a, Expr zeroOneMask, int axis = -1);

Expr logsoftmax(Expr a);

Expr cross_entropy(Expr a, Expr b, float labelSmoothingAlpha = 0.f, Type outputType = Type::float32);

Expr unlikelihood(Expr a, Expr b);

Expr scalar_product(Expr a, Expr b, int ax = 0);

Expr weighted_average(Expr in, Expr weights, int ax = 0);

Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);

Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);

Expr highway(Expr y, Expr x, Expr t);

Expr highway(const std::string prefix, Expr x);

static inline Expr dropout(Expr x, Expr mask) {
  if (mask)
    return x * mask;
    return x;

static inline Expr dropout(Expr x, float dropProb, Shape shape) {
  if(dropProb == 0)
    return x;
  auto graph = x->graph();
  auto mask = graph->dropoutMask(dropProb, shape);
  return dropout(x, mask);

static inline Expr dropout(Expr x, float dropProb) {
  if(dropProb == 0)
    return x;
  return dropout(x, dropProb, x->shape());

Expr shift(Expr x, Shape shift, float padValue = 0);

Expr convert2cudnnFormat(Expr x);

Expr convertFromcudnnFormat(Expr x);

Expr avg_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight = 0,
                 int padWidth = 0,
                 int strideHeight = 1,
                 int strideWidth = 1);

Expr max_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight = 0,
                 int padWidth = 0,
                 int strideHeight = 1,
                 int strideWidth = 1);

Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven = false);

}  // namespace marian