Program Listing for File communicator.cpp

Return to documentation for file (src/training/communicator.cpp)

#include "training/communicator.h"
#include "common/utils.h"

#if defined(CUDA_FOUND) && defined(USE_NCCL)
#include "training/communicator_nccl.h"
#endif

#if MPI_FOUND
#include "mpi.h"
#endif

namespace marian {

ShardingMode getShardingMode(Ptr<Options> options, Ptr<IMPIWrapper> mpi) {
  auto shardOpt = options->get<std::string>("sharding", "global");
  if(shardOpt == "global" || !mpi || (mpi && mpi->numMPIProcesses() == 1))
    return ShardingMode::global;
  else if(shardOpt == "local")
    return ShardingMode::local;
  else
    ABORT("Unknown sharding mode {}", shardOpt);
}

Ptr<ICommunicator> createCommunicator(
  const std::vector<Ptr<ExpressionGraph>>& graphs,
  bool noNccl, ShardingMode shardingMode, Ptr<IMPIWrapper> mpi) {
  mpi;
#if defined(CUDA_FOUND) && defined(USE_NCCL)
  if(noNccl) {
    LOG(warn, "[comm] NCCL communicator overridden");
    return New<DefaultCommunicator>(graphs, mpi);
  }

  // if at least one of the devices is not a gpu, fall-back to default
  for(auto& graph : graphs) {
    if(graph->getBackend()->getDeviceId().type == DeviceType::cpu) {
      return New<DefaultCommunicator>(graphs, mpi);
    }
  }

  size_t d = graphs.size();
  if((d & (d - 1)) != 0) {
    LOG(warn,
        "[comm] Number of devices {} is not a power of 2 and communication "
        "might be slow with NCCL",
        d);
    LOG(warn, "[comm] You can switch off NCCL with --no-nccl option", d);
  }

  // the actual implementation is inside communicator.cu
  return New<NCCLCommunicator>(graphs, shardingMode, mpi);
#else // no CUDA or no NCCL
  noNccl; shardingMode; // (unused)
  return New<DefaultCommunicator>(graphs, mpi);
#endif
}

std::string IMPIWrapper::idStr() const { // helper to identify the node in logs
  std::string hostname; int pid; std::tie
  (hostname, pid) = utils::hostnameAndProcessId();
  return hostname + ":" + std::to_string(pid) + " MPI rank " + std::to_string(myMPIRank()) + " out of " + std::to_string(numMPIProcesses());
}

#if MPI_FOUND
// wrapper for MPI calls
// Since MPI can only be initialized once, only one instance of this class can exist.
class MPIWrapper : public IMPIWrapper
{
  int my_rank_;         // MPI rank of this node
  int comm_world_size_; // Number of nodes in MPI world (cluster)

  void handleError(int mpiRetval, const char* exprString) const { // call this with the return value of all MPI calls to report errors
    if (mpiRetval != MPI_SUCCESS) {
      char errStr[MPI_MAX_ERROR_STRING + 1] = { 0 };
      int resultLen = 0;
      MPI_Error_string(mpiRetval, &errStr[0], &resultLen);
      errStr[resultLen] = 0; // (@TODO: needed?)
      ABORT("MPI call failed with code {} '{}' on node {}: {}", mpiRetval, errStr, my_rank_, exprString); // @TODO: also log host name, which is involved on Windows
    }
  }
#define HANDLE_MPI_ERROR(expr) (handleError(expr, #expr)) // call through a macro so we can also log the failed expression itself

public:
  MPIWrapper(bool multiThreaded) {
    int requiredThreadingMode = multiThreaded ? MPI_THREAD_MULTIPLE : MPI_THREAD_FUNNELED; // FUNNELED means only one thread ever calls MPI

    int argc = 1; char* argv[] = { const_cast<char*>("this.exe") }; char** argvp = argv; // dummy argc/argv since MPI_Init needs something here
    int providedThreadingMode;
    HANDLE_MPI_ERROR(MPI_Init_thread(&argc, &argvp, MPI_THREAD_MULTIPLE, &providedThreadingMode));
    MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // have errors reported as return codes

    MPI_Comm_size(MPI_COMM_WORLD, &comm_world_size_);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);

    ABORT_IF(comm_world_size_ > 1 && providedThreadingMode < requiredThreadingMode,
      "Your version of MPI does not support multi-threaded communication.");

    // patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
    if (numMPIProcesses() > 1) {
      std::string rankStr = std::to_string(MPIWrapper::myMPIRank());
      std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
      while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
        rankStr.insert(rankStr.begin(), ' ');
      switchToMultinodeLogging(rankStr);
    }

    // log hostnames in order, and test
    for (size_t r = 0; r < numMPIProcesses(); r++) {
      MPIWrapper::barrier();
      if (r == MPIWrapper::myMPIRank() && MPIWrapper::numMPIProcesses() > 1) {
        std::string hostname; int pid; std::tie
        (hostname, pid) = utils::hostnameAndProcessId();
        LOG(info, "[mpi] Initialized as rank {} out of {} processes on {} as process {}",
                  MPIWrapper::myMPIRank(), MPIWrapper::numMPIProcesses(), hostname, pid);
      }
      MPIWrapper::barrier();
    }
  }

  virtual size_t myMPIRank()        const override { return (size_t)my_rank_; };
  virtual size_t numMPIProcesses() const override { return (size_t)comm_world_size_; };

  virtual void barrier(MPI_Comm comm = MPI_COMM_WORLD) const override {
    HANDLE_MPI_ERROR(MPI_Barrier(comm));
  }

  virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm = MPI_COMM_WORLD) const override {
    // MPI_Bcast only supports MAX_INT count, here and in the functions below, we need to cycle through the counts until we have sent
    // all elemements if count is larger MAX_INT.

    // get the data type size in bytes
    int datatypeSize;
    HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));

    // get the limit for int count
    size_t limit = (size_t)std::numeric_limits<int>::max();
    size_t remaining = count, offset = 0;

    // while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
    while(remaining > 0) {
      int intCount = (int)std::min(remaining, limit);
      HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
      offset    += (size_t)intCount;
      remaining -= (size_t)intCount;
    }
  }

  virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
    int datatypeSize;
    HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));

    size_t limit = (size_t)std::numeric_limits<int>::max();
    size_t remaining = count, offset = 0;
    while(remaining > 0) {
      int intCount = (int)std::min(remaining, limit);
      HANDLE_MPI_ERROR(MPI_Ssend((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)destRank, tag, comm));
      offset    += (size_t)intCount;
      remaining -= (size_t)intCount;
    }
  }

  virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
    int datatypeSize;
    HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));

    size_t limit = (size_t)std::numeric_limits<int>::max();
    size_t remaining = count, offset = 0;
    while(remaining > 0) {
      int intCount = (int)std::min(remaining, limit);
      HANDLE_MPI_ERROR(MPI_Recv((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)sourceRank, tag, comm, status));
      offset    += (size_t)intCount;
      remaining -= (size_t)intCount;
    }
  }

  virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
    if (sendbuf == recvbuf)
      sendbuf = MPI_IN_PLACE; // MSMPI requires this

    int datatypeSize;
    HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));

    size_t limit = (size_t)std::numeric_limits<int>::max();
    size_t remaining = count, offset = 0;
    while(remaining > 0) {
      int intCount = (int)std::min(remaining, limit);
      HANDLE_MPI_ERROR(MPI_Allreduce((char*)sendbuf + offset * (size_t)datatypeSize, (char*)recvbuf + offset * (size_t)datatypeSize, intCount, datatype, op, comm));
      offset    += (size_t)intCount;
      remaining -= (size_t)intCount;
    }
  }

  virtual void finalize() override {
    HANDLE_MPI_ERROR(MPI_Finalize());
  }
};
#endif

// dummy MPI wrapper that implements only one process without actual operations
// This is used when not compiling under MPI.
class FakeMPIWrapper : public IMPIWrapper
{
public:
  FakeMPIWrapper(bool) {
    LOG(info, "[comm] Compiled without MPI support. Running as a single process on {}", utils::hostnameAndProcessId().first);
  }
  virtual ~FakeMPIWrapper() {}
  virtual size_t myMPIRank() const override { return 0; };
  virtual size_t numMPIProcesses() const override { return 1; };

#pragma warning(push)
#pragma warning(disable: 4100) // unreferenced formal parameter
  // most functions are no-ops when applied to a single process
  virtual void barrier(MPI_Comm comm) const override {
    comm;
  }
  virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm) const override {
    buf; count; datatype; rootRank; comm;
  }
  virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
    buf; count; datatype; destRank; tag; comm;
  }
  virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
    buf; count; datatype; sourceRank; tag; comm;
    // @TODO: fill in status
    ABORT_IF(status != MPI_STATUS_IGNORE, "FakeMPIWrapper::recv() does not yet implement returning a status object");
  }
  virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
    count; datatype; op; comm;
    // @TODO: There is only one place where this is called with sendbuf != recvbuf, which is sync multi-node.
    //        I think that can be changed to use the same buffer. Then we should change this API
    //        to only accept one parameter, and remove this error check can be removed.
    ABORT_IF(sendbuf != recvbuf, "FakeMPIWrapper::allReduce() only implemented for in-place operation"); // otherwise it's not a no-op, we must copy data
  }
#pragma warning(pop)
  virtual void finalize() override { }
};

// create instance of the singleton MPI wrapper
static Ptr<IMPIWrapper> s_mpi;    // singleton instance of MPI wrapper
static size_t s_mpiUseCount;      // how many times has this wrapper been instantiated?
static bool s_mpiIsMultiThreaded; // multi-threading mode of this instance

Ptr<IMPIWrapper> initMPI(bool multiThreaded) {
  if (!s_mpi) {
#if MPI_FOUND
    s_mpi = New<MPIWrapper>(multiThreaded);
#else
    s_mpi = New<FakeMPIWrapper>(multiThreaded);
#endif
    s_mpiIsMultiThreaded = multiThreaded;
  }
  else {
    ABORT_IF(s_mpiIsMultiThreaded != multiThreaded, "attempted to reinitialize MPI with different multi-threading mode");
  }
  s_mpiUseCount++;
  return s_mpi;
}

void finalizeMPI(Ptr<IMPIWrapper>&& mpi) {
  ABORT_IF(mpi == nullptr || mpi != s_mpi, "attempted to finalize an inconsistent MPI instance. This should not be possible.");
  mpi = nullptr; // destruct caller's handle
  ABORT_IF(s_mpiUseCount == 0, "finalize called too many times. This should not be possible.");
  if (s_mpiUseCount == 1) { // last call finalizes MPI, i.e. tells MPI that we successfully completed computation
    ABORT_IF(s_mpi.use_count() != 1, "dangling reference to MPI??"); // caller kept another shared_ptr to this instance
    s_mpi->finalize(); // signal successful completion to MPI
    s_mpi = nullptr;   // release the singleton instance upon last finalization
  }
  s_mpiUseCount--;
}

}  // namespace marian