Program Listing for File io_item.h

Return to documentation for file (src/common/io_item.h)

#pragma once

#include "common/shape.h"
#include "common/types.h"

#include <string>

namespace marian {
namespace io {

struct Item {
  std::vector<char> bytes;
  const char* ptr{0};
  bool mapped{false};

  std::string name;
  Shape shape;
  Type type{Type::float32};

  const char* data() const {
    if(mapped)
      return ptr;
    else
      return bytes.data();
  }

  size_t size() const { // @TODO: review this again for 256-bytes boundary alignment
    return requiredBytes(shape, type);
  }

  // Extend this item with data and shape from the input item, creating a flattened concatenation.
  void append(const Item& other) {
    ABORT_IF(mapped, "Memory-mapped items cannot be appended");
    ABORT_IF(type != other.type, "Only item of same type can be appended");

    // abort if any of the shapes is not a flat array, i.e. the number of elements in the
    // last dimension has to correspond to the number of bytes.
    ABORT_IF(shape[-1] != shape.elements(), "1 - Only flat items can be appended : {}", shape);
    ABORT_IF(other.shape[-1] != other.shape.elements(), "2 - Only flat items can be appended: {}", other.shape);

    // cut to size (get rid of padding if any) to make append operation work correctly
    size_t bytesWithoutPadding = shape.elements() * sizeOf(type);
    bytes.resize(bytesWithoutPadding);

    shape.set(-1, shape.elements() + other.shape.elements());

    size_t addbytesWithoutPadding = other.shape.elements() * sizeOf(other.type); // ignore padding if any
    bytes.insert(bytes.end(), other.bytes.begin(), other.bytes.begin() + addbytesWithoutPadding);

    // grow to align to 256 bytes boundary (will be undone when more pieces are appended)
    size_t multiplier = (size_t)ceil((float)bytes.size() / (float)256);
    bytes.resize(multiplier * 256);
  }

  template <typename From, typename To>
  void convertFromTo() {
    size_t elements = size() / sizeof(From);
    size_t newSize = elements * sizeof(To);
    std::vector<char> newBytes(newSize);

    From* in = (From*)bytes.data();
    To* out = (To*)newBytes.data();
    for(int i = 0; i < elements; ++i)
      out[i] = (To)in[i];

    bytes.swap(newBytes);
  }

  template <typename T>
  void convertTo() {
    if(type == Type::float32)
      convertFromTo<float, T>();
    else if(type == Type::float16)
      convertFromTo<HalfFloat, T>();
    else
      ABORT("convert from type {} not implemented", type);
  }

  void convert(Type toType) {
    if(type == toType)
      return;

    if(toType == Type::float32)
      convertTo<float>();
    else if(toType == Type::float16)
      convertTo<float16>();
    else
      ABORT("convert to type {} not implemented", toType);

    type = toType;
  }
};

}  // namespace io
}  // namespace marian