Program Listing for File expression_operators.cpp

Return to documentation for file (src/graph/expression_operators.cpp)

#include "graph/expression_operators.h"
#include "common/definitions.h"
#include "layers/constructors.h"

#include "graph/node_operators.h"
#include "graph/node_operators_binary.h"
#include "graph/node_operators_unary.h"
#include "graph/node_operators_tuple.h"

#include "graph/auto_tuner.h"
#include "tensors/cpu/intgemm_interface.h"
#include "tensors/cpu/fbgemm/expanded_gemm.h"

#if USE_FBGEMM
#include "fbgemm/Utils.h"
#endif

namespace marian {

Expr debug(Expr a, const std::string& message) {
  a->debug(message);
  return a;
}

Expr checkpoint(Expr a) {
  a->markCheckpoint();
  return a;
}

Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
            LambdaNodeFunctor fwd, size_t hash) {
  return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
}

Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
            LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
  return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
}

Expr callback(Expr node, LambdaNodeCallback call) {
  return Expression<CallbackNodeOp>(node, call);
}

// logistic function. Note: scipy name is expit()
Expr sigmoid(Expr a) {
  return Expression<SigmoidNodeOp>(a);
}

Expr relu(Expr a) {
  return Expression<ReLUNodeOp>(a);
}

Expr leakyrelu(Expr a) {
  return Expression<PReLUNodeOp>(0.01f, a);
}

Expr prelu(Expr a, float alpha) {
  return Expression<PReLUNodeOp>(alpha, a);
}

Expr clip(Expr a, float c) {
  if(c == 0)
    return a;
  else
    return Expression<ClipNodeOp>(a, c);
}

Expr log(Expr a) {
  return Expression<LogNodeOp>(a);
};

Expr exp(Expr a) {
  return Expression<ExpNodeOp>(a);
};

Expr sin(Expr a) {
  return Expression<SinNodeOp>(a);
};

Expr cos(Expr a) {
  return Expression<CosNodeOp>(a);
};

Expr tan(Expr a) {
  return Expression<TanNodeOp>(a);
};

Expr swish(Expr a) {
  return Expression<SwishNodeOp>(a);
}

Expr gelu(Expr a) {
  return Expression<SwishNodeOp>(a, 1.702f);
}

Expr operator-(Expr a) {
  return Expression<NegNodeOp>(a);
};

Expr softmax(Expr a, int axis /*=-1*/)
{
  // @TODO: move axis parameter down into the kernel
  if (axis != -1)
  {
    return swapAxes(softmax(swapAxes(a,
                                     axis, -1),
                            /*axis=*/-1),
                    axis, -1);
  }
  return Expression<SoftmaxNodeOp>(a);
}

Expr softmax(Expr a, Expr zeroOneMask, int axis /*=-1*/) {
  // This will return the smallest value / 2 for the input type converted to float
  // So for Type::Float16 that will be the smallest fp16 value expressed as float
  // We divide by 2 to allow for some tolerance and overflow protection.
  float smallestFloat = NumericLimits<float>(a->value_type()).lowest / 2.f;
  auto logMask = (1.f - zeroOneMask) * smallestFloat;
  return softmax(a + logMask, axis);
}

// @TODO: add mask
Expr logsoftmax(Expr a) {
  return Expression<LogSoftmaxNodeOp>(a);
}

/*********************************************************/

Expr operator+(Expr a, Expr b) {
  return Expression<PlusNodeOp>(a, b);
}

Expr operator-(Expr a, Expr b) {
  return Expression<MinusNodeOp>(a, b);
}

Expr operator*(Expr a, Expr b) {
  return Expression<MultNodeOp>(a, b);
}

Expr operator/(Expr a, Expr b) {
  return Expression<DivNodeOp>(a, b);
}

Expr logaddexp(Expr a, Expr b) {
  return Expression<LogAddExpNodeOp>(a, b);
}

Expr2 topk(Expr a, int k, int axis, bool descending) {
  // only supports topk along last dimension, hence transpose if required
  a = swapAxes(a, axis, -1);                              // non-op if axes are the same
  auto topkVal = Expression<TopKNodeOp>(a, k, -1, descending); // axis=-1 is OK now as we swapped
  auto topkIdx = std::dynamic_pointer_cast<TopKNodeOp>(topkVal)->tupleView(); // get a view on the top-k values
  return std::make_tuple(swapAxes(topkVal, axis, -1), swapAxes(topkIdx, axis, -1)); // non-op if axes are the same
}

