Template Struct TypedSparseGemm

Struct Documentation

template<typename ElementType>
struct TypedSparseGemm

Public Static Functions

static cudaDataType getCudaDataType(const float *)
static cudaDataType getCudaDataType(const half *)
static void CSRProd(marian::Tensor C, Ptr<Allocator> allocator, const marian::Tensor &S_values, const marian::Tensor &S_indices, const marian::Tensor &S_offsets, const marian::Tensor &D, bool transS, ElementType beta)
static void CSRProd(marian::Tensor C, Ptr<Allocator> allocator, const marian::Tensor &S_values, const marian::Tensor &S_indices, const marian::Tensor &S_offsets, const marian::Tensor &D, bool transS, bool swapOperands, ElementType beta)