Program Listing for File expanded_gemm.h

Return to documentation for file (src/tensors/cpu/fbgemm/expanded_gemm.h)

#pragma once

#include "graph/node.h"
#include "packed_gemm.h"
#include "tensors/cpu/integer_common.h"

#if USE_FBGEMM
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
#endif

#include "3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h"

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

using namespace fbgemm;
// @TODO: don't use using namespace ...; in header files. Just don't. [UG]
#endif  // USE_FBGEMM

namespace marian {
namespace cpu {
namespace variant {

// Enumeration for the Matrix used in pack functions
// A matrix - 0, B matrix - 1
enum class PackMatrix : uint8_t {
  A = 0x00,
  B = 0x01
};

// Pack a matrix (fp16) into cache utilization efficient way (block format) together with quantization into fp16
// PackMatrix packMat_: the type of packed matrix - A or B matrix
// bool transpose_: transpose
// int nrow_: the number of rows
// int ncol_: the number of columns
// int kernel_ncol_blocks_: the number of column blocks
// int brow_: the number of rows in a block
// int bcol_: the number of columns in a block
// int last_brow_: the number of rows in the last block
// int nbrow_: row index in a block
// int nbcol_: column index in a block
// uint64_t packsize_: the size of the packed matrix
//                    (the number of fp16 elements + padding (1024) + extra temporary memory (256))
struct FbgemmPacked16PackNodeOp : public UnaryNodeOp {
  PackMatrix packMat_;
  bool transpose_;
  int nrow_;
  int ncol_;
  int kernel_ncol_blocks_;
  int brow_;
  int bcol_;
  int last_brow_;
  int nbrow_;
  int nbcol_;
  uint64_t packsize_;

  FbgemmPacked16PackNodeOp(Expr a, PackMatrix packMat, bool transpose)
      : UnaryNodeOp(a, newShape(a, transpose), Type::uint8),
        packMat_(packMat),
        transpose_(transpose) {
    if(packMat != PackMatrix::B)
      ABORT("Only prepacking of B (weight matrix) is supported");
    if(!memoize_)
      ABORT("Only constant weight node can be packed");
  }

  NodeOps forwardOps() override {
#if USE_FBGEMM
    return {NodeOp(fbgemmPacked16Pack(val_,
                                      child(0)->val()->data(),
                                      transpose_,
                                      nrow_,
                                      ncol_,
                                      kernel_ncol_blocks_,
                                      brow_,
                                      bcol_,
                                      last_brow_,
                                      nbrow_,
                                      nbcol_,
                                      packsize_))
    };
#else // USE_FBGEMM
    ABORT("FbgemmPacked16PackNodeOp can only be used with FBGEMM enabled.");
    return { NodeOp(0) };
#endif  // USE_FBGEMM
  }

  NodeOps backwardOps() override {
    ABORT("FbgemmPacked16PackNodeOp only available for inference");
    return {NodeOp(0)};
  }

  const std::string type() override { return "packMatFp16"; }

  Shape newShape(Expr a, bool transpose) {
#if USE_FBGEMM
    auto shapeMat = a->shape();
    // Should be 2D - weight matrix
    ABORT_IF(shapeMat.size() != 2,
             "Weight Matrix should be 2D");
    fbgemmPacked16PackInfo(shapeMat,
                           transpose,
                           nrow_,
                           ncol_,
                           kernel_ncol_blocks_,
                           brow_,
                           bcol_,
                           last_brow_,
                           nbrow_,
                           nbcol_,
                           packsize_);

    Shape outShape({(int)packsize_});
    return outShape;
#else
    a; transpose;
    ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
    return Shape();
#endif  // USE_FBGEMM
  }
};

// Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8
// PackMatrix packMat_: the type of packed matrix - A or B matrix
// marian::Type packType_: the type the input matrix is packed - packed8avx2 or packed8avx512
// bool transpose_: transpose
// int nrow_: the number of rows
// int ncol_: the number of columns
// uint64_t packsize_: the size of the packed matrix
//                    (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point)
struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
  PackMatrix packMat_;
  marian::Type packType_;
  bool transpose_;
  int nrow_;
  int ncol_;
  uint64_t packsize_;
  float quantizeRange_;

