Class ConvolutionWrapper

Inheritance Relationships

Base Type

Class Documentation

class ConvolutionWrapper : public marian::CUDNNWrapper

Public Functions

ConvolutionWrapper(const Shape &kernelShape, const Shape &biasShape, int hPad = 1, int wPad = 1, int hStride = 1, int wStride = 1)
void getOutputShape(const Shape &xShape, Shape &shape)
~ConvolutionWrapper()
void forward(Tensor x, Tensor Kernels, Tensor bias, Tensor y)
void backward(Tensor x, Tensor xGrad, Tensor kernels, Tensor kernelGrad, Tensor biasGrad, Tensor yGrad)