Program Listing for File intgemm_interface.h

Return to documentation for file (src/tensors/cpu/intgemm_interface.h)

#pragma once

#include "graph/node.h"
#include "graph/node_operators_unary.h"
#include "integer_common.h"

namespace marian {

namespace cpu {
namespace integer {

#if COMPILE_CPU
/*
 * Prepare an activation matrix into intgemm8/16 format. For now the activation matrix is just quantized.
 * Expr input: The input tensor
 */
template<Type vtype>
static inline Expr prepareA(Expr a) {
  auto nodeOp = [](Expr out, const std::vector<Expr>& children) {
    Expr in = children[0];
    auto quantMult = computeQuantMult<vtype>(in->val());
    typedef typename intgemm_<vtype>::type Integer;
    intgemm_<vtype>::width::PrepareA(in->val()->data(), /*input*/
                                     out->val()->data<Integer>(), /*output*/
                                     quantMult, /*Quant Mult*/
                                     rows(in->val()),
                                     cols(in->val()));
    getQuantMult<vtype>(out->val()) = quantMult;
  };

  return lambda({a}, a->shape(), vtype, nodeOp);
}
#endif

/*
 * This computes A*B (+ bias if available) in intgemm.
 * Expr a: The activation matrix in intgemm format
 * Expr b: The parameter matrix in intgemm fromat
 * Expr bias: The bias
 * bool transA - tranpose input A if true
 * bool transB - unused here (@TODO remove?)
 * float scale - scale the output by `scale`
 * the template argument controls whether we're doing 16bit integers or 8bit integers.
 * It can be Type::intgemm8 or Type::intgemm16 and all hardware-specific variants
 */
template<Type vtype>
static inline Expr affineOrDotTyped(Expr a, Expr bQuant, Expr bias, bool transA, bool /*transB*/, float scale) {
#if COMPILE_CPU
  ABORT_IF(!isFloat(a->value_type()), "Intgemm expects type of A to be float32 not {}", a->value_type());
  ABORT_IF(!isIntgemm(bQuant->value_type()), "Intgemm expects type of B to be a variant of intgemm not {}", bQuant->value_type());

  auto aQuant = prepareA<vtype>(transA ? transpose(a) : a); // A should not be quantized yet as seen above, hence quantize here

  // determine the output shape m x n for A: m x k and B: k x n
  // since we transpose A beforehand we don't need to take care of transposed shapes here
  Shape outShape = aQuant->shape();
  outShape.set(-1, bQuant->shape()[-1]);

  // wrap the multiply finctions to be executed in the forward step of a Lambda node
  auto dotOrAffineNodeOp = [=](Expr out, const std::vector<Expr>& children) {
    Expr aQuant = children[0];
    Expr bQuant = children[1];
    Expr bias   = children.size() > 2 ? children[2] : nullptr;

    // when we arrive here, A and B are already quantized, so just get the multipliers
    float aQuantMult = getQuantMult<vtype>(aQuant->val());
    float bQuantMult = getQuantMult<vtype>(bQuant->val());

    float unquant_mult = 1.0f / (aQuantMult * bQuantMult);
    unquant_mult = unquant_mult * scale;

    typedef typename intgemm_<vtype>::type Integer;
    if(bias) { // dispatch a multiply with integrated bias addition i.e affine(...)
      intgemm_<vtype>::width::Multiply(/*A=*/aQuant->val()->data<Integer>(),
                                       /*B=*/bQuant->val()->data<Integer>(),
                                       rows(aQuant->val()),
                                       cols(aQuant->val()),
                                       cols(bQuant->val()),
                                       intgemm::callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, /*bias=*/bias->val()->data(), /*output=*/out->val()->data()));
    } else { // dispatch a multiply without bias addition i.e dot(...)
      intgemm_<vtype>::width::Multiply(/*A=*/aQuant->val()->data<Integer>(),
                                       /*B=*/bQuant->val()->data<Integer>(),
                                       rows(aQuant->val()),
                                       cols(aQuant->val()),
                                       cols(bQuant->val()),
                                       intgemm::callbacks::UnquantizeAndWrite(unquant_mult, /*output=*/out->val()->data()));
    }
  };

  std::vector<Expr> children = {aQuant, bQuant};
  if(bias)
    children.push_back(bias);

  return lambda(children, outShape, Type::float32, dotOrAffineNodeOp); // inference-only Lambda node
#else
  a, bQuant, bias, transA, scale;
  ABORT("You need to enable CPU compilation to use this feature. Use cmake .. -DCOMPILE_CPU=ON");
#endif
}

// Dispatch correct hardware-agnostic or hardware-specific matrix multiplies
static inline Expr affineOrDot(Expr a, Expr bQuant, Expr bias, bool transA, bool transB, float scale) {
  Type bQuantElementType = bQuant->value_type();
  static const bool pass = cpu::integer::passOrAbort(bQuantElementType);
  pass; // We declare this variable as static so that passOrAbort is only ever run once during the initialization.
  switch(bQuantElementType) {
    //case Type::intgemm8 :  // The generic case selects CPU automatically, but we set all the types manually anyways.
    //  return cpu::integer::affineOrDotTyped<Type::intgemm8>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm8ssse3 :
      return cpu::integer::affineOrDotTyped<Type::intgemm8ssse3>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm8avx2 :
      return cpu::integer::affineOrDotTyped<Type::intgemm8avx2>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm8avx512 :
      return cpu::integer::affineOrDotTyped<Type::intgemm8avx512>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm8avx512vnni :
      return cpu::integer::affineOrDotTyped<Type::intgemm8avx512vnni>(a, bQuant, bias, transA, transB, scale);
    //case Type::intgemm16 :  // The generic case selects CPU automatically, but we set all the types manually anyways.
    //  return cpu::integer::affineOrDotTyped<Type::intgemm16>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm16sse2 :
      return cpu::integer::affineOrDotTyped<Type::intgemm16sse2>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm16avx2 :
      return cpu::integer::affineOrDotTyped<Type::intgemm16avx2>(a, bQuant, bias, transA, transB, scale);
    case Type::intgemm16avx512 :
      return cpu::integer::affineOrDotTyped<Type::intgemm16avx512>(a, bQuant, bias, transA, transB, scale);
    default:
      ABORT("Unsupported type {} for Intgemm type??", bQuantElementType);
  }
}

}  // namespace integer
}  // namespace cpu
}  // namespace marian