Program Listing for File element.cu¶
↰ Return to documentation for file (src/tensors/gpu/element.cu
)
#include "tensors/gpu/element.h"
#include "functional/array.h"
#include "functional/functional.h"
#include "functional/tensor.h"
#include "functional/tmp.h"
#include "tensors/gpu/cuda_helpers.h"
namespace marian {
namespace gpu {
template <size_t K, bool broadcast, class Functor, typename T>
__global__ void gElement(
Functor functor,
functional::Array<functional::Tensor<T>, K> tensors) {
int length = tensors[0].shape().elements();
functional::Array<int, functional::Shape::size()> dims;
functional::Array<int, K> indices;
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
indices.fill(index);
if(broadcast) {
tensors[0].shape().dims(index, dims);
for(int i = 1; i < K; ++i)
indices[i] = tensors[i].shape().bindex(dims);
}
// This performs the internal application of the functor in float32 regardless of the input type.
// It seems there are no speed penalties but improved precision.
tensors[0].data()[index] = (T)functional::applyWithCast<float>(functor, tensors, indices);
}
}
}
template <typename T, class Functor, class... Tensors>
void ElementTyped(Functor functor, Tensor out, Tensors... tensors) {
//matchOrAbort<T>(out->type()); // @TODO: figure out undefined reference
cudaSetDevice(out->getDeviceId().no);
int length = out->shape().elements();
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
constexpr size_t K = sizeof...(tensors) + 1;
functional::Array<functional::Tensor<T>, K> gTensors = {out, tensors...};
bool broadcast = false;
for(int i = 1; i < K; ++i)
broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape();
if(broadcast)
gElement<K, true><<<blocks, threads>>>(functor, gTensors);
else
gElement<K, false><<<blocks, threads>>>(functor, gTensors);
}
template <class Functor, class... Tensors>
void Element(Functor functor, Tensor out, Tensors... tensors) {
checkCommonType(out, tensors...);
if(out->type() == Type::float32) {
ElementTyped<float>(functor, out, tensors...);
} else if(out->type() == Type::float16) {
#if COMPILE_FP16
ElementTyped<half>(functor, out, tensors...);
#else
ABORT("FP16 not supported with chosen current hardware or CUDA version");
#endif
} else if(out->type() == Type::float64) {
ElementTyped<double>(functor, out, tensors...);
} else {
ABORT("Type {} not yet supported", out->type());
}
}
#include "tensors/gpu/element.inc"
} // namespace gpu
} // namespace marian