Expr2 argmax(Expr a, int axis) {
  return topk(a, 1, axis, /*descending=*/true);
}

Expr2 argmin(Expr a, int axis) {
  return topk(a, 1, axis, /*descending=*/false);
}

Expr maximum(Expr a, Expr b) {
  return Expression<MaximumNodeOp>(a, b);
}

// @TODO: implement version without constant
Expr maximum(float a, Expr b) {
  auto aExpr = b->graph()->constant({}, inits::fromValue(a));
  return Expression<MaximumNodeOp>(aExpr, b);
}

Expr maximum(Expr a, float b) {
  return maximum(b, a);
}

Expr minimum(Expr a, Expr b) {
  return Expression<MinimumNodeOp>(a, b);
}

// @TODO: implement version without constant
Expr minimum(float a, Expr b) {
  auto aExpr = b->graph()->constant({}, inits::fromValue(a));
  return Expression<MinimumNodeOp>(aExpr, b);
}

Expr minimum(Expr a, float b) {
  return minimum(b, a);
}

Expr abs(Expr a) {
  return Expression<AbsNodeOp>(a);
}

Expr lt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1, false); }
Expr eq(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  0, false); }
Expr gt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  1, false); }
Expr ge(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1,  true); }
Expr ne(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  0,  true); }
Expr le(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  1,  true); }

Expr lt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, -1, false); }
Expr eq(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b,  0, false); }
Expr gt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b,  1, false); }
Expr ge(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, -1,  true); }
Expr ne(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b,  0,  true); }
Expr le(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b,  1,  true); }

Expr lt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), -1, false); }
Expr eq(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()),  0, false); }
Expr gt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()),  1, false); }
Expr ge(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), -1,  true); }
Expr ne(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()),  0,  true); }
Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()),  1,  true); }

/*********************************************************/

Expr operator+(Expr a, float b) {
  if (b == 0)
    return a;
  else
    return Expression<ScalarAddNodeOp>(a, b);
}

Expr operator+(float a, Expr b) {
  if (a == 0)
    return b;
  else
    return Expression<ScalarAddNodeOp>(b, a);
}

Expr operator-(Expr a, float b) {
  if (b == 0)
    return a;
  else
    return Expression<ScalarAddNodeOp>(a, -b);
}

Expr operator-(float a, Expr b) {
  if (a == 0)
    return -b;
  else
    return Expression<ScalarAddNodeOp>(-b, a);
}

Expr operator*(float a, Expr b) {
  if (a == 1.0f)
    return b;
  else
    return Expression<ScalarMultNodeOp>(b, a);
}

Expr operator*(Expr a, float b) {
  if (b == 1.0f)
    return a;
  else
    return Expression<ScalarMultNodeOp>(a, b);
}

Expr operator/(Expr a, float b) {
  return a * (1.f / b);
}

// TODO: efficient version of this without constant()
Expr operator/(float a, Expr b) {
  auto aExpr = b->graph()->constant({}, inits::fromValue(a));
  return aExpr / b;
}

// Expr pow(float a, Expr b) {
//  return Expression<Scalar1PowNodeOp>(a, b);
//
//}
//
// Expr pow(Expr a, float b) {
//  return Expression<Scalar2PowNodeOp>(a, b);
//
//}
//
// Expr pow(Expr a, Expr b) {
//  return Expression<PowNodeOp>(a, b);
//}

/*********************************************************/

Expr concatenate(const std::vector<Expr>& concats, int ax) {
  if(concats.size() == 1)
    return concats[0];
  return Expression<ConcatenateNodeOp>(concats, ax);
}

Expr repeat(Expr a, size_t repeats, int ax) {
  if(repeats == 1)
    return a;
  return concatenate(std::vector<Expr>(repeats, a), ax);
}

Expr reshape(Expr a, Shape shape) {
  if (a->shape() == shape)
    return a;
  return Expression<ReshapeNodeOp>(a, shape);
}

