Program Listing for File shape.h

Return to documentation for file (src/common/shape.h)

#pragma once

#include <algorithm>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

#include "common/hash.h"
#include "common/logging.h"

namespace marian {

struct Slice // Python-like slice/index descriptor
{
  Slice(int b, int e, int s) : begin(b), end(e), stride(s) {}
  Slice(int b, int e) : Slice(b, e, 1) {}
  Slice() : Slice(0, END) {}
  explicit Slice(int i) : Slice(i, i + 1) {}
  Slice(const Slice& other) : Slice(other.begin, other.end, other.stride) {}
  const Slice& operator=(const Slice& other) { begin = other.begin; end = other.end; stride = other.stride; return *this; }
  const Slice& operator=(int i) { begin = i; end = i + 1; stride = 1; return *this; }
  bool operator==(const Slice& other) const { return begin == other.begin && end == other.end && stride == other.stride; }
  bool operator!=(const Slice& other) const { return !(*this == other); }
  /*const*/ int begin, end, stride;
  static const int END = INT_MAX;
};
typedef std::vector<Slice> Slices;

struct Shape {
private:
  std::vector<int> shape_;

public:
  Shape() : shape_({1}) {}

  Shape(std::initializer_list<int> il) : Shape() {
    shape_.resize(il.size());
    std::copy(il.begin(), il.end(), begin());
  }

  Shape(std::vector<int>&& shape) : shape_(std::move(shape)) {}

  Shape(const Shape& shape) : Shape() {
    shape_.resize(shape.size());
    std::copy(shape.begin(), shape.end(), begin());
  }

  Shape& operator=(const Shape& p) = default;

  inline size_t size() const { return shape_.size(); }

  void resize(size_t n) { shape_.resize(n, 1); } // @TODO: this should respect shape semantics? Currently behaves like vector which is the wrong way around.

  const int* data() const { return shape_.data(); }
  int* data() { return shape_.data(); }

  inline void set(int    i, int val) { dim(i) = val; }
  inline void set(size_t i, int val) { dim(i) = val; }
  inline void set(int    i, size_t val) { dim(i) = (int)val; }
  inline void set(size_t i, size_t val) { dim(i) = (int)val; }

  inline int& dim(int i) {
    if(i >= 0) {
      ABORT_IF(i >= (int)size(),
               "Index {} is out of bounds, shape {} has {} dimension",
               i, std::string(*this), size());
      return shape_[i];
    } else {
      ABORT_IF((int)size() + i < 0,
               "Negative index {} is out of bounds, shape {} has {} dimension",
               i, std::string(*this), size());
      return shape_[size() + i];
    }
  }
  inline const int& dim(int i) const {
    return const_cast<Shape&>(*this).dim(i);
  }

  inline       int& dim(size_t i)       { return dim(int(i)); }
  inline const int& dim(size_t i) const { return dim(int(i)); }

  inline int operator[](int i) const { return dim(i); }
  inline int operator[](int i)       { return dim(i); }
  inline int operator[](size_t i) const { return dim(i); }
  inline int operator[](size_t i)       { return dim(i); }

  inline int back() const { return shape_.back(); }
  inline int& back() { return shape_.back(); }

  inline int stride(int i) const {
    std::vector<int> stride(shape_.size(), 1);
    for(int j = (int)shape_.size() - 2; j >= 0; --j)
      stride[j] = stride[j + 1] * shape_[j + 1];

    if(i >= 0)
      return stride[i];
    else
      return stride[size() + i];
  }

  template<typename T = int> // using a template so that FactoredSegmenter, which uses this as well, can pass size_t
  inline T elements() const {
    T el = 1;
    for(auto s : shape_)
      el *= (T)s;
    return el;
  }

  inline void dims(int i, std::vector<int>& d) const {
    d.resize(shape_.size());

    std::vector<int> stride(shape_.size(), 1);
    for(int j = (int)shape_.size() - 2; j >= 0; --j)
      stride[j] = stride[j + 1] * shape_[j + 1];

    for(size_t j = 0; j < d.size(); ++j)
      d[j] = (i / stride[j]) % shape_[j];
  }

  auto begin() -> decltype(shape_.begin()) { return shape_.begin(); }
  auto begin() const -> decltype(shape_.begin()) { return shape_.begin(); }

  auto end() -> decltype(shape_.end()) { return shape_.end(); }
  auto end() const -> decltype(shape_.end()) { return shape_.end(); }

  auto rbegin() -> decltype(shape_.rbegin()) { return shape_.rbegin(); }
  auto rbegin() const -> decltype(shape_.rbegin()) { return shape_.rbegin(); }

  auto rend() -> decltype(shape_.rend()) { return shape_.rend(); }
  auto rend() const -> decltype(shape_.rend()) { return shape_.rend(); }

  bool operator==(const Shape& other) const {
    return size() == other.size() && std::equal(begin(), end(), other.begin());
  }

  bool operator!=(const Shape& other) const { return !(*this == other); }

  std::string toString() const {
    std::stringstream strm;
    strm << "shape=" << (*this)[0];
    for(int i = 1; i < size(); ++i)
      strm << "x" << (*this)[i];
    strm << " size=" << elements();
    return strm.str();
  }

  friend std::ostream& operator<<(std::ostream& strm, const Shape& shape) {
    strm << shape.toString();
    return strm;
  }

  operator std::string() const {
    std::stringstream ss;
    ss << *this;
    return ss.str();
  }

  int axis(int ax) const {
    if(ax < 0)
      return (int)size() + ax;
    else
      return ax;
  }

  Slice slice(Slice slice, int ax) const { // interpret negative and special values in Slice
    int n = dim(ax);
    if (slice.begin < 0)
      slice.begin += n;
    if (slice.end < 0)
      slice.end += n;
    else if (slice.end == Slice::END)
      slice.end = n;
    return slice;
  }

  static Shape broadcast(const std::vector<Shape>& shapes) {
    size_t maxDims = 0;
    for(auto& s : shapes)
      if(s.size() > maxDims)
        maxDims = s.size();

    Shape shape;
    shape.resize(maxDims);

    for(auto& s : shapes) {
      for(int i = 1; i <= (int)s.size(); ++i) {
        ABORT_IF(shape[-i] != s[-i] && shape[-i] != 1 && s[-i] != 1,
                 "Shapes {} and {} cannot be broadcast",
                 (std::string)shape,
                 (std::string)s);
        shape.set(-i, std::max(shape[-i], s[-i]));
      }
    }
    return shape;
  }

  template <typename T>
  static Shape broadcast(const std::initializer_list<T>& il) {
    return broadcast(std::vector<T>(il));
  }

  template <typename T>
  static Shape broadcast(const std::vector<T>& nodes) {
    size_t maxDims = 0;
    for(auto& n : nodes)
      if(n->shape().size() > maxDims)
        maxDims = n->shape().size();

    Shape shape;
    shape.resize(maxDims);

    for(auto& node : nodes) {
      const Shape& shapen = node->shape();
      for(int i = 1; i <= (int)shapen.size(); ++i) {
        ABORT_IF(shape[-i] != shapen[-i] && shape[-i] != 1 && shapen[-i] != 1,
                 "Shapes {} and {} cannot be broadcasted",
                 (std::string)shape,
                 (std::string)shapen);
        shape.set(-i, std::max(shape[-i], shapen[-i]));
      }
    }
    return shape;
  }

  size_t hash() const {
    size_t seed = util::hash<int>()(shape_[0]);
    for(size_t i = 1; i < shape_.size(); ++i)
      util::hash_combine(seed, shape_[i]);
    return seed;
  }
};
}  // namespace marian