Program Listing for File element.h

Return to documentation for file (src/tensors/cpu/element.h)

#pragma once

#include "tensors/tensor.h"

namespace marian {
namespace cpu {

// Function in this header are supposed to execute element-wise operations
// (passed in as a Functor) on arbitrary numbers of tensors. The templates
// are required to implement correct broadcasting of operations across
// a fixed-at-compile-time but in principle arbitrary number of dimensions.

// @TODO: generalize to vector operations, possible using specializations

// single loop over outer dimension. Recursively creates nested loops
// down to inner dimension and to single elements. Since this is based
// on strides, it correctly broadcasts to all dimensions without additional
// computation.
// Compiler optimizes this to single construct with nested(?) loops.

namespace F = marian::functional;

template <size_t I = 0>
struct E {
  template <size_t numArg, class Functor, typename ElementType>
  static inline void element(
      const Functor& functor,
      F::Array<F::Tensor<ElementType>, numArg>& tensors,
      F::Array<int, numArg> indices) {
    const auto& shape = tensors[0].shape();

    // loop over outer-most dimension
    for(int i = 0; i < shape[I]; ++i) {
      // call loop for next-inner dimension
      E<I + 1>::element(functor, tensors, indices);

      // increase index for current dimension by stride or 0 if broadcasting.
      // bstride(i) is look-up value, either equal to stride if the
      // corresponding dim is larger 1 or 0 if the dim is 1.
      for(size_t k = 0; k < numArg; ++k) {
        //int stride = tensors[k].shape().stride(I);
        //indices[k] += stride == 1 ? 0 : stride;
        indices[k] += tensors[k].shape().bstride(I);
      }
    }
  }
};

// specialization for inner-most single element (recursive stopping criterion)
// using const reference for indices here to avoid copying. No loop.
template <>
struct E<F::Shape::size()> {
  template <size_t numArg, class Functor, typename ElementType>
  static inline void element(
      const Functor& functor,
      F::Array<F::Tensor<ElementType>, numArg>& tensors,
      const F::Array<int, numArg>& indices) {
    // just apply the function for all indexed elements across all tensors
    // @TODO: use converting operator[] on tensor
    tensors[0].data()[indices[0]] = F::apply(functor, tensors, indices);
  }
};

template <typename ElementType, class Functor, class... Tensors>
void element(const Functor& functor, marian::Tensor out, Tensors... tensors) {

  // Number of input tensors + 1 (output tensor)
  constexpr size_t argNum = sizeof...(tensors) + 1;
  // create and initialize indices to 0, one index per tensor
  F::Array<int, argNum> indices;
  indices.fill(0);

  // call elementwise operation going from outer-most dimension
  // to inner-most element.
  F::Array<F::Tensor<ElementType>, argNum> gTensors = {out, tensors...};
  E<0>::element(functor, gTensors, indices);
}

// Dispatch elementwise functions with float element type based on number of
// elements. If dividable by 8 and AVX2 is available (TODO: check this?) use
// AVX2 specific intrinsics. Similar for 4 and AVX. TODO: Add AVX512 support.
template <class Functor, class... Tensors>
void elementFloat(const Functor& functor, marian::Tensor out, Tensors... tensors) {
#ifndef __CUDACC__
  std::vector<marian::Tensor> ts({out, tensors...});
  bool div8 = true;
  bool div4 = true;

  for(auto t : ts) {
    if(t->shape()[-1] % 8 != 0)
      div8 = false;
    if(t->shape()[-1] % 4 != 0) {
      div4 = false;
      break;
    }
  }

  if(div8) {
    // std::cerr << "8: " << functor.to_string() << std::endl;
#ifdef __AVX__
    element<float32x8>(functor, out, tensors...);
    return;
#endif
  }

  if(div4) {
    // std::cerr << "4: " << functor.to_string() << std::endl;
    element<float32x4>(functor, out, tensors...);
    return;
  }
#endif
  // std::cerr << "1: " << functor.to_string() << std::endl;
  element<float>(functor, out, tensors...);
}

// main call to function executing element-wise operation
template <class Functor, class... Tensors>
void Element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
  switch(out->type()) {
    case Type::float32: elementFloat(functor, out, tensors...); break;
    //case Type::uint32:  element<uint32_t>(functor, out, tensors...); break;
    default: ABORT("Unsupported type for element-wise operation: {}", out->type()); break;
  }
}

}  // namespace cpu
}  // namespace marian