// @TODO: remove this if it turns out that we can train FP16 without that
Expr clipGradient(Expr a, float clipValue) {
  // don't create node if no clipping
  return clipValue != 0.f ? Expression<ClipGradientNodeOp>(a, clipValue) : a;
}

Expr atleast_1d(Expr a) {
  return atleast_nd(a, 1);
}

Expr atleast_2d(Expr a) {
  return atleast_nd(a, 2);
}

Expr atleast_3d(Expr a) {
  return atleast_nd(a, 3);
}

Expr atleast_4d(Expr a) {
  return atleast_nd(a, 4);
}

Expr atleast_nd(Expr a, size_t dims) {
  if(a->shape().size() >= dims)
    return a;

  Shape nShape;
  nShape.resize(dims);
  for(int i = 1; i <= (int)a->shape().size(); ++i)
    nShape.set(-i, a->shape()[-i]);

  return reshape(a, nShape);
}

Expr flatten(Expr a) {
  Shape shape = {a->shape().elements()};
  return Expression<ReshapeNodeOp>(a, shape);
}

Expr flatten_2d(Expr a) {
  Shape shape = {a->shape().elements() / a->shape()[-1], a->shape()[-1]};
  return Expression<ReshapeNodeOp>(a, shape);
}

Expr stopGradient(Expr a) {
  // implemented as a dummy reshape that is not trainable
  auto res = Expression<ReshapeNodeOp>(a, a->shape());
  res->setTrainable(false);
  return res;
}

// gather() -- gather arbitrary elements along an axis; batched or non-batched
Expr gather(Expr a, int axis, Expr indices) {
  return Expression<GatherNodeOp>(a, axis, indices);
}

// scatter() -- scatter arbitrary elements along an axis; batched or non-batched
// This is the reverse operation to gather.
Expr scatter(Expr a, int axis, Expr indices, Expr source) {
  return Expression<ScatterNodeOp>(a, axis, indices, source);
}


// index_select() -- gather arbitrary elements along an axis from an unbatched
// input 'a'. Indices are specified as a 1D vector.
// This is used e.g. for embedding lookup.
// Note: To use a batch of index vectors, reshape them into a single vector,
// call index_select(), then reshape the result back. Reshapes are cheap.
// This function has the same semantics as PyTorch operation of the same name.
Expr index_select(Expr a, int axis, Expr indices) {
  ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
  // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
  auto rank = a->shape().size();
  if (rank == 2) {
    if (axis == 0 || axis == -2)
      return Expression<RowsNodeOp>(a, indices);
    else if (axis == -1 || axis == 1)
      return Expression<ColsNodeOp>(a, indices);
  }
  // Delegate to gather() for any other axis or non-matrix input.
  Shape shape;
  shape.resize(a->shape().size());
  shape.set(axis, indices->shape()[0]);
  indices = reshape(indices, shape); // move index to axis
  return gather(a, axis, indices);
}

Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) {
  auto indexExpr = a->graph()->indices(indices);
  return index_select(a, axis, indexExpr);
}

static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather()
  ABORT_IF(slice.stride < 0, "Negative strides are not supported yet");
  ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they?
  std::vector<IndexType> indices;
  indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1);
  for (int i = slice.begin; i < slice.end; i += slice.stride)
    indices.push_back((IndexType)i);
  return gather(a, axis, a->graph()->indices(indices, a, axis));
}

static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive)
  return Expression<SliceViewNodeOp>(a, axis, slice);
}

// slice() -- gather a slice along an axis (step size > 1 allowed)
Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter
  const auto& shape = a->shape();
  axis  = shape.axis(axis);         // normalize negative axis
  slice = shape.slice(slice, axis); // normalize negative slice values
  if (slice.begin == 0 && slice.end == shape[axis] && slice.stride == 1)
    return a; // it's a no-op
#if 1 // until strided views are supported, non-consecutive slices are implemented via gather()
  if (slice.stride != 1)
    return sliceCopy(a, axis, slice);
  for (int i = 0; i < axis; ++i) {
    if (shape[i] != 1)  // this makes it non-consecutive
      return sliceCopy(a, axis, slice);
  }
#endif
  return sliceView(a, axis, slice);
}

Expr sum(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, sum of itself is a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
}

Expr mean(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, mean of itself is a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::mean);
}

Expr std(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, std(a) = 0
    return a - a;
  return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
}