  FbgemmPacked8PackNodeOp(Expr a,
                          PackMatrix packMat,
                          marian::Type packType,
                          bool transpose,
                          float quantizeRange)
      : UnaryNodeOp(a, newShape(a, packType, transpose), Type::uint8),
        packMat_(packMat),
        packType_(packType),
        transpose_(transpose),
        quantizeRange_(quantizeRange){
    if(packMat != PackMatrix::B)
      ABORT("Only prepacking of B (weight matrix) is supported");
    if(!memoize_)
      ABORT("Only constant weight node can be packed");
  }

  NodeOps forwardOps() override {
#if USE_FBGEMM
    return {NodeOp(fbgemmPacked8Pack(val_,
                                     child(0)->val()->data(),
                                     packType_,
                                     transpose_,
                                     nrow_,
                                     ncol_,
                                     packsize_,
                                     quantizeRange_))
    };
#else // USE_FBGEMM
    ABORT("FbgemmPacked8PackNodeOp can only be used with FBGEMM enabled.");
    return { NodeOp(0) };
#endif  // USE_FBGEMM
  }

  NodeOps backwardOps() override {
    ABORT("FbgemmPacked8PackNodeOp only available for inference");
    return {NodeOp(0)};
  }

  const std::string type() override { return "packMatInt8"; }

#if USE_FBGEMM
  Shape newShape(Expr a, marian::Type packType, bool transpose) {
    fbgemmPacked8PackInfo(
        a->shape(),
        packType,
        transpose,
        nrow_,
        ncol_,
        packsize_);
    Shape outShape({(int)packsize_});
    return outShape;
  }
#else
  Shape newShape(Expr /*a*/, marian::Type /*packType*/, bool /*transpose*/) {
    ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
    return Shape();
  }
#endif  // USE_FBGEMM
};


// Affine transform (matrix multiplication) using packed B matrix
// float scalar_: scalar multiplier
// size_t m_: the number of rows in A and C
// size_t n_: the number of columns in B and C
// size_t k_: the number of columns in A and the number of rows in C
// bool transA_: transpose A
// bool transB_: transpose B
class FbgemmPacked16AffineNodeOp : public NaryNodeOp {
private:
  size_t m_;
  size_t n_;
  size_t k_;
  bool transA_;
  bool transB_;

public:
  FbgemmPacked16AffineNodeOp(const std::vector<Expr>& nodes, Shape bShape, bool transA, bool transB, float /*scalar*/)
    : NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32)/*, scalar_(scalar)*/ {
    transA_ = transA;
    transB_ = transB;
    m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
    k_ = nodes[0]->shape().back();
    if(transA)
      std::swap(m_, k_);

    size_t l = bShape.elements() / bShape[-1];
    n_ = bShape[-1];
    if(transB)
      std::swap(l, n_);
  }

  Shape newShape(Expr a, Shape bShape, bool transA, bool transB) {
    auto shapeA = a->shape();
    if(transA) {
      shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
      shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
    }

    auto shapeB = bShape;
    if(transB) {
      shapeB.set(shapeB.size() - 2, bShape[shapeB.size() - 1]);
      shapeB.set(shapeB.size() - 1, bShape[shapeB.size() - 2]);
    }

    Shape outShape = shapeA;
    outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
    ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
             "Matrix product requires inner dimensions to match");
    return outShape;
  }

  NodeOps forwardOps() override {
#if USE_FBGEMM
    return {
      NodeOp(fbgemmPacked16Gemm(val_,
                                child(0)->val(),
                                child(1)->val(),
                                children().size() > 2 ? child(2)->val() : nullptr, // pass only if it has a bias
                                m_,
                                n_,
                                transA_))
    };
#else // USE_FBGEMM
    ABORT("FbgemmPacked16AffineNodeOp can only be used with FBGEMM enabled.");
    return { NodeOp(0) };
#endif  // USE_FBGEMM
  }

  NodeOps backwardOps() override {
    ABORT("Only used for inference");
    return {NodeOp(0)};
  }

  const std::string type() override { return "gemmPacked16"; }
};

// Affine transform (matrix multiplication) using packed B matrix
// Especially, this gemm performs quantized gemms in 8-bit integers.
// float scalar_: scalar multiplier
// size_t m_: the number of rows in A and C
// size_t n_: the number of columns in B and C
// size_t k_: the number of columns in A and the number of rows in C
// bool transA_: transpose A
// bool transB_: transpose B
class FbgemmPacked8AffineNodeOp : public NaryNodeOp {
private:
  size_t m_;
  size_t n_;
  size_t k_;
  bool transA_;
  bool transB_;

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-private-field"
#endif

