Program Listing for File prod_sparse.cpp¶
↰ Return to documentation for file (src/tensors/gpu/prod_sparse.cpp
)
#ifdef _MSC_VER
#pragma warning(disable: 4505) // warning C4505: '__float2half_rz': unreferenced local function has been removed (missing 'static inline')
#endif
#include <cublas_v2.h>
#include <cusparse.h>
// clang-format off
#include "tensors/gpu/prod.h"
#include "tensors/gpu/backend.h"
#include "tensors/gpu/cuda_helpers.h"
// clang-format on
// what a nightmare
#if CUDA_VERSION >= 11000
#include "tensors/gpu/prod_sparse_cu11.h"
#else
#include "tensors/gpu/prod_sparse_cu10.h"
#endif
namespace marian {
namespace gpu {
void CSRProd(marian::Tensor C,
Ptr<Allocator> allocator,
const marian::Tensor& S_values,
const marian::Tensor& S_indices,
const marian::Tensor& S_offsets,
const marian::Tensor& D,
bool transS,
bool swapOperands,
float beta) {
if(S_values->type() == Type::float32 && D->type() == Type::float32) {
TypedSparseGemm</*ElementType=*/float>::CSRProd(C, allocator, S_values, S_indices, S_offsets, D, transS, swapOperands, beta);
#if COMPILE_FP16
} else if(S_values->type() == Type::float16 && D->type() == Type::float16) {
TypedSparseGemm</*ElementType=*/half>::CSRProd(C, allocator, S_values, S_indices, S_offsets, D, transS, swapOperands, (half)beta);
#endif
} else {
ABORT("Types {} and {} are not supported for sparse GEMM operations", S_values->type(), D->type());
}
}
}
}