Expr var(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
    return a - a;
  return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
}

Expr max(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, max of itself is a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::max);
}

Expr min(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, min of itself is a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::min);
}

Expr prod(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, prod of itself is a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::prod);
}

// log(sum(exp(a)))
Expr logsumexp(Expr a, int ax) {
  if(a->shape()[ax] == 1) // nothing to reduce, log(sum(exp(a))) = log(exp(a)) = a
    return a;
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::logSumExp);
}

Expr scalar_product(Expr a, Expr b, int ax) {
  return Expression<ScalarProductNodeOp>(a, b, ax);
}

Expr weighted_average(Expr in, Expr weights, int ax) {
  auto p = scalar_product(in, weights, ax);
  auto s = sum(weights, ax);
  return p / s;
}

Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
  auto device = a->graph()->getDeviceId().type;
  // added support for packed GEMM API (fp16, int8)
  Type aElementType = a->value_type();
  Type bElementType = b->value_type();

  // Currently only true when command line options
  // --optimize --cpu-thread=N with N > 0 are set.
  if(device == DeviceType::cpu) {
    if(isFloat(aElementType) && isFloat(bElementType)) {
      if(b->memoize() && (a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed ||
        a->graph()->getBackend()->getGemmType() == GemmType::FbInt8Packed)) {
#if USE_FBGEMM
        if(a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed) {
          auto packedB = cpu::variant::pack(
              marian::Type::packed16, b, cpu::variant::PackMatrix::B, transB);
          return cpu::variant::dot(marian::Type::packed16,
              a, packedB, b->shape(), transA, transB, scale);
        } else {
          float quantizeRange = b->graph()->getBackend()->getQuantizeRange();
          if(fbgemm::fbgemmHasAvx512Support()) {
            auto packedB = cpu::variant::pack(marian::Type::packed8avx512,
                                              b,
                                              cpu::variant::PackMatrix::B,
                                              transB,
                                              quantizeRange);
            return cpu::variant::dot(marian::Type::packed8avx512,
                a, packedB, b->shape(), transA, transB, scale);
          } else if(fbgemm::fbgemmHasAvx2Support()) {
            auto packedB = cpu::variant::pack(marian::Type::packed8avx2,
                                              b,
                                              cpu::variant::PackMatrix::B,
                                              transB,
                                              quantizeRange);
            return cpu::variant::dot(marian::Type::packed8avx2,
                a, packedB, b->shape(), transA, transB, scale);
          } else {
            ABORT(
                "AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed "
                "GEMM");
          }
        }
#else
        ABORT("Packed GEMM is not available in this build");
#endif  // USE_FBGEMM
      } else {
        return Expression<DotNodeOp>(
          a, b, transA, transB, scale);
      }
    } else if(isFloat(aElementType) && isIntgemm(bElementType)) {
      return cpu::integer::affineOrDot(a, b, nullptr, transA, transB, scale);
    } else if(isFloat(aElementType) && isPacked(bElementType)) {
#if USE_FBGEMM
      // 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
      // one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
      // It looks at the cpu register
      // (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
      // and this cpu lookup is executed only once and the state is kept in FBGEMM.
      if(fbgemm::fbgemmHasAvx2Support()) {
        // This variant of dot product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
        return cpu::variant::dot(b->value_type(),
                                 a,
                                 b,
                                 b->shape(),
                                 transA,
                                 transB,
                                 scale);
      } else {
        ABORT("AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed GEMM");
      }
#else
      ABORT("Packed GEMM is not available in this build");
#endif  // USE_FBGEMM
    } else {
      ABORT("Combination of types A: {} B: {} not supported", aElementType, bElementType);
    }
  } else {
    return Expression<DotNodeOp>(a, b, transA, transB, scale);
  }
}

Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
  return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}

Expr bdot_legacy(Expr a, Expr b, bool transA, bool transB, float scale) {
  return Expression<DotBatchedLegacyNodeOp>(a, b, transA, transB, scale);
}

Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
  // general version, MKL, CBlas or CUDA

  int rows = a->shape().elements() / a->shape()[-1];
  Expr ones = a->graph()->ones({ rows, 1 });
  std::vector<Expr> nodes = { a, b, bias, ones };
  return Expression<AffineNodeOp>(nodes, transA, transB, scale);
}

// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
// The last branch with auto-tuner is:
// youki/packed-model-pr-backup1031
// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
  auto device = a->graph()->getDeviceId().type;

  Type aElementType = a->value_type();
  Type bElementType = b->value_type();

  if(device == DeviceType::cpu) {
    if(isFloat(aElementType) && isFloat(bElementType)) {
      if(a->graph()->getBackend()->isOptimized()) {
        if(b->memoize() && (a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed ||
          a->graph()->getBackend()->getGemmType() == GemmType::FbInt8Packed)) {
#if USE_FBGEMM
          if(a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed) {
            auto packedB = cpu::variant::pack(
                marian::Type::packed16, b, cpu::variant::PackMatrix::B, transB);
            return cpu::variant::affine(marian::Type::packed16,
                a, packedB, b->shape(), bias, transA, transB, scale);
          } else {
            float quantizeRange = b->graph()->getBackend()->getQuantizeRange();
            if(fbgemm::fbgemmHasAvx512Support()) {
              auto packedB = cpu::variant::pack(marian::Type::packed8avx512,
                                                b,
                                                cpu::variant::PackMatrix::B,
                                                transB,
                                                quantizeRange);
              return cpu::variant::affine(marian::Type::packed8avx512,
                  a, packedB, b->shape(), bias, transA, transB, scale);
            } else if(fbgemm::fbgemmHasAvx2Support()) {
              auto packedB = cpu::variant::pack(marian::Type::packed8avx2,
                                                b,
                                                cpu::variant::PackMatrix::B,
                                                transB,
                                                quantizeRange);
              return cpu::variant::affine(marian::Type::packed8avx2,
                  a, packedB, b->shape(), bias, transA, transB, scale);
            } else {
              ABORT(
                  "AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed "
                  "GEMM");
            }
          }
#else
          ABORT("Packed GEMM is not available in this build");
#endif  // USE_FBGEMM
        } else {
          return affineDefault(a, b, bias, transA, transB, scale);
        }
      } else {
        return affineDefault(a, b, bias, transA, transB, scale);
      }
    } else if(isFloat(aElementType) && isIntgemm(bElementType)) {
      return cpu::integer::affineOrDot(a, b, bias, transA, transB, scale);
    } else if(isFloat(aElementType) && isPacked(bElementType)) {
#if USE_FBGEMM
      // 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
      // one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
      // It looks at the cpu register
      // (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
      // and this cpu lookup is executed only once and the state is kept in FBGEMM.
      if(fbgemm::fbgemmHasAvx2Support()) {
        // This variant of affine product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
        return cpu::variant::affine(b->value_type(),
                                    a,
                                    b,
                                    b->shape(),
                                    bias,
                                    transA,
                                    transB,
                                    scale);
      } else {
        ABORT("AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed GEMM");
      }
#else
      ABORT("Packed GEMM is not available in this build");
#endif  // USE_FBGEMM
    } else {
      ABORT("Combination of types A: {} B: {} not supported", aElementType, bElementType);
    }
  } else {
    // Default GEMM
    ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
             "GPU-based GEMM only supports float types, you have A: {} and B: {}",
             aElementType, bElementType);
    return affineDefault(a, b, bias, transA, transB, scale);
  }
}

Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
  auto graph = a->graph();

  if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
    return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
  else
    return relu(affine(a, b, bias, transA, transB, scale));
}

// @TODO: Not a great place to check this
#if CUDA_VERSION < 11000
// multiply a CSR matrix A with a matrix B
// A[i,j] is at A_values[A_offsets[i]+k], where k is position of j in A_indices[A_offsets[i]:A_offsets[i+1]]
// @TODO: Define a proper sparse tensor type.
Expr csr_dot(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA /*= false*/) {
  if(A_values->value_type() == Type::float16)
    LOG_ONCE(warn, "Using very slow version of sparse matrix operations with explicity cast to {}. Use CUDA 11.0 or higher.", Type::float16);
  return cast(Expression<CSRDotNodeOp>(A_shape, cast(A_values, Type::float32), A_indices, A_offsets, cast(B, Type::float32), transA, /*swapOperands=*/false), A_values->value_type());
}

