Program Listing for File approx.h¶
↰ Return to documentation for file (src/functional/approx.h
)
#pragma once
#include "functional/defs.h"
namespace marian {
namespace functional {
// approximate any unary float function within range with
// piecewise linear functions in equal steps.
//
// Example:
// static Approx<10, 0, 100> approxSigmoid(stableSigmoid);
// float y = approxSigmoid(x);
//
// Creates a functor for range [-10,10] with piecewise linear
// approximations of a sigmoid, 100 pieces, step 0.2.
// This is quite fast on the CPU.
//
// approxSigmoid.grad(x) computes the corresponding gradient.
//
// When used as a local variable, use static keyword to create
// only once.
template <int radius = 5, int offset = 0, int pieces = 10>
struct Approx {
float a[pieces + 2];
float b[pieces + 2];
template <typename Function>
Approx(const Function& function) {
for(int i = 1; i <= pieces; ++i) {
float x0 = domain(i - 1);
float x1 = domain(i);
float y0 = function(x0);
float y1 = function(x1);
a[i] = (y1 - y0) / (x1 - x0);
b[i] = y0 - a[i] * x0;
// std::cerr << x0 << " " << x1 << " : " << a[i] << " " << b[i] << std::endl;
// std::cerr << y0 << " " << y1 << " - " << a[i] * x0 + b[i] << " " << a[i] * x1 + b[i] << std::endl;
}
a[0] = 0;
b[0] = function(domain(0));
a[pieces + 1] = 0;
b[pieces + 1] = function(domain(pieces));
// std::cerr << std::endl << radius << " " << pieces << std::endl;
// for(int i = 0; i < 100; i++) {
// float x = -6.f + i * (12.f/100);
// std::cerr << x << " -> " << index(x) << " " << operator()(x) << " " << function(x) << std::endl;
// }
}
HOST_DEVICE_INLINE int index(float x) const {
if(x <= -radius)
return 0;
if(x < radius) // +1 because 0 holds value for x < -radius
return int((x + radius - offset) / ((2.f * radius) / pieces) + 1);
return pieces + 1;
}
HOST_DEVICE_INLINE float domain(int i) const {
return i * ((2.f * radius) / pieces) + offset - radius;
}
HOST_DEVICE_INLINE float operator()(float x) const {
int i = index(x);
return a[i] * x + b[i];
}
HOST_DEVICE_INLINE float grad(float x) const {
int i = index(x);
return a[i];
}
};
} // namespace functional
} // namespace marian