Class ConvolutionWrapper¶
Defined in File cudnn_wrappers.h
Inheritance Relationships¶
Base Type¶
public marian::CUDNNWrapper
(Class CUDNNWrapper)
Class Documentation¶
-
class
ConvolutionWrapper
: public marian::CUDNNWrapper¶ Public Functions
-
ConvolutionWrapper
(const Shape &kernelShape, const Shape &biasShape, int hPad = 1, int wPad = 1, int hStride = 1, int wStride = 1)¶
-
~ConvolutionWrapper
()¶
-
void
forward
(Tensor x, Tensor Kernels, Tensor bias, Tensor y)¶
-
void
backward
(Tensor x, Tensor xGrad, Tensor kernels, Tensor kernelGrad, Tensor biasGrad, Tensor yGrad)¶
-