Program Listing for File expression_graph.cpp¶
↰ Return to documentation for file (src/graph/expression_graph.cpp
)
#include "graph/expression_graph.h"
#include "tensors/tensor_operators.h"
#include <sstream>
namespace marian {
ExpressionGraph::ExpressionGraph(bool inference)
: inferenceOnly_(inference),
backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId, Ptr<Device> device) {
if(!backend_) {
backend_ = BackendByDeviceId(deviceId, Config::seed);
auto params = New<Parameters>(defaultElementType_);
params->init(backend_);
paramsByElementType_[defaultElementType_] = params;
if(device)
tensors_ = New<Tensors>(backend_, device);
else
tensors_ = New<Tensors>(backend_);
}
}
Expr ExpressionGraph::add(Expr node) {
auto found = tensors_->findOrRemember(node);
if(found) {
return found;
} else {
node->setId(count_++);
// record in forward graph
nodesForward_.push_back(node);
// record in backward graph if training, and keep track of roots
if(!inferenceOnly_ && node->trainable()) {
nodesBackward_.push_back(node);
topNodes_.insert(node); // opportunistically record all new nodes as roots (gets removed once consumed)
}
if(topNodes_.count(node)) // only erase children of nodes with are themselves in the topNodes list
for(auto child : node->children())
topNodes_.erase(child); // this child is consumed and therefore not a root
return node;
}
}
// Call on every checkpoint in backwards order
void createSubtape(Expr node) {
auto subtape = New<std::list<Expr>>();
for(auto child : node->children()) {
if(child->isCheckpoint()) {
/* do not descend */
} else {
if(child->getSubtape()) {
/* already visited */
} else {
createSubtape(child);
subtape->splice(subtape->end(), *(child->getSubtape()));
}
}
}
if(!node->isCheckpoint())
subtape->push_back(node);
node->setSubtape(subtape);
}
void ExpressionGraph::forwardNext() {
// @TODO: check if allocation works properly
tensors_->clearShorttermMemory();
if(checkpointing_) {
for(auto top : topNodes_)
top->markCheckpoint();
auto it = nodesBackward_.rbegin();
while(it != nodesBackward_.rend()) {
auto v = *it;
if(v->isCheckpoint())
createSubtape(v);
it++;
}
// To avoid recomputation of range from last checkpoint to the top,
// turn all nodes on last subtape into checkpoints and clear subtape.
// @TODO: put this into special backprob function? Needs to know that we are done with adding nodes
for(auto top : topNodes_) {
if(top->getSubtape()) {
for(auto& node : *top->getSubtape())
node->markCheckpoint();
top->getSubtape()->clear();
}
}
}
forward(nodesForward_, /*finalPass=*/!checkpointing_); // if checkPointing, this is not final
}
void ExpressionGraph::forward(std::list<Expr>& forwardTape, bool finalPass) {
while(!forwardTape.empty()) {
auto v = forwardTape.front();
v->allocate();
v->init();
for(auto& child : v->children())
ABORT_IF(!child->val(), "De-allocated child {} {} of {} {}", child->getId(), child->type(), v->getId(), v->type());
v->forward();
if(v->trainable() && throwNaN_) {
bool isNaN = false, isInf = false;
checkNaN(v->val(), isNaN, isInf);
if(isNaN || isInf) {
LOG(critical, "Detected NaN ({}) or Inf ({}) in value (forward pass)", isNaN, isInf);
LOG(critical, "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
v->type(), v->shape(), v->name(), v->getId(), v->hash());
LOG(critical, "Children: {}", v->children().size());
for(auto&& child : v->children()) {
LOG(critical, "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
child->type(), child->shape(), child->name(), child->getId(), child->hash());
}
}
}
if(v->marked_for_debug()) {
Logger log = spdlog::get("general");
if(log) {
LOG(info, "Debug: {} op={}", v->debug_message(), v->type());
LOG(info, v->val()->debug());
}
else {
std::cerr << "Debug: " << v->debug_message() << " op=" << v->type() << std::endl;
std::cerr << v->val()->debug() << std::endl;
}
}
if(inferenceOnly_)
v->children().clear();
// If checkpointing is disabled, keep the memory for forward signals for all nodes.
// If checkpointing is enabled:
// (a) In the forward pass before the backward pass, free the memory for the nodes in the subtape to save memory.
// (b) In the forward calls during the backward pass, keep the memory in the current subtape to accelerate
// gradient computation.
if(checkpointing_ && !finalPass) {
auto subtape = v->getSubtape();
if(subtape) {
for(auto& node : *subtape) {
node->free();
}
}
}
forwardTape.pop_front();
}
}
void ExpressionGraph::backward(bool reset, float clipValue) {
if(topNodes_.size() > 1) {
LOG(info, "There are more ({}) than one top most nodes for backward pass:", topNodes_.size());
for(auto node : topNodes_) {
LOG(info,
"\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
node->type(),
node->shape(),
node->name(),
node->getId(),
node->hash());
}
ABORT("Aborting");
}
// allocates memory and initialises gradients for parameters
for(auto kvParams : paramsByElementType_) {
kvParams.second->allocateBackward();
if(reset)
kvParams.second->set_zero_adjoint();
}
// for top nodes: allocates memory and initialise gradients to 1
for(auto&& v : topNodes_)
v->init_dependent();
topNodes_.clear();
tensors_->clearShorttermMemory();
bool firstNaN = true;
while(!nodesBackward_.empty()) {
auto v = nodesBackward_.back(); // return the last element
nodesBackward_.pop_back(); // remove the last element
// for non-top nodes: allocates memory and initialises gradients to 0
for(auto&& child : v->children())
if(child->trainable() && child->type() != "param")
child->set_zero_adjoint();
// if using gradient checkpointing,
// recompute the forward pass from checkpoint to the root
if(checkpointing_ && v->getSubtape()) {
forward(*v->getSubtape(), /*finalPass=*/true);
}
if(v->trainable() && v->marked_for_debug()) {
Logger log = spdlog::get("general");
if(log) {
LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
LOG(info, v->grad()->debug());
}
else {
std::cerr << "Debug Grad: " << v->debug_message() << " op=" << v->type() << std::endl;
std::cerr << v->grad()->debug() << std::endl;
}
}
if(v->trainable() && clipValue != 0) {
using namespace functional;
Element(_1 = clip(_1, clipValue), v->grad());
}
if(v->trainable())
v->backward();
if(throwNaN_ && firstNaN) {
for(auto&& child : v->children()) {
if(child->trainable()) {
bool isNaN = false, isInf = false;
checkNaN(child->grad(), isNaN, isInf);
if(isNaN) {
LOG(critical, "Detected NaN ({}) or Inf ({}) in gradient (backward pass) of child node", isNaN, isInf);
LOG(critical, "Child - Type: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
child->type(), child->shape(), child->name(), child->getId(), child->hash());
LOG(critical, "Parent - Type: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
v->type(), v->shape(), v->name(), v->getId(), v->hash());
firstNaN = false;
}
}
}
}
v->children().clear();
}
}
Expr ExpressionGraph::dropoutMask(float prob, const Shape& shape, Type valueType) {
return constant(shape, inits::dropout(prob), valueType);
}
Expr ExpressionGraph::dropoutMask(float prob, const Shape& shape) {
return constant(shape, inits::dropout(prob), defaultElementType_);
}
void ExpressionGraph::checkNaN(Tensor t, bool& isNaN, bool& isInf) {
IsNaN(t, allocator(), isNaN, isInf);
}
void ExpressionGraph::save(std::vector<io::Item>& ioItems, Type saveElementType) {
// sorted by type in std::map
for(auto kvParams : paramsByElementType_) {
// sorted by name in std::map
for(auto p : kvParams.second->getMap()) {
std::string pName = p.first;
if(!namespace_.empty()) {
if(pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
pName = pName.substr(namespace_.size() + 2);
}
Tensor val = p.second->val();
io::Item item;
val->get(item, pName);
item.convert(saveElementType);
ioItems.emplace_back(std::move(item));
}
}
}
} // namespace marian