Program Listing for File expression_graph_onnx_exporter.cpp

Return to documentation for file (src/onnx/expression_graph_onnx_exporter.cpp)

#ifdef USE_ONNX

#include "onnx/expression_graph_onnx_exporter.h"

#include "models/model_factory.h"
#include "models/encoder_decoder.h"
#include "data/corpus_base.h"
#include "tensors/cpu/fbgemm/expression_graph_packable.h"

#include <memory>

namespace marian {
  // The goal is to export three functions:
  //  - encode_source(): encodes the source
  //                     output: encoder_state
  //  - decode_first(): resets decoder state and performs the first decoding step
  //                    main output: log prob vector for step 0
  //  - decode_next(): performs a subsequent decoding step (called repeatedly)
  //                   main output: log prob vector
  // This is done by generating the tape for encoding followed by the first two decoding steps.
  // As we do this, we remember the Exprs on the tape that are the inputs and the outputs
  // of the three functions.
  // Since not all Marian operations have a 1:1 ONNX counterpart, the tape now get rewritten
  // such that it only consists of operations that ONNX has.
  // Now we cut out three sub-graphs from the tape. Each sub-graph represents one
  // of the three functions. The sub-graph is delimited by the inputs and outputs we remembered above.
  // Limitations:
  //  - Inner recurrences, e.g. an RNN encoder, are not supported, since we cannot export control flow to ONNX.
  //  - Dynamic objects that depend on the input are not supported.
  //    For example constants whose shape depends on the input length.
  //    That's why we had to change the sinusoidal embeddings from a constant to a computation.
  //  - The input length is represented by a "unique" dimension value (97). This brittle.
  //    That dimension value must not occur naturally in the model.
  //    That dimension must also not be used in dimension calculations.
  //    E.g. the exporter does not recognize if a constant is added to it, or if it gets multiplied.
  void ExpressionGraphONNXExporter::exportToONNX(const std::string& modelToPrefix, Ptr<Options> modelOptions, const std::vector<std::string>& vocabPaths)
  {
    auto graph = shared_from_this();

    // get the model and the vocabularies
    auto model = std::dynamic_pointer_cast<IEncoderDecoder>(models::createModelFromOptions(modelOptions, models::usage::translation));
    std::vector<Ptr<Vocab>> vocabs;
    for (auto vocabPath : vocabPaths) {
      Ptr<Vocab> vocab = New<Vocab>(modelOptions, vocabs.size());
      vocab->load(vocabPath, INT_MAX);
      vocabs.emplace_back(vocab);
    }
    setInference(true);  // note: must also set "inference" parameter on options

    // if we must suppress <unk>, we do that by patching the bias
    const auto trgUnkId = vocabs.back()->getUnkId();
    int unkColId = -1;
    if (trgUnkId != Word::NONE && !modelOptions->get<bool>("allow-unk", false)) { // do we need to suppress unk?
      unkColId = trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector?
      // find the bias
      const std::string outputBiasName = "decoder_ff_logit_out_b";
      auto outputBias = graph->get(outputBiasName);
      auto outputBiasVal = outputBias->val();
      std::vector<float> outputBiasVec;
      outputBiasVal->get(outputBiasVec);
      outputBiasVec[unkColId] = -std::numeric_limits<float>::infinity();
      outputBiasVal->set(outputBiasVec);
    }

    // the input length is represented by a value that hopefully is not used elsewhere
    const size_t sentinelDim = 97;  // who uses prime numbers as dimensions anyways!
    size_t numEncoders = vocabs.size() - 1;  // @TODO: test this exporter for >1 encoder

    // some helper functions
    auto extractInputByName = [&](const std::string& name) {
      auto expr = tryFindForwardNodeByName(name);
      ABORT_IF(!expr, "Unexpectedly could not find input node named {}", name);
      expr->set_name("none"); // and nuke the name, as it will be created again in step()
      return std::make_pair(name, expr);
    };
    auto extractEmbeddingInputs = [&](bool forEncoder) {
      // embedding inputs must be found by name, since Marian does not clearly separate batch and Expr version of the batch
      std::vector<std::pair<std::string, Expr>> embeddingInputs;
      for (size_t i = 0; i < numEncoders; i++) {
        // inputs must be found by name, since Marian does not clearly separate batch and Expr version of the batch
        std::string inputName = "data_" + std::to_string(i);
        embeddingInputs.push_back(extractInputByName(inputName));
        if (forEncoder) {
          embeddingInputs.push_back(extractInputByName(inputName + "_mask"));
          embeddingInputs.push_back(extractInputByName(inputName + "_posrange"));
        }
      }
      return embeddingInputs;
    };
    auto extractStates = [&](Ptr<DecoderState> decoderState) {
      std::vector<Expr> states;  // all decoder-state Exprs in a long list
      for (const auto& d : decoderState->getStates()) {
        states.push_back(d.output);
        states.push_back(d.cell);
      }
      return states;
    };

    // run a fake batch through the encoder (this goes into encode_source()) and create decoder start states
    // This adds the operations to the tape.
    std::vector<Ptr<data::SubBatch>> subBatches;
    for (size_t batchIndex = 0; batchIndex < numEncoders; batchIndex++) {
      auto sb = New<data::SubBatch>(1, sentinelDim, vocabs[batchIndex]);
      // set word indices to random values
      std::transform(sb->data().begin(), sb->data().end(), sb->data().begin(),
        [&](Word) -> Word { return vocabs[batchIndex]->randWord(); });
      // mask: no items ask being masked out
      std::fill(sb->mask().begin(), sb->mask().end(), 1.f);
      subBatches.push_back(std::move(sb));
    }
    auto batch = New<data::CorpusBatch>(subBatches);
    auto startState = model->startState(graph, batch);

    // fish out the embedding inputs by name and neutralize the names
    // These constitute the inputs for the graph we are cutting out for encode_source().
    auto encoderEmbeddingInputs = extractEmbeddingInputs(/*forEncoder=*/true);
    std::vector<std::pair<std::string, Expr>> encoderContexts;
    for (const auto& e : startState->getEncoderStates())
      encoderContexts.push_back(std::make_pair("encoder_context_" + std::to_string(encoderContexts.size()), e->getContext()));

    // run it further until the first prediction --> decode_first()
    // This adds more operations to the tape.
    auto decodeFirstState = model->step(graph, startState, /*hypIndices=*/{},
      /*words=*/{}, /*batchIndices=*/{ 0 }, /*beamSize=*/1);
    auto decodeFirstPosRangeInput = extractInputByName("data_" + std::to_string(numEncoders) + "_posrange");

    // run it further until the next prediction --> decode_next()
    // This adds more operations to the tape.
    auto decodeNextState = model->step(graph, decodeFirstState, /*hypIndices=*/{},
      /*words=*/{ vocabs.back()->randWord() }, /*batchIndices=*/{ 0 }, /*beamSize=*/1);
    auto decodeNextEmbeddingInput = extractEmbeddingInputs(/*forEncoder=*/false);
    auto decodeNextPosRangeInput = extractInputByName("data_" + std::to_string(numEncoders) + "_posrange");

    ABORT_IF(encoderContexts.size() != numEncoders, "Unexpected mismatch in number of encoders??");

    // create a descriptor for the three functions, which consists of
    //  - function name
    //  - list of inputs and outputs, as name-Expr pairs
    FunctionDefs functionDefs;

    std::vector<std::pair<std::string, Expr>> inputs;
    std::vector<std::pair<std::string, Expr>> outputs;

    // descriptor for encode_source(data_0, data_0_mask) -> encoder_context_0
    inputs = encoderEmbeddingInputs;
    outputs = encoderContexts;
    functionDefs["encode_source"] = std::make_pair(std::move(inputs), std::move(outputs));

    // descriptor for decode_first(data_1_posrange, encoder_context_0, data_0_mask) -> logits, out_decoder_state_0, out_decoder_state_1, ...
    inputs.emplace_back(decodeFirstPosRangeInput);
    for (size_t i = 0; i < numEncoders; i++) {
      inputs.emplace_back(encoderContexts[i]);
      inputs.emplace_back(encoderEmbeddingInputs[1+2*i]);
    }
    outputs.emplace_back(std::make_pair("first_logits", decodeFirstState->getLogProbs().getLogits()));
    for (const auto& dss : extractStates(decodeFirstState))
      outputs.emplace_back(std::make_pair("first_decoder_state_" + std::to_string(outputs.size()-1), dss));
    functionDefs["decode_first"] = std::make_pair(std::move(inputs), std::move(outputs));

    // descriptor for decode_next(prev_word, data_1_posrange, encoder_context_0, data_0_mask, decoder_state_0, decoder_state_1, ...) -> logits, decoder_state_0, decoder_state_1, ...
    inputs.emplace_back(std::make_pair("prev_word", decodeNextEmbeddingInput[0].second));
    inputs.emplace_back(decodeNextPosRangeInput);
    for (size_t i = 0; i < numEncoders; i++) {
      inputs.emplace_back(encoderContexts[i]);
      inputs.emplace_back(encoderEmbeddingInputs[1 + 2 * i]);
    }
    for (const auto& dss : extractStates(decodeFirstState))
      inputs.emplace_back(std::make_pair("decoder_state_" + std::to_string(inputs.size() - (numEncoders*2 + 2)), dss));
    outputs.emplace_back(std::make_pair("next_logits", decodeNextState->getLogProbs().getLogits()));
    for (const auto& dss : extractStates(decodeNextState))
      outputs.emplace_back(std::make_pair("next_decoder_state_" + std::to_string(outputs.size() - 1), dss));
    functionDefs["decode_next"] = std::make_pair(std::move(inputs), std::move(outputs));

    // now export the sub-graph as given by the function descriptor
    serializeToONNX(modelToPrefix, std::move(functionDefs), sentinelDim);
  }
}

#endif