// multiply a matrix A with a CSR matrix B
// @TODO: Define a proper sparse tensor type.
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB /*= false*/) {
  if(B_values->value_type() == Type::float16)
    LOG_ONCE(warn, "Using very slow version of sparse matrix operations with explicity cast to {}. Use CUDA 11.0 or higher.", Type::float16);
  return cast(Expression<CSRDotNodeOp>(B_shape, cast(B_values, Type::float32), B_indices, B_offsets, cast(A, Type::float32), transB, /*swapOperands=*/true), B_values->value_type());
}
#else
// multiply a CSR matrix A with a matrix B
// A[i,j] is at A_values[A_offsets[i]+k], where k is position of j in A_indices[A_offsets[i]:A_offsets[i+1]]
// @TODO: Define a proper sparse tensor type.
Expr csr_dot(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA /*= false*/) {
  // @TODO: implement this without cast
  return Expression<CSRDotNodeOp>(A_shape, A_values, A_indices, A_offsets, B, transA, /*swapOperands=*/false);
}

// multiply a matrix A with a CSR matrix B
// @TODO: Define a proper sparse tensor type.
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB /*= false*/) {
  return Expression<CSRDotNodeOp>(B_shape, B_values, B_indices, B_offsets, A, transB, /*swapOperands=*/true);
}
#endif


// swap the last two axes
// @TODO: change to swapAxes(a, -1, -2)
Expr transpose(Expr a) {
  std::vector<int> axes(a->shape().size());
  for(int i = 0; i < axes.size(); ++i) {
    axes[i] = i;
  }
  if(axes.size() > 1) {
    axes[axes.size() - 1] = (int)axes.size() - 2;
    axes[axes.size() - 2] = (int)axes.size() - 1;
  }
  return Expression<TransposeNodeOp>(a, axes);
}

Expr transpose(Expr a, const std::vector<int>& axes) {
  return Expression<TransposeNodeOp>(a, axes);
}

Expr swapAxes(Expr x, int axis1, int axis2)
{
  const auto& shape = x->shape();
  axis1 = shape.axis(axis1);
  axis2 = shape.axis(axis2);
  if (axis1 == axis2)
    return x;
  if (shape[axis1] == 1 || shape[axis2] == 1) { // can we use a reshape instead?
    if (axis1 > axis2)
      std::swap(axis1, axis2);
    bool canReshape = true;
    for (int ax = axis1 + 1; ax < axis2 && canReshape; ax++)
      canReshape &= (shape[ax] == 1);
    if (canReshape) {
      auto newShape = shape;
      newShape.set(axis1, shape[axis2]);
      newShape.set(axis2, shape[axis1]);
      //LOG(info, "SwapAxes() did a reshape from {} to {}", shape.toString(), newShape.toString());
      return reshape(x, newShape);
    }
  }
  // TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
  std::vector<int> axes(shape.size());
  for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota()
    axes[i] = i;
  std::swap(axes[axis1], axes[axis2]);
  return transpose(x, axes);
}

Expr cast(Expr a, Type type) {
  if(a->value_type() == type) {
    return a; // it's the correct type already, so nothing to do here
  } else {
    return Expression<CastNodeOp>(a, type);
  }
}

Expr cross_entropy(Expr logits, Expr indices, float labelSmoothingAlpha, Type outputType) {
  return Expression<CrossEntropyNodeOp>(logits, indices, labelSmoothingAlpha, outputType);
}

// Unlikelihood loss based on https://arxiv.org/abs/1908.04319
Expr unlikelihood(Expr logits, Expr indices) {
  int dimBatch = logits->shape()[-2];
  int dimTime  = logits->shape()[-3];

  // @TODO: fix this outside of this function in decoder.h etc.
  auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});

  // This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy
  return -log(gather(1.f - softmax(logits), /*axis=*/-1, indicesWithLayout));
}

Expr plus(const std::vector<Expr>& nodes) {
  ABORT_IF(nodes.size() > 1, "Not implemented");
  return nodes[0];
}

Expr swish(const std::vector<Expr>& nodes) {
  ABORT_IF(nodes.size() > 1, "Not implemented");
  return swish(nodes[0]);
}

Expr gelu(const std::vector<Expr>& nodes) {
  ABORT_IF(nodes.size() > 1, "Not implemented");
  return gelu(nodes[0]);
}

