Program Listing for File attention_constructors.h

Return to documentation for file (src/rnn/attention_constructors.h)

#pragma once

#include "marian.h"

#include "layers/factory.h"
#include "rnn/attention.h"
#include "rnn/constructors.h"
#include "rnn/types.h"

namespace marian {
namespace rnn {

class AttentionFactory : public InputFactory {
protected:
  Ptr<EncoderState> state_;

public:
//  AttentionFactory(Ptr<ExpressionGraph> graph) : InputFactory(graph) {}

  Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) override {
    ABORT_IF(!state_, "EncoderState not set");
    return New<Attention>(graph, options_, state_);
  }

  Accumulator<AttentionFactory> set_state(Ptr<EncoderState> state) {
    state_ = state;
    return Accumulator<AttentionFactory>(*this);
  }

  int dimAttended() {
    ABORT_IF(!state_, "EncoderState not set");
    return state_->getAttended()->shape()[1];
  }
};

typedef Accumulator<AttentionFactory> attention;
}  // namespace rnn
}  // namespace marian