Program Listing for File shape.h¶
↰ Return to documentation for file (src/functional/shape.h
)
#pragma once
#include <cstdint>
#include <string>
#include "common/shape.h"
#include "functional/array.h"
namespace marian {
namespace functional {
#define CONST_SHAPE_DIMS 4
// attempts at low-level slicing and proper views, not integrated yet
#if 0
const int MAX_INT = std::numeric_limits<int>::max();
struct Slice {
static const int END{MAX_INT}; // fix
int begin{0};
int end{END};
int stride{1};
Slice(int b, int e, int s = 1)
: begin(b), end(e), stride(s) {}
Slice()
: begin(0), end(END), stride(1) {}
Slice(int i)
: begin(i), end(i + 1), stride(1) {}
Slice(const std::initializer_list<int>& l) {
std::vector<int> v(l);
switch(v.size()) {
case 0: begin = 0; end = END; stride = 1; break;
case 1: begin = v[0]; end = v[0] + 1; stride = 1; break;
case 2: begin = v[0]; end = v[1]; stride = 1; break;
case 3: begin = v[0]; end = v[1]; stride = v[2]; break;
default:
ABORT("Too many elements in slice: {}", v.size());
}
}
};
const Slice All;
#endif
template <const int N>
struct ConstantShape {
Array<int, N> shape_;
Array<int, N> stride_;
Array<int, N> bstride_;
size_t elements_{1};
size_t offset_{0};
// @TODO: review all these constructors
HOST_DEVICE ConstantShape() {
shape_.fill(1);
stride_.fill(1);
bstride_.fill(0);
}
HOST_DEVICE ConstantShape(const ConstantShape& shape)
: shape_(shape.shape_),
stride_(shape.stride_),
bstride_(shape.bstride_),
elements_(shape.elements_),
offset_(shape.offset_) {}
template <size_t M>
HOST_DEVICE ConstantShape(const Array<int, M>& shape) {
ABORT_IF(M > N, "Recompile with CONST_SHAPE_DIMS >= {}", M);
std::copy(shape.begin(), shape.end(), shape_.begin() + N - M);
if(N - M)
std::fill_n(shape_.begin(), N - M, 1);
updateStrides();
updateElements();
}
HOST_DEVICE ConstantShape(const Array<int, N>& shape,
const Array<int, N>& stride,
size_t offset)
: shape_(shape), stride_(stride), offset_(offset) {
updateElements();
}
ConstantShape(const marian::Shape& shape) {
size_t filled = shape.size();
ABORT_IF(filled > N,
"Recompile with CONST_SHAPE_DIMS >= " + std::to_string(filled));
std::copy(shape.begin(), shape.end(), shape_.begin() + N - filled);
if(N - filled)
std::fill_n(shape_.begin(), N - filled, 1);
updateStrides();
updateElements();
}
// @TODO: do we need bstrides at all?
HOST_DEVICE_INLINE void updateStrides() {
stride_[N - 1] = 1;
bstride_[N - 1] = shape_[N - 1] == 1 ? 0 : stride_[N - 1];
for(int i = N - 2; i >= 0; --i) {
stride_[i] = stride_[i + 1] * shape_[i + 1];
bstride_[i] = shape_[i] == 1 ? 0 : stride_[i];
}
}
HOST_DEVICE_INLINE void updateElements() {
elements_ = 1;
for(int i = 0; i < N; ++i)
elements_ *= shape_[i];
}
HOST_DEVICE_INLINE void set(int i, int dim) {
shape_[i] = dim;
updateStrides();
updateElements();
}
HOST_DEVICE_INLINE const int& dim(int i) const { return shape_[i]; }
HOST_DEVICE_INLINE const int& back() const { return dim(N - 1); }
HOST_DEVICE_INLINE const int& operator[](int i) const { return dim(i); }
HOST_DEVICE_INLINE const int& stride(int i) const { return stride_[i]; }
HOST_DEVICE_INLINE const int& bstride(int i) const { return bstride_[i]; }
HOST_DEVICE_INLINE static constexpr size_t size() { return N; }
HOST_DEVICE_INLINE int elements() const { return (int)elements_; }
// The following functions iterate over shape dimensions and use recursive
// templates. They unroll over a compile-time defined number of dimensions.
// Struct for recurrent template calls over shape dimensions,
// version for K > 0
template <const int K, const int D> struct I {
HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
const Array<int, D>& stride) {
return dims[K] * stride[K] + I<K-1, D>::index(dims, stride);
}
HOST_DEVICE_INLINE static int index(int si,
const Array<int, D>& shape,
const Array<int, D>& stride) {
return (si % shape[K]) * stride[K] + I<K-1, D>::index(si / shape[K], shape, stride);
}
HOST_DEVICE_INLINE static void dims(int si,
Array<int, D>& dims,
const Array<int, D>& shape) {
dims[K] = si % shape[K];
I<K-1, D>::dims(si / shape[K], dims, shape);
}
};
// Struct for recurrent template calls over shape dimensions,
// specialization for K == 0
template <const int D> struct I<0, D> {
HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
const Array<int, D>& stride) {
return dims[0] * stride[0];
}
HOST_DEVICE_INLINE static int index(int si,
const Array<int, D>& shape,
const Array<int, D>& stride) {
return (si % shape[0]) * stride[0];
}
HOST_DEVICE_INLINE static void dims(int si,
Array<int, D>& dims,
const Array<int, D>& shape) {
dims[0] = si % shape[0];
}
};
HOST_DEVICE_INLINE int index(const Array<int, N>& dims) const {
return (int)offset_ + I<N-1, N>::index(dims, stride_);
}
HOST_DEVICE_INLINE int index(int si) const {
return (int)offset_ + I<N-1, N>::index(si, shape_, stride_);
}
HOST_DEVICE_INLINE void dims(int si, Array<int, N>& dims) const {
I<N-1, N>::dims(si, dims, shape_);
}
HOST_DEVICE_INLINE int bindex(const Array<int, N>& dims) const {
int i = 0;
// ?? : return offset_ + I<N-1, N>::index(d, bstride_);
for(int j = 0; j < N; ++j)
i += dims[j] * bstride_[j];
return i;
}
// @TODO: should this check all the members?
HOST_DEVICE_INLINE bool operator==(const ConstantShape& other) const {
for(int i = 0; i < N; ++i)
if(shape_[i] != other[i])
return false;
return true;
}
HOST_DEVICE_INLINE bool operator!=(const ConstantShape& other) const {
return !(*this == other);
}
std::string toString() const {
std::stringstream strm;
// @TODO: add more information
strm << "shape=" << (*this)[0];
for(int i = 1; i < size(); ++i)
strm << "x" << (*this)[i];
strm << " size=" << elements();
return strm.str();
}
// @TODO: attempts at proper slicing. Works but not integrated anywhere. To be revisited.
#if 0
// Performs numpy-like slicing on a given shape object. The number
// of slices corresponds to the number of dimensions.
HOST_DEVICE_INLINE ConstantShape<N> slice(const Array<Slice, N>& slices) {
// @TODO: add various checks
Array<int, N> offsets;
Array<int, N> shape;
Array<int, N> stride;
for(int i = 0; i < N; ++i) {
int beg = slices[i].begin;
// restrict maximum value to actual shape size if larger than shape size
int end = slices[i].end < shape_[i] ? slices[i].end : shape_[i];
int str = slices[i].stride;
// collect starting points for all coordinates
offsets[i] = beg;
// when calculating the new shape, take into account stride
// TODO: std::ceil does not work on the GPU
shape[i] = std::ceil((end - beg) / (float) str);
// new stride is just old stride multiplied by slice stride
stride[i] = str * stride_[i];
}
// map offset coordinates into single offset index
int offset = index(offsets);
return ConstantShape<N>(shape, stride, offset);
}
// non-continguous slices cannot be reshaped! need to be copied
// template <const int D>
// HOST_DEVICE_INLINE ConstantShape<D> reshape(const ConstantShape<D>& other) const {
// // @TODO: add various checks
// #ifndef __CUDA__ARCH__
// ABORT_IF(elements() != other.elements(),
// "Reshaping operation requires matching number of elements");
// #endif
// Array<int, D> stride;
// for(int i = 0; i < D; ++i) {
// stride[i] = /*other.stride_[i] **/ stride_[i];
// }
// stride[D - 1] = stride_[N - 1];
// for(int i = 2; i < D + 1; ++i) {
// stride[D - i] = stride[D - i + 1] * stride_[N - i + 1] * shape_[D - i + 1];
// }
// return ConstantShape<D>(other.shape_, stride, offset_);
// }
#endif
friend std::ostream& operator<<(std::ostream& strm, const ConstantShape<N>& shape) {
strm << shape.toString();
return strm;
}
};
typedef ConstantShape<CONST_SHAPE_DIMS> Shape;
} // namespace functional
} // namespace marian