  Type elementType_;

#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif


public:
  FbgemmPacked8AffineNodeOp(Type elementType,
                            const std::vector<Expr>& nodes,
                            Shape bShape,
                            bool transA,
                            bool transB,
                            float /*scalar*/)
      : NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32),
        elementType_(elementType) {
    transA_ = transA;
    transB_ = transB;
    m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
    k_ = nodes[0]->shape().back();
    if(transA)
      std::swap(m_, k_);

    size_t l = bShape.elements() / bShape[-1];
    n_ = bShape[-1];
    if(transB)
      std::swap(l, n_);
  }

  Shape newShape(Expr a, Shape bShape, bool transA, bool transB) {
    auto shapeA = a->shape();
    if(transA) {
      shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
      shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
    }

    auto shapeB = bShape;
    if(transB) {
     shapeB.set(shapeB.size() - 2, bShape[shapeB.size() - 1]);
     shapeB.set(shapeB.size() - 1, bShape[shapeB.size() - 2]);
    }

    Shape outShape = shapeA;
    outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
    ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
            "Matrix product requires inner dimensions to match");
    return outShape;
  }

  NodeOps forwardOps() override {
    NodeOps nodeOps;
#if USE_FBGEMM
    // Do addBias only if it has a bias term
    if (children().size() > 2) {
      nodeOps = { NodeOp(fbgemmPacked8Gemm(elementType_,
                                           val_,
                                           child(0)->val(),
                                           child(1)->val(),
                                           m_,
                                           n_,
                                           k_,
                                           transA_,
                                           transB_);
                       marian::cpu::integer::AddBias(val_, child(2)->val())) };
    } else {
      nodeOps = { NodeOp(fbgemmPacked8Gemm(elementType_,
                                           val_,
                                           child(0)->val(),
                                           child(1)->val(),
                                           m_,
                                           n_,
                                           k_,
                                           transA_,
                                           transB_)) };
    }
#else // USE_FBGEMM
    ABORT("FbgemmPacked8AffineNodeOp can only be used with FBGEMM enabled.");
#endif  // USE_FBGEMM

    return nodeOps;
  }

  NodeOps backwardOps() override {
    ABORT("Only used for inference");
    return {NodeOp(0)};
  }

  const std::string type() override { return "gemmPacked8"; }
};

static inline Expr affine(Type elementType,
                          Expr a,
                          Expr b,
                          Shape bShape,
                          Expr c,
                          bool transA,
                          bool transB,
                          float scalar) {
  std::vector<Expr> nodes = {a, b, c};

  if (elementType == Type::packed16)
    return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
  else if (isPacked(elementType) && sizeOf(elementType) == 1)
    return Expression<cpu::variant::FbgemmPacked8AffineNodeOp>(
        elementType, nodes, bShape, transA, transB, scalar);
  else {
    ABORT("Only int8 and fp16 are available. {}", elementType);
    return nullptr;
  }
}

static inline Expr pack(Type elementType, Expr a, PackMatrix packMat, bool transpose, float quantizeRange = 0.f) {
  if (elementType == Type::packed16)
    return Expression<FbgemmPacked16PackNodeOp>(a, packMat, transpose);
  else if (isPacked(elementType) && sizeOf(elementType) == 1)
    return Expression<cpu::variant::FbgemmPacked8PackNodeOp>(a, packMat, elementType, transpose, quantizeRange);
  else {
    ABORT("Only int8 and fp16 are available. {}", elementType);
    return nullptr;
  }
}

static inline Expr dot(Type elementType, Expr a, Expr b, Shape bShape, bool transA, bool transB, float scalar) {
  std::vector<Expr> nodes = {a, b};

  if (elementType == Type::packed16)
    return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
  else if (isPacked(elementType) && sizeOf(elementType) == 1)
    return Expression<cpu::variant::FbgemmPacked8AffineNodeOp>(
        elementType, nodes, bShape, transA, transB, scalar);
  else {
    ABORT("Only int8 and fp16 are available. {}", elementType);
    return nullptr;
  }
}

}  // namespace variant
}  // namespace cpu
}  // namespace marian