Program Listing for File dataset.h¶
↰ Return to documentation for file (src/examples/mnist/dataset.h
)
#pragma once
#include <algorithm>
#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "common/definitions.h"
#include "common/shape.h"
#include "data/batch.h"
#include "data/dataset.h"
#include "data/vocab.h"
namespace marian {
namespace data {
typedef std::vector<float> Data;
typedef std::vector<IndexType> Labels;
struct Example : public std::vector<Data> { // a std::vector<Data> with a getId() method
size_t id;
size_t getId() const { return id; }
Example(std::vector<Data>&& data, size_t id) : std::vector<Data>(std::move(data)), id(id) {}
Example() : id(SIZE_MAX) {}
};
typedef std::vector<Example> Examples;
typedef Examples::const_iterator ExampleIterator;
class Input {
private:
Shape shape_;
Ptr<Data> data_;
public:
typedef Data::iterator iterator;
typedef Data::const_iterator const_iterator;
Input(const Shape& shape) : shape_(shape), data_(new Data(shape_.elements(), 0.0f)) {}
Data::iterator begin() { return data_->begin(); }
Data::iterator end() { return data_->end(); }
Data::const_iterator begin() const { return data_->cbegin(); }
Data::const_iterator end() const { return data_->cend(); }
Data& data() { return *data_; }
Shape shape() const { return shape_; }
size_t size() const { return data_->size(); }
};
class DataBatch : public Batch {
private:
std::vector<Input> inputs_;
public:
std::vector<Input>& inputs() { return inputs_; }
const std::vector<Input>& inputs() const { return inputs_; }
void push_back(Input input) { inputs_.push_back(input); }
virtual std::vector<Ptr<Batch>> split(size_t /*n*/, size_t /*sizeLimit*/) override { ABORT("Not implemented"); }
Data& features() { return inputs_[0].data(); }
Data& labels() { return inputs_.back().data(); }
size_t size() const override { return inputs_.front().shape()[0]; }
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
ABORT("Guided alignment in DataBatch is not implemented");
}
void setDataWeights(const std::vector<float>&) override {
ABORT("Data weighting in DataBatch is not implemented");
}
};
class Dataset : public DatasetBase<Example, ExampleIterator, DataBatch>, public RNGEngine {
protected:
Examples examples_;
public:
Dataset(const std::vector<std::string>& paths, Ptr<Options> options)
: DatasetBase(paths, options) {}
virtual void loadData() = 0;
iterator begin() override { return ExampleIterator(examples_.begin()); }
iterator end() override { return ExampleIterator(examples_.end()); }
void shuffle() override { std::shuffle(examples_.begin(), examples_.end(), eng_); }
batch_ptr toBatch(const Examples& batchVector) override {
int batchSize = (int)batchVector.size();
std::vector<int> maxDims;
for(auto& ex : batchVector) {
if(maxDims.size() < ex.size())
maxDims.resize(ex.size(), 0);
for(size_t i = 0; i < ex.size(); ++i) {
if(ex[i].size() > (size_t)maxDims[i])
maxDims[i] = (int)ex[i].size();
}
}
batch_ptr batch(new DataBatch());
std::vector<Input::iterator> iterators;
for(auto& m : maxDims) {
batch->push_back(Shape({batchSize, m}));
iterators.push_back(batch->inputs().back().begin());
}
for(auto& ex : batchVector) {
for(size_t i = 0; i < ex.size(); ++i) {
Data d = ex[i];
d.resize(maxDims[i], 0.0f);
iterators[i] = std::copy(d.begin(), d.end(), iterators[i]);
}
}
return batch;
}
};
class MNISTData : public Dataset {
private:
const int IMAGE_MAGIC_NUMBER;
const int LABEL_MAGIC_NUMBER;
public:
MNISTData(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> /*vocabs*/ = {},
Ptr<Options> options = nullptr)
: Dataset(paths, options), IMAGE_MAGIC_NUMBER(2051), LABEL_MAGIC_NUMBER(2049) {
loadData();
}
virtual ~MNISTData(){}
void loadData() override {
ABORT_IF(paths_.size() != 2, "Paths to MNIST data files are not specified");
auto features = ReadImages(paths_[0]);
auto labels = ReadLabels(paths_[1]);
ABORT_IF(features.size() != labels.size(), "Features do not match labels");
for(size_t i = 0; i < features.size(); ++i) {
examples_.emplace_back(std::vector<Data>{ features[i], labels[i] }, i);
}
}
Example next() override { return Example(); } //@TODO: this return was added to fix a warning. Is it correct?
private:
typedef unsigned char uchar;
int reverseInt(int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255,
c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}
std::vector<Data> ReadImages(const std::string &full_path) {
std::ifstream file(full_path);
ABORT_IF(!file.is_open(), "Cannot open file `" + full_path + "`!");
int magic_number = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
ABORT_IF(magic_number != IMAGE_MAGIC_NUMBER, "Invalid MNIST image file!");
int number_of_images;
int n_rows = 0;
int n_cols = 0;
file.read((char *)&number_of_images, sizeof(number_of_images));
number_of_images = reverseInt(number_of_images);
file.read((char *)&n_rows, sizeof(n_rows));
n_rows = reverseInt(n_rows);
file.read((char *)&n_cols, sizeof(n_cols));
n_cols = reverseInt(n_cols);
int imgSize = n_rows * n_cols;
std::vector<Data> dataset(number_of_images);
for(int i = 0; i < number_of_images; ++i) {
dataset[i] = Data(imgSize, 0);
for(int j = 0; j < imgSize; j++) {
unsigned char pixel = 0;
file.read((char *)&pixel, sizeof(pixel));
dataset[i][j] = pixel / 255.0f;
}
}
return dataset;
}
std::vector<Data> ReadLabels(const std::string &full_path) {
std::ifstream file(full_path);
if(!file.is_open())
throw std::runtime_error("Cannot open file `" + full_path + "`!");
int magic_number = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if(magic_number != LABEL_MAGIC_NUMBER)
throw std::runtime_error("Invalid MNIST label file!");
int number_of_labels;
file.read((char *)&number_of_labels, sizeof(number_of_labels));
number_of_labels = reverseInt(number_of_labels);
std::vector<Data> dataset(number_of_labels);
for(int i = 0; i < number_of_labels; i++) {
dataset[i] = Data(1, 0.0f);
unsigned char label;
file.read((char *)&label, 1);
dataset[i][0] = label;
}
return dataset;
}
};
} // namespace data
} // namespace marian