Program Listing for File cudnn_wrappers.h¶
↰ Return to documentation for file (src/tensors/gpu/cudnn_wrappers.h
)
#pragma once
#include <iostream>
#include "common/shape.h"
#include "tensors/tensor.h"
#ifdef CUDNN
#include <cudnn.h>
namespace marian {
class CUDNNWrapper {
public:
CUDNNWrapper();
virtual ~CUDNNWrapper();
protected:
void setCudnnTensor(cudnnTensorDescriptor_t& desc, Tensor x);
void setCudnnTensor(cudnnTensorDescriptor_t& desc, const Shape& shape);
protected:
cudnnHandle_t cudnnHandle_;
};
class ConvolutionWrapper : public CUDNNWrapper {
public:
ConvolutionWrapper(const Shape& kernelShape,
const Shape& biasShape,
int hPad = 0,
int wPad = 0,
int hStride = 1,
int wStride = 1);
void getOutputShape(const Shape& xShape, Shape& shape);
virtual ~ConvolutionWrapper();
void forward(Tensor x, Tensor Kernels, Tensor bias, Tensor y);
void backward(Tensor x,
Tensor xGrad,
Tensor kernels,
Tensor kernelGrad,
Tensor biasGrad,
Tensor yGrad);
protected:
void setConvDescriptor(int hPad, int wPad, int hStride, int wStride);
void setKernelDescriptor(const Shape& shape);
protected:
cudnnConvolutionDescriptor_t convDesc_;
cudnnFilterDescriptor_t kernelDesc_;
cudnnTensorDescriptor_t biasDesc_;
};
class PoolingWrapper : public CUDNNWrapper {
public:
PoolingWrapper(int height,
int width,
int padHeight,
int padWidth,
int strideHeight,
int strideWidth,
std::string mode);
void getOutputShape(const Shape& xShape, Shape& shape);
void forward(Tensor x, Tensor y);
void backward(Tensor x, Tensor xGrad, Tensor y, Tensor yGrad);
virtual ~PoolingWrapper();
protected:
void setPoolingDescriptor(int height,
int width,
int padHeight,
int padWidth,
int strideHeight,
int strideWidth);
protected:
cudnnPoolingDescriptor_t poolingDesc_;
cudnnPoolingMode_t poolingMode_;
};
} // namespace marian
#else
namespace marian {
class CUDNNWrapper {
public:
CUDNNWrapper();
virtual ~CUDNNWrapper();
};
class ConvolutionWrapper : public CUDNNWrapper {
public:
ConvolutionWrapper(const Shape& kernelShape,
const Shape& biasShape,
int hPad = 1,
int wPad = 1,
int hStride = 1,
int wStride = 1);
void getOutputShape(const Shape& xShape, Shape& shape);
virtual ~ConvolutionWrapper();
void forward(Tensor x, Tensor Kernels, Tensor bias, Tensor y);
void backward(Tensor x,
Tensor xGrad,
Tensor kernels,
Tensor kernelGrad,
Tensor biasGrad,
Tensor yGrad);
};
class PoolingWrapper : public CUDNNWrapper {
public:
PoolingWrapper(int height,
int width,
int padHeight,
int padWidth,
int strideHeight,
int strideWidth,
std::string mode);
void getOutputShape(const Shape& xShape, Shape& shape);
void forward(Tensor x, Tensor y);
void backward(Tensor x, Tensor xGrad, Tensor y, Tensor yGrad);
virtual ~PoolingWrapper();
};
} // namespace marian
#endif