Program Listing for File auto_tuner.h¶
↰ Return to documentation for file (src/graph/auto_tuner.h
)
#pragma once
#include "common/timer.h"
#include <chrono>
#include <functional>
#include <memory>
#include <vector>
namespace marian {
class AutoTunerRecorder {
public:
virtual void start(size_t hash) = 0;
virtual void stop(size_t hash, bool) = 0;
};
template <typename Return, typename... Args>
class AutoTuner : public AutoTunerRecorder {
private:
typedef std::function<Return(Args...)> Algorithm;
// When the autotuner decides the fastest algorithm for a specific tensor operation (e.g. GEMM),
// the autotuner runs each algorithm at least this 'collectStatMax' number of times and
// collects the statistics.
const size_t collectStatMax = 50;
UPtr<timer::CPUTimer> timer_;
// This structure holds a hash key an algorithm function (e.g. int16, packed gemm, mkl gemm)
// for a specific operation size
// hash: a unique hash key for each operation size
// (e.g. m, n, k, transpose A, transpose B, bias size for GEMM)
// algorithm: a function that holds an algorithm
struct HashedAlgorithm {
size_t hash;
Algorithm algorithm;
};
// This structure represents the collected statistics.
// time: total accumulated time of this operator execution with the given algorithm
// runs: total time this algorithm was executed
struct Stat {
double time;
size_t runs;
};
std::unordered_map<size_t, Stat> stats_;
std::unordered_map<size_t, size_t> done_;
std::vector<HashedAlgorithm> algorithms_;
size_t choose() {
size_t best = 0;
double bestTime = std::numeric_limits<double>::max();
for(size_t i = 0; i < algorithms_.size(); ++i) {
auto doneIt = done_.find(algorithms_[i].hash);
if(doneIt != done_.end())
return doneIt->second;
auto it = stats_.find(algorithms_[i].hash);
if(it != stats_.end()) {
auto& stat = it->second;
// collect more stats
if(stat.runs < collectStatMax)
return i;
if(stat.time < bestTime) {
bestTime = stat.time;
best = i;
}
} else {
// collect more stats
return i;
}
}
for(auto& a : algorithms_)
done_[a.hash] = best;
return best;
}
public:
void insert(const HashedAlgorithm& ha) { algorithms_.push_back(ha); }
void clear() { algorithms_.clear(); }
Return run(Args... args) { return algorithms_[choose()].algorithm(args...); }
void start(size_t hash) override {
if(!timer_ && done_.count(hash) == 0)
timer_.reset(new timer::CPUTimer());
}
void stop(size_t hash, bool stop) override {
if(stop && done_.count(hash) == 0) {
timer_->stop();
auto seconds = timer_->elapsed();
auto it = stats_.find(hash);
if(it != stats_.end()) {
if(it->second.runs < collectStatMax) {
it->second.time += seconds;
it->second.runs += 1;
}
} else {
stats_.emplace(hash, Stat({seconds, 1}));
}
timer_.reset(nullptr);
}
}
};
} // namespace marian