.. _program_listing_file_src_functional_tmp.h: Program Listing for File tmp.h ============================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/functional/tmp.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp // TMP here stands for Template Meta-Programming #pragma once #include "functional/array.h" #include "functional/defs.h" #include "functional/tensor.h" namespace marian { namespace functional { // This struct and its specializations are never used directly, only through apply and applyWithCast below. template // K-ary application of Functor, elements are cast to AccType before application of Functor struct FApply {}; template struct FApply<1, Functor, AccType> { template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 1>& in, const functional::Array& indices) { return functor((AccType)in[0].data()[indices[0]]); // indices is an array of offsets into multiple tensors, index[i] corresponds in[i] based on up to arity K } template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 1>& in, int index) { return functor((AccType)in[0].data()[index]); } }; template struct FApply<2, Functor, AccType> { template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 2>& in, const functional::Array& indices) { return functor((AccType)in[0].data()[indices[0]], (AccType)in[1].data()[indices[1]]); } template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 2>& in, int index) { return functor((AccType)in[0].data()[index], (AccType)in[1].data()[index]); } }; template struct FApply<3, Functor, AccType> { template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 3>& in, const functional::Array& indices) { return functor((AccType)in[0].data()[indices[0]], (AccType)in[1].data()[indices[1]], (AccType)in[2].data()[indices[2]]); } template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 3>& in, int index) { return functor((AccType)in[0].data()[index], (AccType)in[1].data()[index], (AccType)in[2].data()[index]); } }; template struct FApply<4, Functor, AccType> { template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 4>& in, const functional::Array& indices) { return functor((AccType)in[0].data()[indices[0]], (AccType)in[1].data()[indices[1]], (AccType)in[2].data()[indices[2]], (AccType)in[3].data()[indices[3]]); } template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 4>& in, int index) { return functor((AccType)in[0].data()[index], (AccType)in[1].data()[index], (AccType)in[2].data()[index], (AccType)in[3].data()[index]); } }; template struct FApply<5, Functor, AccType> { template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 5>& in, const functional::Array& indices) { return functor((AccType)in[0].data()[indices[0]], (AccType)in[1].data()[indices[1]], (AccType)in[2].data()[indices[2]], (AccType)in[3].data()[indices[3]], (AccType)in[4].data()[indices[4]]); } template HOST_DEVICE_INLINE static AccType apply( Functor functor, functional::Array, 5>& in, int index) { return functor((AccType)in[0].data()[index], (AccType)in[1].data()[index], (AccType)in[2].data()[index], (AccType)in[3].data()[index], (AccType)in[4].data()[index]); } }; /******************************************************************************/ // Applying functor to sets of K tensors template HOST_DEVICE_INLINE ElementType apply(Functor functor, functional::Array, K>& in, const functional::Array& indices) { return FApply::apply(functor, in, indices); // functor is applied to same type as input ElementType, no casting required } template HOST_DEVICE_INLINE ElementType apply(Functor functor, functional::Array, K>& in, int index) { return FApply::apply(functor, in, index); // functor is applied to same type as input ElementType, no casting required } template HOST_DEVICE_INLINE AccType applyWithCast(Functor functor, functional::Array, K>& in, const functional::Array& indices) { return FApply::apply(functor, in, indices); // ElementType and AccType are potentially different, cast to AccType before applying functor. // This is useful when accumulating e.g. 16-bit into 32-bit and we want to case to 32-bit before // the functor is applied. L2-Norm is a good use-case since the square can be large. } template HOST_DEVICE_INLINE AccType applyWithCast(Functor functor, functional::Array, K>& in, int index) { return FApply::apply(functor, in, index); // ElementType and AccType are potentially different, cast to AccType before applying functor } /******************************************************************************/ // @TODO: Rename this. It is a reduction loop. template struct Loop { template HOST_DEVICE_INLINE static AccType result( Functor functor, AccType aggInit, AggFunctor aggFunctor, functional::Array, K>& in, const functional::Array& pAcc, const functional::Array& length, const functional::Array& dim) { AccType agg = aggInit; functional::Array acc; for(int i = 0; i < length[N - n]; ++i) { for(size_t j = 0; j < K; ++j) { acc[j] = pAcc[j] + (dim[N - n] + i) * in[j].shape().bstride(N - n); } agg = aggFunctor(agg, Loop::result(functor, aggInit, aggFunctor, in, acc, length, dim)); } return agg; } }; template struct Loop<1, N, K> { template HOST_DEVICE_INLINE static AccType result( Functor functor, AccType aggInit, AggFunctor aggFunctor, functional::Array, K>& in, const functional::Array& pAcc, const functional::Array& length, const functional::Array& dim) { AccType agg = aggInit; functional::Array acc; for(int i = 0; i < length[N - 1]; ++i) { for(size_t j = 0; j < K; ++j) { acc[j] = pAcc[j] + (dim[N - 1] + i) * in[j].shape().bstride(N - 1); } agg = aggFunctor(agg, applyWithCast(functor, in, acc)); } return agg; } }; template HOST_DEVICE_INLINE AccType loops(Functor functor, AccType aggInit, AggFunctor aggFunctor, functional::Array, K>& in, const functional::Array& length, const functional::Array& dim) { functional::Array acc = {0}; return Loop::result(functor, aggInit, aggFunctor, in, acc, length, dim); } } // namespace functional } // namespace marian