Expr tanh(const std::vector<Expr>& nodes) {
  return Expression<TanhNodeOp>(nodes);
}

Expr sigmoid(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr relu(const std::vector<Expr>& nodes) {
  ABORT_IF(nodes.size() > 1, "Not implemented");
  return relu(nodes[0]);
}

Expr leakyrelu(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr prelu(const std::vector<Expr>&, float /*alpha*/) {
  ABORT("Not implemented");
}

Expr sqrt(Expr a, float eps) {
  return Expression<SqrtNodeOp>(a, eps);
}

Expr square(Expr a) {
  return Expression<SquareNodeOp>(a);
}

Expr layerNorm(Expr x,
               Expr gamma,
               Expr beta /*= nullptr*/,
               float eps /*= 1e-9*/) {

  // layerNorm accumulates in float, so small eps is fine
  std::vector<Expr> nodes = {x, gamma};
  if(beta)
    nodes.push_back(beta);
  return Expression<LayerNormalizationOp>(nodes, eps);
}

Expr rmsNorm(Expr x,
             Expr gamma,
             Expr beta /*= nullptr*/,
             float eps /*= 1e-9*/) {

  // layerNorm accumulates in float, so small eps is fine
  std::vector<Expr> nodes = {x, gamma};
  if(beta)
    nodes.push_back(beta);
  return Expression<RMSNormalizationOp>(nodes, eps);
}

Expr highway(Expr y, Expr x, Expr t) {
  std::vector<Expr> nodes = {y, x, t};
  return Expression<HighwayNodeOp>(nodes);
}

Expr highway(const std::string prefix, Expr x) {
  // clang-format off
  size_t outDim = x->shape()[-1];
  auto graph = x->graph();
  auto g = mlp::dense()
      ("prefix", prefix + "_highway_d1")
      ("dim", outDim)
      ("activation", (int)mlp::act::sigmoid)
      .construct(graph)->apply(x);
  auto relued = mlp::dense()
      ("prefix", prefix + "_highway_d2")
      ("dim", outDim)
      ("activation", (int)mlp::act::ReLU)
      .construct(graph)->apply(x);
  return (g * relued) + ((1 - g) * x);
  // clang-format on
}

Expr shift(Expr a, Shape shift, float padValue) {
  return Expression<ShiftNodeOp>(a, shift, padValue);
}

#ifdef CUDA_FOUND
#ifdef CUDNN

Expr avg_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight,
                 int padWidth,
                 int strideHeight,
                 int strideWidth) {
  return Expression<PoolingOp>(
      x, height, width, padHeight, padWidth, strideHeight, strideWidth, "avg");
}

Expr max_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight,
                 int padWidth,
                 int strideHeight,
                 int strideWidth) {
  return Expression<PoolingOp>(
      x, height, width, padHeight, padWidth, strideHeight, strideWidth, "max");
}

Expr convert2cudnnFormat(Expr x) {
  int numWords = x->shape()[0];
  int numExamples = x->shape()[1];
  int embSize = x->shape()[2];

  std::vector<IndexType> newIndeces;
  for(int b = 0; b < numExamples; ++b) {
    for(int t = 0; t < numWords; ++t) {
      newIndeces.push_back((t * numExamples) + b);
    }
  }

  auto xRows = reshape(x, {x->shape()[0] * x->shape()[1], x->shape()[2]});

  Shape outShape({numExamples, 1, numWords, embSize});
  return reshape(rows(xRows, newIndeces), outShape);
}

Expr convertFromcudnnFormat(Expr x) {
  int batchDim = x->shape()[0];
  int sentenceDim = x->shape()[2];
  int embSize = x->shape()[3];

  auto reshapedX = reshape(x, {batchDim * sentenceDim, embSize});

  std::vector<IndexType> newIndeces;
  for(int t = 0; t < sentenceDim; ++t) {
    for(int b = 0; b < batchDim; ++b) {
      newIndeces.push_back(b * sentenceDim + t);
    }
  }

  Shape shape({batchDim, sentenceDim, embSize});
  return reshape(rows(reshapedX, newIndeces), shape);
}

Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) {
  return Expression<PoolingWithMaskingOp>(x, mask, width, isEven);
}

#endif
#endif
}  // namespace marian