Program Listing for File io.cpp¶
↰ Return to documentation for file (src/common/io.cpp
)
#include "common/io.h"
#include "3rd_party/cnpy/cnpy.h"
#include "common/shape.h"
#include "common/types.h"
#include "common/binary.h"
#include "common/io_item.h"
namespace marian {
namespace io {
bool isNpz(const std::string& fileName) {
return fileName.size() >= 4
&& fileName.substr(fileName.length() - 4) == ".npz";
}
bool isBin(const std::string& fileName) {
return fileName.size() >= 4
&& fileName.substr(fileName.length() - 4) == ".bin";
}
void getYamlFromNpz(YAML::Node& yaml,
const std::string& varName,
const std::string& fileName) {
auto item = cnpy::npz_load(fileName, varName);
if(item->size() > 0)
yaml = YAML::Load(item->data());
}
void getYamlFromBin(YAML::Node& yaml,
const std::string& varName,
const std::string& fileName) {
auto item = binary::getItem(fileName, varName);
if(item.size() > 0)
yaml = YAML::Load(item.data());
}
void getYamlFromModel(YAML::Node& yaml,
const std::string& varName,
const std::string& fileName) {
if(io::isNpz(fileName)) {
io::getYamlFromNpz(yaml, varName, fileName);
} else if(io::isBin(fileName)) {
io::getYamlFromBin(yaml, varName, fileName);
} else {
ABORT("Unknown model file format for file {}", fileName);
}
}
void getYamlFromModel(YAML::Node& yaml,
const std::string& varName,
const void* ptr) {
auto item = binary::getItem(ptr, varName);
if(item.size() > 0)
yaml = YAML::Load(item.data());
}
// Load YAML from item
void getYamlFromModel(YAML::Node& yaml,
const std::string& varName,
const std::vector<Item>& items) {
for(auto& item : items) {
if(item.name == varName) {
yaml = YAML::Load(item.data());
return;
}
}
}
void addMetaToItems(const std::string& meta,
const std::string& varName,
std::vector<io::Item>& items) {
Item item;
item.name = varName;
// increase size by 1 to add \0
item.shape = Shape({(int)meta.size() + 1});
item.bytes.resize(item.shape.elements());
std::copy(meta.begin(), meta.end(), item.bytes.begin());
// set string terminator
item.bytes.back() = '\0';
item.type = Type::int8;
items.push_back(item);
}
void loadItemsFromNpz(const std::string& fileName, std::vector<Item>& items) {
auto numpy = cnpy::npz_load(fileName);
for(auto it : numpy) {
ABORT_IF(
it.second->fortran_order, "Numpy item '{}' is not stored in row-major order", it.first);
Shape shape;
shape.resize(it.second->shape.size());
for(size_t i = 0; i < it.second->shape.size(); ++i)
shape.set(i, (size_t)it.second->shape[i]);
Item item;
item.name = it.first;
item.shape = shape;
char npzType = it.second->type;
int wordSize = it.second->word_size;
if (npzType == 'f' && wordSize == 2) item.type = Type::float16;
else if(npzType == 'f' && wordSize == 4) item.type = Type::float32;
else if(npzType == 'f' && wordSize == 8) item.type = Type::float64;
else if(npzType == 'i' && wordSize == 1) item.type = Type::int8;
else if(npzType == 'i' && wordSize == 2) item.type = Type::int16;
else if(npzType == 'i' && wordSize == 4) item.type = Type::int32;
else if(npzType == 'i' && wordSize == 8) item.type = Type::uint64;
else if(npzType == 'u' && wordSize == 1) item.type = Type::uint8;
else if(npzType == 'u' && wordSize == 2) item.type = Type::uint16;
else if(npzType == 'u' && wordSize == 4) item.type = Type::uint32;
else if(npzType == 'u' && wordSize == 8) item.type = Type::uint64;
else ABORT("Numpy item '{}' type '{}' with size {} not supported", it.first, npzType, wordSize);
item.bytes.swap(it.second->bytes);
items.emplace_back(std::move(item));
}
}
std::vector<Item> loadItems(const std::string& fileName) {
std::vector<Item> items;
if(isNpz(fileName)) {
loadItemsFromNpz(fileName, items);
} else if(isBin(fileName)) {
binary::loadItems(fileName, items);
} else {
ABORT("Unknown model file format for file {}", fileName);
}
return items;
}
std::vector<Item> loadItems(const void* ptr) {
std::vector<Item> items;
binary::loadItems(ptr, items, false);
return items;
}
std::vector<Item> mmapItems(const void* ptr) {
std::vector<Item> items;
binary::loadItems(ptr, items, true);
return items;
}
// @TODO: make cnpy and our wrapper talk to each other in terms of types
// or implement our own saving routines for npz based on npy, probably better.
void saveItemsNpz(const std::string& fileName, const std::vector<Item>& items) {
std::vector<cnpy::NpzItem> npzItems;
for(auto& item : items) {
std::vector<unsigned int> shape(item.shape.begin(), item.shape.end());
char type;
if (item.type == Type::float16) type = cnpy::map_type(typeid(float)); // becomes 'f', correct size is given below
else if(item.type == Type::float32) type = cnpy::map_type(typeid(float));
else if(item.type == Type::float64) type = cnpy::map_type(typeid(double));
else if(item.type == Type::int8) type = cnpy::map_type(typeid(int8_t));
else if(item.type == Type::int16) type = cnpy::map_type(typeid(int16_t));
else if(item.type == Type::int32) type = cnpy::map_type(typeid(int32_t));
else if(item.type == Type::int64) type = cnpy::map_type(typeid(int64_t));
else if(item.type == Type::uint8) type = cnpy::map_type(typeid(uint8_t));
else if(item.type == Type::uint16) type = cnpy::map_type(typeid(uint16_t));
else if(item.type == Type::uint32) type = cnpy::map_type(typeid(uint32_t));
else if(item.type == Type::uint64) type = cnpy::map_type(typeid(uint64_t));
else ABORT("Other types ({}) not supported", item.type);
npzItems.emplace_back(item.name, item.bytes, shape, type, sizeOf(item.type));
}
cnpy::npz_save(fileName, npzItems);
}
void saveItems(const std::string& fileName, const std::vector<Item>& items) {
if(isNpz(fileName)) {
saveItemsNpz(fileName, items);
} else if(isBin(fileName)) {
binary::saveItems(fileName, items);
} else {
ABORT("Unknown file format for file {}", fileName);
}
}
} // namespace io
} // namespace marian