Program Listing for File node.cpp¶
↰ Return to documentation for file (src/graph/node.cpp
)
#include "graph/node.h"
#include "graph/auto_tuner.h"
#include "graph/expression_graph.h"
#include "tensors/backend.h"
namespace marian {
void Node::allocate() {
if(!val_) {
graph()->allocateForward(this);
}
}
void Node::free() {
if(destroy_) { // don't free views, @TODO: better naming
//std::cerr << "Freeing" << std::endl;
if(graph()) {
if(val_) {
graph()->free(val_);
val_ = nullptr;
}
if(adj_) {
graph()->free(adj_);
adj_ = nullptr;
}
}
}
}
void Node::init_dependent() {
if(!adj_) {
graph()->allocateBackward(this);
adj_->set(1.f);
}
}
void Node::set_zero_adjoint() {
if(!adj_) {
graph()->allocateBackward(this);
adj_->set(0.f);
}
}
float Node::scalar() {
return val_->scalar();
}
Ptr<Backend> Node::getBackend() {
return graph()->getBackend();
}
void Node::forward() {
if(recorder_)
recorder_->start(recorderHash_);
runForward(forwardOps());
if(recorder_)
recorder_->stop(recorderHash_, recorderStop_);
}
void Node::backward() {
if(recorder_)
recorder_->start(recorderHash_);
runBackward(backwardOps());
if(recorder_ && recorderStop_)
recorder_->stop(recorderHash_, recorderStop_);
}
void Node::record(Ptr<AutoTunerRecorder> recorder,
size_t recorderHash,
bool stop) {
recorder_ = recorder;
recorderHash_ = recorderHash;
recorderStop_ = stop;
}
} // namespace marian