Program Listing for File expression_graph_onnx_serialization.cpp

Return to documentation for file (src/onnx/expression_graph_onnx_serialization.cpp)

#ifdef USE_ONNX

#include "onnx/expression_graph_onnx_exporter.h"
#include "graph/expression_operators.h"
#include "graph/node_operators_unary.h"
#include "graph/node_operators_binary.h"
#include "common/version.h"
#define AuxillaryParseTableField AuxiliaryParseTableField  // in protobuf 3.12, the generated source has a spelling error
#include "3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.h"
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <algorithm>

namespace marian {

  // collection of helper functions for accessing and converting Expr properties
  // This class is a friend of all node-op classes whose attributes we need to access.
  class SerializationHelpers {
  public:
    // helper for accessing class members in Marian's polymorphic node classes
    // If 'e' is of NNaryNodeOp then execute getFn() and return true.
    template<class NNaryNodeOp, typename F>
    static bool tryGetAttributes(Expr e, const F& getFn) {
      auto np = std::dynamic_pointer_cast<NNaryNodeOp>(e);
      if (!np)
        return false;
      getFn(np);
      return true;
    }

    template<class NNaryNodeOp>
    static bool tryGetScalarAttribute(Expr e, float& scalar) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { scalar = np->scalar_; });
    }

    template<class NNaryNodeOp>
    static bool tryGetMatMulAttributes(Expr e, bool& transA, bool& transB, float& scalar) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) {
        transA = np->transA_;
        transB = np->transB_;
        scalar = np->scalar_;
      });
    }

    template<class NNaryNodeOp>
    static bool tryGetEpsilonAttribute(Expr e, float& eps) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { eps = np->eps_; });
    }

    template<class NNaryNodeOp>
    static bool tryGetAxisAttribute(Expr e, size_t& axis) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { axis = (size_t)e->shape().axis(np->axis_); });
    }

    template<class NNaryNodeOp>
    static bool tryGetAxesAttribute(Expr e, std::vector<size_t>& axes) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) {
        axes.clear();
        for (auto ax : np->axes_)
          axes.push_back((size_t)e->shape().axis(ax));
      });
    }

    template<class NNaryNodeOp>
    static bool tryGetShiftAttributes(Expr e, std::vector<int>& shift, float& padValue) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) {
        shift.assign(np->shift_.begin(), np->shift_.end());
        padValue = np->padValue_;
      });
    }

    template<class NNaryNodeOp>
    static bool tryGetSliceAttribute(Expr e, Slice& slice) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { slice = np->slice_; });
    }

    template<class NNaryNodeOp>
    static bool tryGetReshapeeAttributePtr(Expr e, Expr*& ep) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { ep = &np->reshapee_; });
    }

    template<class NNaryNodeOp>
    static bool tryGetStepNodeAttributePtr(Expr e, Expr*& ep) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { ep = &np->stepNode_; });
    }

    template<class NNaryNodeOp>
    static bool tryGetMaskAttributePtr(Expr e, Expr*& ep) {
      return tryGetAttributes<NNaryNodeOp>(e, [&](IPtr<NNaryNodeOp> np) { ep = &np->mask_; });
    }

    // call this for mandatory parameters, e.g. tryGetMaskAttributePtr(...) || tryFailed("message", ...)
    template<typename... Args>
    static bool fail(Args&&... args) {
      ABORT(std::forward<Args>(args)...);
    }
    static bool fail() { return fail("an attempt to access a Marian node attribute unexpectedly failed due to a type mismatch"); }
  };
  using E = SerializationHelpers;

  struct InputsMap : public std::map<Expr, Expr> {
    Expr operator()(Expr e) const {
      auto iter = find(e); // redirect input if found
      if (iter != end())
        e = iter->second;
      return e;
    }
  };

  // helper for rebuildNodesForward()
  static void addNodeAndChildren(Expr node, std::list<Expr>& nodesForward, std::set<Expr>& visited, const InputsMap& inputsMap)
  {
    // check if this is an input
    // In that case, we generate a replacement node instead, which has no children and thus terminates the recursion.
    // All nodes that reference this input are, however, unmodified.
    // The tape is now inconsistent. The consumer of this tape must perform child mapping.
    auto replacementNode = inputsMap(node);
    if (replacementNode != node)
      node = replacementNode;
    // recursion terminates if we already visited a node
    // (Input mapping is taken into account already.)
    auto res = visited.insert(node);
    if (!res.second) // already in visited set: done
      return;
    for (auto& child : node->children()) // children come before node itself
      addNodeAndChildren(child, nodesForward, visited, inputsMap);
    nodesForward.push_back(node);
  }

  // rebuild nodesForward_ from a graph given by its set of roots
  // Also replaces the inputs by constants, but does not redirect references (leaving an invalid tape--must be corrected on the fly by the caller!).
  void ExpressionGraphONNXExporter::rebuildNodesForward(const InputsMap& inputsMap,
                                                        const std::vector<std::pair<std::string, Expr>>& outputDefs) {
    nodesForward_.clear();
    std::set<Expr> visited;
    for (auto& outputDef : outputDefs)
      addNodeAndChildren(outputDef.second, nodesForward_, visited, inputsMap);
  }

  class NodeReferenceRedirector {
    std::map<Expr, Expr> nodeMap; // [orig node] -> replacement nodes

  public:
    void addRedirect(const Expr& whichNode, const Expr& withWhichNode) {
      nodeMap[whichNode] = withWhichNode;
    }

    // in-place redirect an Expr reference, i.e. look up the redirect and replace the original with it
    void redirectReference(Expr& child) const {
      auto iter = nodeMap.find(child);
      if (iter != nodeMap.end()) {
        child = iter->second;       // redirect child to the replacement node
        ABORT_IF(nodeMap.find(child) != nodeMap.end(), "Nested macro expansion??");
      }
    };

    // redirect all references (=children and more in special cases)
    void redirectAllReferencesIn(Expr v) const {
      // redirect all children
      auto& children = v->children(); // this is a mutable reference
      for (auto& child : children) {  // child is a mutable reference
        redirectReference(child);
      }
      // redirect additional references tat some nodes hold
      Expr* ep{};
      if (E::tryGetReshapeeAttributePtr<ReshapeNodeOp>   (v, ep) ||
          //E::tryGetStepNodeAttributePtr<StepNodeOp>      (v, ep) ||    // @TODO: review all of these and update the names
          E::tryGetMaskAttributePtr<PoolingWithMaskingOp>(v, ep)) {
        redirectReference(*ep);
      }
    }
  };

  static Expr newConstant(Expr v, Shape shape, float val, std::string suffix) {
    auto expr = v->graph()->constant(shape, inits::fromVector(std::vector<float>(shape.elements(), val)));
    expr->set_name("const_" + v->type() + "_" + std::to_string(v->getId()) + "_" + suffix);
    // Note: By convention, all constants should be named const_ something (and all data inputs data_),
    // to distinguish them from trainable weight tensors.
    return expr;
  }

  // unroll higher-level operations for which no ONNX equivalent exists
  // This updates the functionDefs' root nodes in-place.
  // Note: This appends to nodesForward_ in-place. Some meta-information, like root node, is not updated correctly.
  void ExpressionGraphONNXExporter::expandMacroOpsForONNX(std::map<std::string, std::pair<std::vector<std::pair<std::string, Expr>>, std::vector<std::pair<std::string, Expr>> >>& functionDefs) {
    LOG(info, "[graph] Expanding macro ops into primitives. Current graph size is {}", nodesForward_.size());
    NodeReferenceRedirector nodeReferenceRedirector;
    // clear memoization cache, as it removes some children for ops that have not changed since last inference
    tensors_->clearLongtermMemory();
    // Note: expansions will add to the existing tape in-place. But we disallows nested expansions,
    // i.e. disallow looping over newly created nodes, because otherwise the nodeReferenceRedirector
    // becomes very complicated because those new nodes are no longer topo-sorted.
    // The for loop below loops also over newly-created nodes, but those may not
    // trigger another expansion, which will be caught in redirectReference() above.
    auto beg = nodesForward_.begin();
    auto end = nodesForward_.end();
    for (auto vi = beg; vi != end; ++vi) {
      auto& v = *vi;
      // redirect all children of this node, in case they got mapped in this process
      nodeReferenceRedirector.redirectAllReferencesIn(v);
      // expand macro ops
      Expr n;
#if 0 // For GC ONNX, some ops are still missing. Map these first.
      // @BUGBUG: These operators are not up-to-date
      if (v->type() == "highway") {
        // Replace Sigmoid by Softmax. The only sigmoid in the system comes from highway.
        auto y = v->child(0);       // something like [B, H, T, dim]
        auto x = v->child(1);
        auto t = v->child(2);
        auto shape = x->shape();
        ABORT_IF(y->shape() != shape || t->shape() != shape, "unexpected highway shapes??");
        // Softmax([x,0]) = (Sigmoid(x), 1-Sigmoid(x))
        // Softmax([x,y]) = e^x / (e^x + e^y)
        // Sigmoid(x) = e^x / (e^x + e^0)
        auto shape1 = Shape{shape.elements() / shape.back(), shape.back(), 1};
        t = reshape(t, shape1);
        auto tAug = concatenate({t, newConstant(v, t->shape(), 0.0f, "zero_row")}, -1); // [(B*H*T, dim, 2)]
        auto s = softmax(tAug, /*axis=*/-1); // = (Sigmoid(t), 1-Sigmoid(t)) : [(B*H*T, dim, 2)]
        s = swapAxes(s, 0, -1);              // step() only supports axis=0
        auto sy = step(s, 0, /*axis=*/0);
        auto sx = step(s, 1, /*axis=*/0);
        sy = swapAxes(sy, 0, -1);
        sx = swapAxes(sx, 0, -1);
        sy = reshape(sy, shape);
        sx = reshape(sx, shape);
        n = sy * y + sx * x;
        //LOG(info, "OVERWRITING highway, {} -> {} -> {} -> back", std::string(shape), std::string(shape1), std::string(tAug->shape()));
      }
      else if (v->type() == "sum") {
        // replace ReduceSum by a matrix product with a vector of ones
        auto x = v->child(0);
        auto shape = x->shape();
        size_t lastAxis = shape.size() - 1;
        size_t axis;
        E::tryGetAxisAttribute<SumNodeOp>(v, axis) || E::fail();
        if (axis != lastAxis)   // bring axis to be reduced into last dimension so that we can MatMul
          x = swapAxes(x, (int)axis, (int)lastAxis);
        auto ones = newConstant(v, {x->shape().back(), 1}, 1.0f, "ones");
        n = dot(x, ones);       // [..., D] * [D, 1] = [..., 1]
        if (axis != lastAxis)   // and swap it back
          n = swapAxes(n, (int)axis, (int)lastAxis);
        //LOG(info, "OVERWRITING sum {}/{}, {} -> {} -> . -> {}", axis, lastAxis, std::string(shape), std::string(x->shape()), std::string(n->shape()));
      }
      else if (v->type() == "layer_normalization") {
        // layerNorm along last axis
        auto x = v->child(0);
        auto s = v->child(1);
        auto b = v->child(2);
        auto vecDim = x->shape().back();
        // for summing up elements, we use MatMul
        auto onesOverDim = newConstant(v, {vecDim, 1}, 1.0f / vecDim, "ones_over_dim");
        // compute mean and variance
        auto mean = dot(x, onesOverDim);
        auto x0 = x - mean;
        auto var = dot(x0 * x0, onesOverDim);
        // variance-normalize
        float epsilon;
        E::tryGetEpsilonAttribute<LayerNormalizationOp>(v, epsilon) || E::fail();
        auto sigma = sqrt(newConstant(v, {}, epsilon, "epsilon") + var);
        auto xnorm = x0 / sigma;
        // and final scale/bias
        n = xnorm * s + b;
        //LOG(info, "OVERWRITING layerNorm {} -> {}", std::string(x->shape()), std::string(mean->shape()));
      }
      else
#endif
      if (v->type() == "scalar_add") {
        float scalar{};
        E::tryGetScalarAttribute<ScalarAddNodeOp>(v, scalar) || E::fail();
        n = v->child(0) + newConstant(v, {}, scalar, "scalar");
      }
      else if (v->type() == "scalar_mult") {
        float scalar{};
        E::tryGetScalarAttribute<ScalarMultNodeOp>(v, scalar) || E::fail();
        n = v->child(0) * newConstant(v, {}, scalar, "scalar");
      }
      else if (v->type() == "square") {
        auto x = v->child(0);
        n = x * x;
      }
#if 0  // @BUGBUG: not supported for now, since we don't aim at training. This requires a function called select() which no longer exists.
      else if (v->type() == "x-ent") {
        auto x = v->child(0); // logits : some_shape + (num_classes,)
        auto y = v->child(1); // indices: some_shape + (1,)
        // C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
        auto xShape = x->shape();
        // note: indices are flattened into a vector
        auto yShape = xShape;   // true shape of y -> result shape
        yShape.back() = 1;
        auto nl = logsoftmax(x);
        //nl->debug("nl");
#if 1   // ONNX has no batched select/gather, so we must fake it.
        // We first flatten the batch to a vector.
        nl = flatten(nl); // now x: (totalWords, vocabSize), while  y: (totalWords,)
        // Then we create a constant with offsets into this vector
        auto vocabSize = xShape.back();
        auto totalWords = xShape.elements() / vocabSize; // total batch size across batch and length dimension
        std::vector<unsigned int> offs;
        for (size_t i = 0; i < totalWords; i++)
          offs.push_back((unsigned int)(i * vocabSize));
        auto offsExpr = v->graph()->indices(offs);
        offsExpr->set_name("const_" + v->type() + "_offsets_" + std::to_string(v->getId()));
        // Now form indices into the flattened vector using the offsets
        y = y + offsExpr; // -> [y0, y1 + V, y2 + 2V, ...]
        // Now we can select with this.
        n = -select(nl, y, /*axis=*/-1);
        n = reshape(n, yShape);
        //LOG(info, "x-ent: {}, {} -> {}", std::string(x->shape()), std::string(y->shape()), std::string(n->shape()));
#else   // better version, but unfortunately neither Marian nor ONNX support batched select/gather
        y = reshape(y, yShape);
        n = -select(nl, y, /*axis=*/-1); // @TODO: update if we ever add axis_ to x-ent
#endif
      }
#endif
      else if (v->type() == "highway") {
        auto y = v->child(0);
        auto x = v->child(1);
        auto t = v->child(2);
        auto s = sigmoid(t);
        auto oneExpr = newConstant(v, {}, 1.0f, "one");
        n = s * y + (oneExpr - s) * x;
      }
      else if ( v->type() == "bdot" ||
               (v->type() == "dot"  /*  && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2)*/) ||
               (v->type() == "affine" && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2 || v->child(2)->shape().size() > 2))) {
        // ONNX MatMul behaves like Numpy matmul, and therefore implements batched semantics.
        // ONNX MatMul has no transA/B/scale parameters, so we must handle those as explicit operations.
        // affine() could also be ONNX Gemm, but that does not support outer ranks, so we just expand it into dot().
        // @TODO: ^^ we can just reshape(). Code is already below, but ONNX Gemm always crashes, so this is disabled for now.
        auto a = v->child(0);
        auto b = v->child(1);
        bool transA{}, transB{}; float scalar{}; // (gcc complains without the initializers, which I think is a compiler bug)
        E::tryGetMatMulAttributes<DotNodeOp>       (v, transA, transB, scalar) ||
        E::tryGetMatMulAttributes<DotBatchedNodeOp>(v, transA, transB, scalar) ||
        E::tryGetMatMulAttributes<AffineNodeOp>    (v, transA, transB, scalar) || E::fail();
        //LOG(info, "{} {}={}x{} trans = {}, {} and scalar = {}",
        //          v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar);
        if (transA || transB || scalar != 1.0f ||
            (v->type() == "affine" && (a->shape().size() != 2 || b->shape().size() != 2 || v->child(2)->shape().size() > 2))) {
          //LOG(info, "patching {} {}={}x{} due to trans = {}, {} and scalar = {}",
          //          v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar);
          if (transA) {  // note: we don't optimize for this since it does not happen in present models
            a = swapAxes(a, -1, -2);
            transA = false;
          }
          // @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
          //if (v->type() != "bdot" && b->shape().size() == 2) {        // [A,B,C,I,J] x [J,K] --> reshape into regular matrix product
          //  ABORT_IF(transA, "Transposition not mapped away??");
          //  a = reshape(a, Shape({ a->shape().elements() / a->shape()[-1], a->shape()[-1] }));  // now it's a regular matrix product, can use Gemm
          //}
          /*else*/ if (transB) {  // not a regular matrix product: cannot use Gemm, so must transpose manually
            b = swapAxes(b, -1, -2);
            transB = false;
          }
          float extraScalar = 1.0f;
          if (v->type() == "bdot") {  // this maps to ONNX MatMul
            extraScalar = scalar;     // must add extra scale operation at the end
            scalar = 1.0f;            // we cannot scale in ONNX MatMul
            ABORT_IF(transA || transB || scalar != 1.0f, "Transposition and/or scalar not mapped away??");
            n = bdot(a, b, transA, transB, scalar);
          }
          else { // dot, affine
            // @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
            //if (a->shape().size() != 2 || b->shape().size() != 2) {  // not ONNX MatMul: must use explicit scale operation
              extraScalar = scalar;
              scalar = 1.0f;
            //}
            n = dot(a, b, transA, transB, scalar);
            //LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape()));
            if (v->type() == "affine")
              n = n + v->child(2);
          }
          //if (v->type() == "affine")
          //  LOG(info, "{} + {} -> {}", v->type(), std::string(v->child(2)->shape()), std::string(n->shape()));
          if (extraScalar != 1.0f)
            n = n * newConstant(v, {}, extraScalar, "scalar");
          if (n->shape() != v->shape())
            n = reshape(n, v->shape());  // if we did some shaping to get a regular matrix product, reshape it back
        }
      }
      else if (v->type() == "affine" && v->children().size() > 3) {
        // affine() may have a redundant vector of ones, which we strip here
        // This then becomes Gemm.
        v->children().resize(3);
        ABORT("affine() can presently not stripped of its additional ones vector. Need to fix Marian first to run with this.");
        // Note: Cannot recreate affine() as a new node, because that will get that fourth axis again.
        // @BUGBUG: This will crash.
      }
#if 0 // @BUGBUG: select() no longer exists. Likely some other ops are missing now.
      else if (v->type() == "select") {
        // select maps to Gather, and is limited to non-batched and the last axis
        size_t axis;
        E::tryGetAxisAttribute<SelectNodeOp>(v, axis) || E::fail();
        auto data    = v->child(0);
        auto indices = v->child(1);
        auto dataShape = data->shape();
        auto dataRank = dataShape.size();
        auto indicesShape = indices->shape();
        auto indicesRank = indicesShape.size();
        auto indicesDim = indicesShape[(int)axis - (int)dataShape.size()];
        ABORT_IF(indicesShape.elements() != indicesDim, "ONNX does not support batched select()");
        if (indicesRank != 1 || axis != dataRank - 1) {
          if (indicesRank != 1)
            indices = flatten(indices); // (batched Gather is not supported)
          if (axis != dataRank - 1)
            data = swapAxes(data, (int)axis, (int)dataRank - 1); // swap select axis to back
          n = select(data, indices, -1);
          if (axis != dataRank - 1)
            n = swapAxes(n, (int)axis, (int)dataRank - 1);
        }
      }
#endif
      else if (v->type() == "layer_normalization" &&
               (v->child(0)->shape().size() != 3 || v->child(1)->shape().size() != 1 || (v->children().size() > 2 && v->child(2)->shape().size() != 1))) {
        // ONNX InferenceNormalization is layer norm for shapes (N, C, D, ...) where N and C are
        // batch dimensions, and D... all share normalization statistics ("mean and variance are
        // computed per instance per channel").
        // Marian layer_normalization normalizes along axis -1.
        // Hence, if the input rank is != 3, we must temporarily reshape.
        // Also, ONNX expects scale and bias to contain C values (one for each c), while Marian
        // shares scale and bias along C but uses vectors of dim D. Hence, we must apply them manually.
        // This op gets replaced by a sequence that includes the same op, but with
        // gamma and beta being scalars, which is invalid for Marian.
        // (This will fail if layerNorm is applied to a scalar, which makes no sense.)
        auto x = v->child(0);
        auto s = v->child(1);
        auto b = v->children().size() > 2 ? v->child(2) : nullptr;  // beta is optional
        auto outShape = x->shape();
        auto vecDim = outShape[-1];
        x = reshape(x, {outShape.elements() / vecDim, 1, vecDim}); // -> (N, C, D)
        ABORT_IF((s->shape().size() > 1 && s->shape()[-1] != s->shape().elements()) ||
                 (b && b->shape().size() > 1 && b->shape()[-1] != b->shape().elements()),
                 "scale and bias must be vectors or single rows");
        s = flatten(s);
        if (b)
          b = flatten(b);
        //LOG(info, "layer_normalization reshaped from {} to {}", std::string(outShape), std::string(x->shape()));
        float epsilon;
        E::tryGetEpsilonAttribute<LayerNormalizationOp>(v, epsilon) || E::fail();
        //LOG(info, "LNORM {}, {}, {} vs. {}, {}", std::string(x->shape()), std::string(oneExpr->shape()), std::string(zeroExpr->shape()), std::string(s->shape()), std::string(b->shape()));
        n = layerNorm(x, newConstant(v, {1}, 1.0f, "one"), newConstant(v, {1}, 0.0f, "zero"), epsilon);
        n = n * s;
        if (b)
          n = n + b;
        n = reshape(n, outShape);
      }
      else if (v->type() == "const" && v->name().find("dropout_mask_") == 0) {
        // This is a randomly generated mask. We must replace this by RandomUniform.
        // This is done in 3 steps:
        //  - We expand v as (uniform < keepProb) * scale; but because Marian has no "<", we use "-" instead for now. @HACKHACK 1
        //  - The uniform for now is a constant, which later gets converted as ONNX RandomUniform(0,1).  @HACKHACK 2
        //  - The "-" with left arg of v gets patched to become ONNX Less. @HACKHACK 1 fix-up
        auto pString = v->name();
        pString.erase(0, pString.find_last_of('_') + 1);
        float dropProb = std::stof(pString);
        //LOG(info, "Found dropProb constant {} -> {}", v->name(), dropProb);
        float keepProb = 1.f - dropProb;
        float scale = 1.f / keepProb;
        auto uniformExpr = v->graph()->constant(v->shape(), inits::zeros());
        uniformExpr->set_name("opRandomUniform_" + std::to_string(v->getId())); // not using newConstant because of special node name
        // (uniform(0,1) < keepProb) * scale
        n = (uniformExpr - newConstant(v, {}, keepProb, "keepProb")) * newConstant(v, {}, scale, "scale");
        // @HACKHACK 1: Marian has no "less than", so we use "-" instead. Must patch that back later.
        // @HACKHACK 2: We use a specially-named constant as the placeholder for uniform(0,1).
      }

      if (n) {
        // copy key properties
        if (v->name() != n->name()) // (this tests for the empty name)
          n->set_name(v->name() + "_expanded"); // (this branch is actually never taken presently)
        n->setTrainable(v->trainable());
        // register mapping
        nodeReferenceRedirector.addRedirect(v, n);
        LOG(info, "[graph] Macro op {} expanded with new root op {}", v->type(), n->type());
      }
    }
    for (auto& functionDef : functionDefs) {
      for (auto& output : functionDef.second.second)  // redirect outputs: a root may also have been a macro op
        nodeReferenceRedirector.redirectReference(output.second);
      for (auto& output : functionDef.second.first)   // redirect inputs: inputs may be the outputs of other functions
        nodeReferenceRedirector.redirectReference(output.second);
    }

    // Since we added the expanded ops to the end of nodesForward_, we must bring it
    // back into topologically sorted order.
    LOG(info, "[graph] After creating expanded nodes, we now have {} nodes", nodesForward_.size());
  }

  using namespace onnx; // all -Proto classes come from here

  const std::string LENGTH_AXIS_NAME = "SOURCE_LENGTH";  // the source length is a named (dynamic) axis with this name

  // C++ port of a subset of https://github.com/onnx/onnx/blob/master/onnx/helper.py
  static ValueInfoProto makeValueInfoProto(std::string name, TensorProto_DataType dataType, std::vector<size_t> shape, size_t sentinelDim) {
    ValueInfoProto valueInfo;
    valueInfo.set_name(name);
    auto* valueInfoType = valueInfo.mutable_type();
    auto* valueInfoTensorType = valueInfoType->mutable_tensor_type();
    valueInfoTensorType->set_elem_type(dataType);
    auto* valueInfoTensorTypeShape = valueInfoTensorType->mutable_shape();
    for (auto dim : shape)
      if (dim == sentinelDim)
        valueInfoTensorTypeShape->add_dim()->set_dim_param(LENGTH_AXIS_NAME);
      else
        valueInfoTensorTypeShape->add_dim()->set_dim_value(dim);
    return valueInfo;
  }

  template<typename T> // note: for now, must pass the matching dataType (not checked)
  static TensorProto makeTensorProto(std::string name, TensorProto_DataType dataType, std::vector<size_t> shape, std::vector<T> vals) {
    TensorProto tensor;
    tensor.set_name(name);
    tensor.set_data_type(dataType);
    for (auto dim : shape)
      tensor.add_dims(dim);
#if 0   // @HACKHACK for debugging: keep files small during debugging, so that we can load and view those files easily
    *tensor.mutable_raw_data() = std::string((char*)vals.data(), (char*)(vals.data() + std::min(size_t(10), vals.size())));
#else
    *tensor.mutable_raw_data() = std::string((char*)vals.data(), (char*)(vals.data() + vals.size()));
#endif
    return tensor;
  }

  static inline void addAttribute(NodeProto& node, std::string name, std::vector<size_t> val) {
    AttributeProto* attribute = node.add_attribute();
    attribute->set_name(name);
    attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
    for (auto i : val)
      attribute->add_ints(i);
  }
  static inline void addAttribute(NodeProto& node, std::string name, std::vector<int> val) {
    AttributeProto* attribute = node.add_attribute();
    attribute->set_name(name);
    attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
    for (auto i : val)
      attribute->add_ints(i);
  }
  static inline void addAttribute(NodeProto& node, std::string name, std::string val) {
    AttributeProto* attribute = node.add_attribute();
    attribute->set_name(name);
    attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_STRING);
    attribute->set_s(val);
  }
  static inline void addAttribute(NodeProto& node, std::string name, float val) {
    AttributeProto* attribute = node.add_attribute();
    attribute->set_name(name);
    attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT);
    attribute->set_f(val);
  }
  static inline void addAttribute(NodeProto& node, std::string name, int val) {
      AttributeProto* attribute = node.add_attribute();
      attribute->set_name(name);
      attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
      attribute->set_i(val);
  }
  static inline void addAttribute(NodeProto& node, std::string name, size_t val) {
      AttributeProto* attribute = node.add_attribute();
      attribute->set_name(name);
      attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
      attribute->set_i(val);
  }
  static inline void addAttribute(NodeProto& node, std::string name, bool val) {
    AttributeProto* attribute = node.add_attribute();
    attribute->set_name(name);
    attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
    attribute->set_i(val ? 1 : 0); // bool is stored as int in ONNX
  }
  static void addAttributes(NodeProto&) { // end of recursion
  }
  template<typename T, typename... Attributes>
  static void addAttributes(NodeProto& node, std::string name, T val, Attributes&&... moreAttributes) {
    addAttribute(node, name, val);
    addAttributes(node, std::forward<Attributes>(moreAttributes)...);
  }

  template <typename... Attributes>
  static NodeProto makeNode(std::string opType, std::string nodeName,
      std::vector<std::string> inputs, std::vector<std::string> outputs,
      Attributes&&... attributes) {
    NodeProto node;
    node.mutable_op_type()->assign(opType);
    for (auto input : inputs)
      node.add_input(input);
    for (auto output : outputs)
      node.add_output(output);
    if (!nodeName.empty())
      node.set_name(nodeName);
    addAttributes(node, std::forward<Attributes>(attributes)...);
    return node;
  }

  static GraphProto makeGraph(const std::vector<NodeProto>& nodes, std::string name,
                              const std::vector<ValueInfoProto>& inputs,
                              const std::vector<ValueInfoProto>& outputs,
                              const std::vector<TensorProto>& initializers,
                              const std::vector<ValueInfoProto>& valueInfos) {
    GraphProto graph;
    for (auto& node : nodes)
      *graph.add_node() = node;
    graph.set_name(name);
    for (auto& input : inputs)
      *graph.add_input() = input;
    for (auto& output : outputs)
      *graph.add_output() = output;
    for (auto& initializer: initializers)
      *graph.add_initializer() = initializer;
    for (auto& valueInfo : valueInfos)
#if 0 // add some as explicit outputs for debugging
      if (valueInfo.name() == "opReshape_292" || valueInfo.name() == "opPad_294")
        *graph.add_output() = valueInfo;
      else
#endif
      *graph.add_value_info() = valueInfo;
    valueInfos;
    return graph;
  }

  static ModelProto makeModel(const GraphProto& graph, std::string producerName) {
    ModelProto model;
    model.set_ir_version(IR_VERSION);
    model.set_producer_name(producerName);
    model.mutable_graph()->CopyFrom(graph);
#define OPSET_IMPORT_VERSION 11
    model.add_opset_import()->set_version(OPSET_IMPORT_VERSION);
    return model;
  }

  static std::string mapExprOp(Expr e) {
    const static std::map<std::string, std::string> opMap = {
      {"+"                      , "Add"},
      {"-"                      , "Sub"},
      {"*"                      , "Mul"},
      {"/"                      , "Div"},
      {"negate"                 , "Neg"},
      {"ReLU"                   , "Relu"},
      {"reshape"                , "Reshape"},
      {"affine"                 , "Gemm"},    // @TODO: is this just a hack, or meant to be used for this? It is not really standard GEMM semantics.
      {"bdot"                   , "MatMul"},
      {"dot"                    , "MatMul"},
      {"sigmoid"                , "Sigmoid"},
      {"sqrt"                   , "Sqrt"},
      {"sin"                    , "Sin"},
      {"cos"                    , "Cos"},
      {"tan"                    , "Tan"},
      {"layer_normalization"    , "InstanceNormalization"},
      {"softmax"                , "Softmax"},
      {"logsoftmax"             , "LogSoftmax"},
      {"sum"                    , "ReduceSum"},
      {"transpose"              , "Transpose"},
      {"concat"                 , "Concat"},
      {"sliceView"              , "Slice"},
      {"shift"                  , "Pad"},
      {"rows"                   , "Gather"},
      {"select"                 , "Gather"},
      // The following are never emitted to ONNX. Keep our original type names to avoid special-casing lots of code.
      {"const"                  , "const"},
      {"param"                  , "param"}
    };
    auto iter = opMap.find(e->type());
    ABORT_IF(iter == opMap.end(), "ONNX export of operation {} is presently not supported", e->type());
    return iter->second;
  }

  // get a unique name for an Expr. Either an actual name, or OP_ID if not named.
  // 'nameOverrides' overrides that name. This is used for inputs and outputs.
  static std::string getExprName(Expr e, const std::map<Expr, std::string>& nameOverrides) {
    if (nameOverrides.find(e) != nameOverrides.end())
      return nameOverrides.at(e);
    std::string name = e->name();
    if (name == "none") // Marian assigns "none" to denote an unassigned name
      name = (e->type() == "const" ? "" : "op") + mapExprOp(e) + "_" + std::to_string(e->getId());
    // For 'const', do not prefix "op", so that all internal constants in the system
    // (i.e. not input data) have a prefix "const_" to distinguish them from weight tensors.
    return name;
  }

  // convert Marian shape into vector<size_t>
  static std::vector<size_t> getExprShape(Expr e) {
    const auto& shape = e->shape();
    return std::vector<size_t>(shape.begin(), shape.end());
  }

  // get TensorProto_DataType for an Expr
  // Note: We map Marian uint32_t to ONNX signed integers because those are only used
  // for indices for Gather operations, where Marian requires unsigned and ONNX signed.
  static TensorProto_DataType getExprDataType(Expr expr) {
    switch (expr->value_type()) {
      case marian::Type::float32: return TensorProto_DataType::TensorProto_DataType_FLOAT;
      //case marian::Type::uint32:  //return TensorProto_DataType::TensorProto_DataType_UINT32;
      case marian::Type::uint32:  // uint32 becomes ONNX INT32 as well (see above)
      case marian::Type::int32:   return TensorProto_DataType::TensorProto_DataType_INT32;
      default: ABORT("Tensor type not supported yet");
    }
  }

  // convert a Marian constant to an ONNX TensorProto
  static TensorProto makeExprTensorProto(Expr expr, const std::map<Expr, std::string>& nameOverrides) {
    auto dataType = getExprDataType(expr);
    auto name     = getExprName    (expr, nameOverrides);
    auto shape    = getExprShape   (expr);
    switch(expr->value_type()) {
    case marian::Type::float32: { // @TODO: template this?
      std::vector<float> valBuf;
      expr->val()->get(valBuf);
      return makeTensorProto(name, dataType, shape, valBuf);
    }
    case marian::Type::uint32: {
      std::vector<uint32_t> valBuf; // note: uint32_t still get passed to ONNX as signed INT32 (cf. getExprDataType())
      expr->val()->get(valBuf);
      return makeTensorProto(name, dataType, shape, valBuf);
    }
    case marian::Type::int32: {
      std::vector<int32_t> valBuf;
      expr->val()->get(valBuf);
      return makeTensorProto(name, dataType, shape, valBuf);
    }
    default:
      ABORT("Tensor type not supported yet");
    }
  }

  static void logNode(const NodeProto& node, const std::vector<size_t>& shape, size_t sentinelDim) {
    std::string s = node.name() + " = " + node.op_type() + "(";
    auto addComma = [&]() { if (s.back() != '(' && s.back() != '[') s += ", "; };
    for (int i = 0; i < node.input_size(); i++) {
      auto inputName = node.input(i);
      addComma();
      s += inputName;
    }
    for (int i = 0; i < node.attribute_size(); i++) {
      auto attribute = node.attribute(i);
      addComma();
      s += attribute.name() + "=?";
    }
    s += (") : [");
    for (auto dim : shape) {
      addComma();
      if (dim == sentinelDim)
          s += LENGTH_AXIS_NAME;
      else
          s += std::to_string(dim);
    }
    s.push_back(']');
    LOG(info, s);
  }

  // convert a Marian Expr to an ONNX node
  // This function needs inputs and initializers because the special case of Reshape needs
  // to create an extra input with initializer.
  static void addExprNode(Expr expr, std::vector<NodeProto>& nodes, std::vector<ValueInfoProto>& inputs,
                          std::vector<TensorProto>& initializers,
                          const std::map<Expr, std::string>& nameOverrides, const InputsMap& inputsMap,
                          size_t sentinelDim) {
    // get all children
    // These may reference inputs, and hence must be mapped right here.
    // The original child in this case is not on the tape.
    auto children = expr->children();
    for (auto& child : children)
      child = inputsMap(child);

    // inputs are referenced by their node names (also when they are leaves)
    std::vector<std::string> inputNames;
    for (const auto& child : children)
      inputNames.push_back(getExprName(child, nameOverrides));

    auto name = getExprName(expr, nameOverrides); // node name is used as both output name and node name
    auto op = mapExprOp(expr);

    //if (op == "MatMul" && expr->child(0)->shape().size() == 2 && expr->child(1)->shape().size() == 2) {
    //  op = "Gemm";
    //}

#if 1 // workaround for onnxruntime which does not handle Pad correctly
    if (op == "Pad") {
      // Implement Pad as Slice >> Concat
      std::vector<int> shifts;
      float padValue{}; // (compiler bug: without initialization, I get an uninit warning, yet it is correctly set)
      E::tryGetShiftAttributes<ShiftNodeOp>(expr, shifts, padValue) || E::fail();
      ABORT_IF(shifts[0] != 1, "can only shift by one");
      for (size_t i = 1; i < shifts.size(); i++)
        ABORT_IF(shifts[i] != 0, "can only shift along first axis");
      auto shape = getExprShape(children[0]);
      // Slice [0:-1,:,:]
      auto sliceName = name + "_Slice";
      auto sliceNode = makeNode("Slice", sliceName, inputNames, {sliceName});
      addAttribute(sliceNode, "axes",   std::vector<size_t>{0});
      addAttribute(sliceNode, "starts", std::vector<size_t>{0});
      addAttribute(sliceNode, "ends",   std::vector<size_t>{shape[0] - 1}); // drop last step
      nodes.push_back(sliceNode);
      LOG(info, "Pad slice op {}", sliceName);
      // create a padding constant
      auto paddingName = "const_" + name + "_Padding";
      shape[0] = 1;
      size_t n = 1;
      for (auto& dim : shape)
        n *= dim;
      std::vector<float> zeros(n);
      inputs.      push_back(makeValueInfoProto(paddingName, TensorProto_DataType::TensorProto_DataType_FLOAT, shape, sentinelDim));
      initializers.push_back(makeTensorProto   (paddingName, TensorProto_DataType::TensorProto_DataType_FLOAT, shape, zeros));
      LOG(info, "Pad constant {}", paddingName);
      // Concat([paddingNode, sliceNode], axis=0)
      auto node = makeNode("Concat", name, {paddingName, sliceName}, {name});
      addAttribute(node, "axis", 0);
      nodes.push_back(node);
      LOG(info, "Pad concat op {}", name);
      return;
    }
#endif

    auto node = makeNode(op, name, inputNames, {name});
    //LOG(info, "NODE {} {} -> {}", name, expr->type(), E::mapExprOp(expr));

    // add attributes needed by some operators

    // fix up inputs
    if (node.op_type() == "Reshape") { // Reshape requires the shape itself to be a tensor.
      auto shapeInputName = "const_" + getExprName(expr, {}) + "_shape_attr";
      *node.add_input() = shapeInputName;
      // create a new input and a new initializer
      auto shape = getExprShape(expr);
      auto shape64 = std::vector<int64_t>(shape.begin(), shape.end());
      for (auto& dim : shape64)
        if (dim == (int64_t)sentinelDim)
          dim = -1;  // means that this one is inferred at runtime
      std::vector<size_t> shapeShape{shape.size()}; // ONNX Reshape requires shape in INT64
      inputs.      push_back(makeValueInfoProto(shapeInputName, TensorProto_DataType::TensorProto_DataType_INT64, shapeShape, sentinelDim));
      initializers.push_back(makeTensorProto   (shapeInputName, TensorProto_DataType::TensorProto_DataType_INT64, shapeShape, shape64));
      std::string s = shapeInputName;
      for (auto& dim : shape64)
        s += " " + std::to_string(dim);
      LOG(info, s);
    }
    // axis attribute
    size_t axis{};
    std::vector<size_t> axes;
    if (E::tryGetAxisAttribute<ConcatenateNodeOp>(expr, axis)// ||
        //E::tryGetAxisAttribute<SelectNodeOp>(expr, axis)
        ) { // axis_ -> 'axis'
      addAttribute(node, "axis", axis);
    }
    else if (E::tryGetAxisAttribute<ReduceNodeOp>(expr, axis) ||
             E::tryGetAxisAttribute<SliceViewNodeOp>(expr, axis)) { // {axis_} -> 'axes'
      addAttribute(node, "axes", std::vector<size_t>{axis});
    }
    else if (E::tryGetAxesAttribute<TransposeNodeOp>(expr, axes)) { // here, the axes are called 'perm'
      addAttribute(node, "perm", axes);
    }
    else if (node.op_type() == "Softmax" || node.op_type() == "LogSoftmax") {
      // Note: ONNX (Log)Softmax is not along an axis; rather along all axes >= given axis (they get flattened).
      addAttribute(node, "axis", expr->shape().size()-1); // Marian softmax defaults to last axis. @TODO: update if we ever add an axis_ parameter.
    }
    else if (expr->type() == "rows") { // becomes Gather
      // Example, adopted from ONNX docs:
      //  axis = 0
      //  data = [ [1.0, 1.2], [2.3, 3.4], [4.5, 5.7], ]
      //  indices = [ 0, 1, 1, 2, ]
      //  output = [  [1.0, 1.2], [2.3, 3.4], [2.3, 3.4], [4.5, 5.7], ]
      ABORT_IF(expr->shape().size() != 2, "Unexpected input shape for rows()");
      addAttribute(node, "axis", 0);
    }
    // slice attributes (starts, ends)
    Slice slice;
    if (E::tryGetSliceAttribute<SliceViewNodeOp>(expr, slice)) {
      addAttribute(node, "starts", std::vector<size_t>{(size_t)slice.begin});
      addAttribute(node, "ends"  , std::vector<size_t>{(size_t)slice.end});
      addAttribute(node, "steps" , std::vector<size_t>{(size_t)slice.stride});
    }
    // shift attributes (shift, padValue)
    std::vector<int> shifts;
    float padValue{}; // (compiler bug: without initialization, I get an uninit warning, yet it is correctly set)
    if (E::tryGetShiftAttributes<ShiftNodeOp>(expr, shifts, padValue)) {
      std::vector<int> pads;
      for (auto shift : shifts)
        pads.push_back(shift);   // shift = #padValues to insert at front (or, for, shift < 0, to remove at front)
      for (auto shift : shifts)
        pads.push_back(-shift);  // and #values to remove at end (or, for, shift < 0, to insert at end)
      ABORT_IF(pads.size() != 2 * expr->shape().size(), "Unexpected number of shift dimensions");
      addAttribute(node, "pads", pads);
      addAttribute(node, "value", padValue);
      addAttribute(node, "mode", std::string("constant"));
    }

    // matmul attributes
    bool transA, transB;
    float scalar;
    // @BUGBUG: I cannot get Gemm to work, ONNX runtime always crashes. So we will NEVER get here.
    if (node.op_type() == "Gemm") {  // we get here for affine() or dot()
      // Note: We only get here if Gemm can implement this configuration.
      ABORT_IF(children[0]->shape().size() != 2 || children[1]->shape().size() != 2 ||
               (children.size() > 2 && children[2]->shape().size() > 2),
               "Gemm unexpectedly used for non-matrix inputs");
      E::tryGetMatMulAttributes<AffineNodeOp>(expr, transA, transB, scalar) ||
      E::tryGetMatMulAttributes<DotNodeOp>   (expr, transA, transB, scalar) || E::fail();
      /*if (transA)        */ addAttribute(node, "transA", transA ? 1 : 0);
      /*if (transB)        */ addAttribute(node, "transB", transB ? 1 : 0);
      /*if (scalar != 1.0f)*/ addAttribute(node, "alpha", scalar);
      //addAttribute(node, "beta", 0.0f);
    }
    else if (E::tryGetMatMulAttributes<DotNodeOp>       (expr, transA, transB, scalar) ||
             E::tryGetMatMulAttributes<DotBatchedNodeOp>(expr, transA, transB, scalar)) {
      // transpose/scalar not supported by ONNX MatMul, must have been expanded before we get here
      ABORT_IF(transA || transB || scalar != 1.0f, "Unexpected transpose or scalar attributes for {}", expr->type());
    }
    // epsilon attribute
    float epsilon;
    if (E::tryGetEpsilonAttribute<LayerNormalizationOp>(expr, epsilon)) {
      addAttribute(node, "epsilon", epsilon);
    }
    // dropout patches
    if (node.op_type() == "Sub" && children[0]->type() == "const" && children[0]->name().find("opRandomUniform_") == 0) {
      // @HACKHACK 1: For dropout, we route a "<" operation through a Marian "-" because it has no "<".
      *node.mutable_op_type() = "Less";
      // Note: Since this is a hack, we don't bother to fix up the node name, which is still opSub_ID.
    }
    else if (expr->type() == "const" && expr->name().find("opRandomUniform_") == 0) {
      // @HACKHACK 2: The dropout weight, which is a 'const' in Marian, acts as a placeholder for
      // a RandomUniform operation. In place of a 'const', we generate a uniform(0,1) node
      // of the same shape.
      *node.mutable_op_type() = "RandomUniform";
      addAttribute(node, "shape", getExprShape(expr));
    }
    nodes.push_back(node);
  }

  // serialize the nodesForward_ of a graph right after build() into an ONNX-formatted file
  // We declare this to be ONNX operator set 9. @TODO: Which ONNX version does this correspond to?
  // The nodes must only contain operations supported by ONNX, so the caller must first call
  // expandMacroOpsForONNX().
  // One batch axis can be variable-length. It is recognized via a hack: by a special
  // dimension value that otherwise never naturally occurs, e.g. a larger prime number.
  // We will not recognize derivates of this value, such as value+1 or value x another dimension.
  // @TODO: This presently does not support variable batch dimensions. How does ONNX handle them?
  // @TODO: How to handle guided alignment? That's another input. Name? Shape?
  // This is based on the simple example in
  // https://github.com/onnx/onnx/blob/master/onnx/examples/make_model.ipynb
  void ExpressionGraphONNXExporter::serializeToONNX(const std::string& fileRoot, FunctionDefs&& functionDefs, size_t sentinelDim) {
    GOOGLE_PROTOBUF_VERIFY_VERSION;

    // @TODO: expansion must deal with multiple sub-tapes (encoder, init)
    // expand Marian macro operations such as "highway" or "scalar_add" that ONNX does not have
    // After this, nodesForward_ is not topologically sorted.
    expandMacroOpsForONNX(functionDefs);

    for (const auto& functionDef : functionDefs) {
      const auto& graphName = functionDef.first;
      const auto& inputDefs  = functionDef.second.first;
      const auto& outputDefs = functionDef.second.second;

      // some stats
      LOG(info, "[onnx] Exporting graph {}", graphName);

      std::map<Expr, std::string> nameOverrides; // we implant input and output names dynamically (instead of setting the name in Expr)

      // clear memoization caches
      tensors_->clearShorttermMemory();
      tensors_->clearLongtermMemory();

      // create new dummy const nodes for all function arguments
      // These nodes will be replaced in rebuildNodesForward() and act as recursion stops.
      // The actual child references are NOT replaced.
      // Also, we collect the nameOverrides for all input and output nodes.
      InputsMap inputsMap;
      for (auto& inputDef : inputDefs) {
        const auto& input = inputDef.second;
        ABORT_IF(inputsMap.find(input) != inputsMap.end(), "Duplicate inputDef expr??");
        auto arg = constant(input->shape(), inits::zeros(), input->value_type());
        inputsMap[input] = arg;
        nameOverrides[arg] = inputDef.first;
      }
      for (const auto& outputDef : outputDefs)
          nameOverrides[inputsMap(outputDef.second)] = outputDef.first;

      // regenerate nodesForward_ from the roots, only for the function under consideration
      // This redirects all items in inputsMap in the graph and in outputDefs as well.
      // I.e. actual inputs are already replaced by Constants on the tape, but other nodes'
      // references are not!
      // All references from this point on have to be run through inputsMap().
      rebuildNodesForward(inputsMap, outputDefs);
      LOG(info, "[graph] Topologically sorted, garbage-collected graph has size {}", nodesForward_.size());

      // sanity check: is the tape consistent, assuming the inputsMap?
      std::set<Expr> nodesOnTape;
      for (const auto& e : nodesForward_)
        nodesOnTape.insert(e);
      for (const auto& e : nodesForward_) for (const auto& c : e->children()) {
        if (nodesOnTape.find(c) == nodesOnTape.end())
          LOG(info, "Redirected child: {}, {}", c->getId(), c->name());
        ABORT_IF(nodesOnTape.find(inputsMap(c)) == nodesOnTape.end(),
                 "Node {} {} refers to child {} {} that is off tape??", e->getId(), e->name(), c->getId(), c->name());
      }

      // sanity check: did we consume all expected inputs?
      std::set<Expr> mappedInputSet;  // set of replacement Exprs (those constants) for inputs
      for (auto ee : inputsMap)
        mappedInputSet.insert(ee.second);
      std::set<Expr> seenMappedInputs;
      for (const auto& expr : nodesForward_) {
        ABORT_IF(inputsMap.find(expr) != inputsMap.end(), "An input node (id={}) was not mapped??", expr->getId());
        if (mappedInputSet.find(expr) != mappedInputSet.end())
          seenMappedInputs.insert(expr);
      }
      for (auto e : mappedInputSet)
        if (seenMappedInputs.find(e) == seenMappedInputs.end()) {
          LOG(info, "WARNING: Input {} not consumed in input graph", nameOverrides[e]);
          nodesForward_.push_back(e);
        }
        //ABORT_IF(seenMappedInputs.find(e) == seenMappedInputs.end(), "Input node {} not found in input graph??", nameOverrides[e]);

      // output set -- these nodes are exported differently
      std::set<Expr> outputsSet;
      for (const auto& outputDef : outputDefs)
        outputsSet.insert(inputsMap(outputDef.second));

      std::vector<ValueInfoProto> inputsParamsAndConstants; // parameters and constants all are considered inputs, just with initializers

      // Create a the nodes -> array of NodeProto
      std::vector<NodeProto> nodes;
      std::vector<TensorProto> initializers; // constants are inputs with initializers that hold their values. They go here.
      std::vector<ValueInfoProto> shapeInfos; // expected shapes of operations (for diagnostics only)
      std::vector<ValueInfoProto> outputs; // outputs' shapes
      for(const auto& expr : nodesForward_) {
        //LOG(info, "exporting node name {} op {} ({})", getExprName(expr), E::mapExprOp(expr), expr->children().size());
        if (expr->type() == "param" ||
            (expr->type() == "const" && expr->name().find("opRandomUniform_") != 0)) { // leaves are not nodes in ONNX (except for the uniform placeholder @HACKHACK 2)
          //LOG(info, "exporting leaf name {} op {} ({})", getExprName(expr), E::mapExprOp(expr), expr->children().size());
          auto shape = getExprShape(expr);
          inputsParamsAndConstants.push_back(makeValueInfoProto(getExprName(expr, nameOverrides), getExprDataType(expr), shape, sentinelDim));
          // don't create an initializers entry for inputs
          if (std::any_of(inputsMap.begin(), inputsMap.end(), [&](const std::pair<Expr, Expr>& inputMap) {
                return inputMap.second == expr;
              })) { // skip designated inputs
            ABORT_IF(expr->type() != "const", "Data inputs must be 'const' nodes");
            //LOG(info, "No initializer for data-input node {}", getExprName(expr));
            continue;
          }
          // run initializers, to realize value of consts (params already got theirs)
          expr->allocate();
          expr->init();
          expr->forward();
          ABORT_IF(!expr->val(), "Leaf '{}' of type {} unexpectedly lacks a value despite trying really hard", expr->name(), expr->type());
          initializers.push_back(makeExprTensorProto(expr, nameOverrides));
          continue;      // parameters must become initializers, name=input name
        }
        addExprNode(expr, nodes, inputsParamsAndConstants, initializers, nameOverrides, inputsMap, sentinelDim);
        logNode(nodes.back(), getExprShape(expr), sentinelDim);

        auto valueInfo = makeValueInfoProto(nodes.back().name(), getExprDataType(expr), getExprShape(expr), sentinelDim);
        if (outputsSet.find(expr) != outputsSet.end())
          outputs.push_back(valueInfo);
        //else // we add expected-shape information, to more easily be able to track down where it may fail
        //  shapeInfos.push_back(valueInfo);
      }

      //LOG(info, "total nodes: {}, incl. {} inputs, {} op shapes", nodesForward_.size(), inputs.size(), shapeInfos.size());

      // @TODO: write a log message with the inputs and output names (the function signature)

      // Create the graph -> GraphProto
      auto graphDef = makeGraph(nodes, graphName, inputsParamsAndConstants, outputs, initializers, shapeInfos);

      // Create the model -> ModelProto
      auto modelDef = makeModel(graphDef, /*producer_name=*/"Marian " + buildVersion());

      // save it
      auto filename = fileRoot + "." + graphName + ".onnx";
      auto s = modelDef.SerializeAsString();
      ABORT_IF(s.empty(), "Failed to serialize ONNX graph to string buffer", filename);
      std::ofstream o(filename, std::ios::binary);
      ABORT_IF(o.fail(), "Failed to create ONNX model file {}", filename);
      o.write(s.data(), s.size());
      o.close();
      ABORT_IF(o.fail(), "Failed to write ONNX model to {}", filename);
      LOG(info, "[onnx] ONNX graph '{}' written to {}", graphName, filename);
    }

    // tape has been destroyed many times, so clear it for good
    nodesForward_.clear();
  }

  Expr ExpressionGraphONNXExporter::tryFindForwardNodeByName(const std::string& nodeName) const {
    auto iter = std::find_if(nodesForward_.begin(), nodesForward_.end(), [&](Expr node) {return node->name() == nodeName; });
    if (iter == nodesForward_.end())
      return nullptr;
    else
      return *iter;
  }

}  // namespace marian

#endif // USE_ONNX