diff --git a/.clang-format b/.clang-format index e941415f..71fc4868 100644 --- a/.clang-format +++ b/.clang-format @@ -27,7 +27,6 @@ BreakBeforeBraces: Allman # ConstructorInitializerAllOnOneLineOrOnePerLine: false BreakConstructorInitializers: BeforeComma ConstructorInitializerIndentWidth: 0 -BreakInheritanceList: BeforeComma #AllowShortBlocksOnASingleLine: Always AllowShortBlocksOnASingleLine: true AllowShortCaseLabelsOnASingleLine: false diff --git a/.cscs-ci/container/build.Containerfile b/.cscs-ci/container/build.Containerfile new file mode 100644 index 00000000..fe3e707f --- /dev/null +++ b/.cscs-ci/container/build.Containerfile @@ -0,0 +1,20 @@ +ARG DEPS_IMAGE +FROM $DEPS_IMAGE + +COPY . /oomph +WORKDIR /oomph + +ARG BACKEND +ARG NUM_PROCS +RUN spack -e ci build-env oomph -- \ + cmake -G Ninja -B build \ + -DCMAKE_BUILD_TYPE=Debug \ + -DOOMPH_WITH_TESTING=ON \ + -DOOMPH_WITH_$(echo $BACKEND | tr '[:lower:]' '[:upper:]')=ON \ + -DOOMPH_USE_BUNDLED_LIBS=ON \ + -DOOMPH_USE_BUNDLED_HWMALLOC=OFF \ + -DMPIEXEC_EXECUTABLE="" \ + -DMPIEXEC_NUMPROC_FLAG="" \ + -DMPIEXEC_PREFLAGS="" \ + -DMPIEXEC_POSTFLAGS="" && \ + spack -e ci build-env oomph -- cmake --build build -j$NUM_PROCS diff --git a/.cscs-ci/container/deps.Containerfile b/.cscs-ci/container/deps.Containerfile new file mode 100644 index 00000000..f5867ac5 --- /dev/null +++ b/.cscs-ci/container/deps.Containerfile @@ -0,0 +1,24 @@ +ARG BASE_IMAGE +FROM $BASE_IMAGE + +ARG SPACK_SHA +RUN mkdir -p /opt/spack && \ + curl -fLsS "https://api.github.com/repos/spack/spack/tarball/$SPACK_SHA" | tar --strip-components=1 -xz -C /opt/spack + +ENV PATH="/opt/spack/bin:$PATH" + +ARG SPACK_PACKAGES_SHA +RUN mkdir -p /opt/spack-packages && \ + curl -fLsS "https://api.github.com/repos/spack/spack-packages/tarball/$SPACK_PACKAGES_SHA" | tar --strip-components=1 -xz -C /opt/spack-packages + +RUN spack repo remove --scope defaults:base builtin && \ + spack repo add --scope site /opt/spack-packages/repos/spack_repo/builtin + +ARG SPACK_ENV_FILE +COPY $SPACK_ENV_FILE /spack_environment/spack.yaml + +ARG NUM_PROCS +RUN spack external find --all && \ + spack env create ci /spack_environment/spack.yaml && \ + spack -e ci concretize -f && \ + spack -e ci install --jobs $NUM_PROCS --fail-fast --only=dependencies diff --git a/.cscs-ci/default.yaml b/.cscs-ci/default.yaml new file mode 100644 index 00000000..e8208abd --- /dev/null +++ b/.cscs-ci/default.yaml @@ -0,0 +1,192 @@ +include: + - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' + +variables: + BASE_IMAGE: jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-pytorch:26.01-py3-alps4-dev + SPACK_SHA: v1.1.1 + SPACK_PACKAGES_SHA: 5f24787b5cd3c2356d9a8188b989ceb5307299c6 # https://github.com/msimberg/spack-packages/tree/oomph-nccl + FF_TIMESTAMPS: true + +.build_deps_template: + timeout: 1 hour + before_script: + - echo $DOCKERHUB_TOKEN | podman login docker.io -u $DOCKERHUB_USERNAME --password-stdin || true + - export DOCKERFILE_SHA=`sha256sum .cscs-ci/container/deps.Containerfile | head -c 16` + - export ENV_FILE_SHA=`sha256sum ${SPACK_ENV_FILE} | head -c 16` + - export CONFIG_TAG=`echo $DOCKERFILE_SHA-$BASE_IMAGE-$SPACK_SHA-$SPACK_PACKAGES_SHA-$ENV_FILE_SHA | sha256sum - | head -c 16` + - export PERSIST_IMAGE_NAME=$CSCS_REGISTRY_PATH/public/oomph-spack-deps-$BACKEND:$CONFIG_TAG + - echo -e "CONFIG_TAG=$CONFIG_TAG" >> base-${BACKEND}.env + - echo -e "DEPS_IMAGE=$PERSIST_IMAGE_NAME" >> base-${BACKEND}.env + variables: + DOCKERFILE: .cscs-ci/container/deps.Containerfile + DOCKER_BUILD_ARGS: '["BASE_IMAGE", "SPACK_SHA", "SPACK_PACKAGES_SHA", "SPACK_ENV_FILE"]' + SPACK_ENV_FILE: .cscs-ci/spack/$BACKEND.yaml + artifacts: + reports: + dotenv: base-${BACKEND}.env + +build_deps_nccl: + variables: + BACKEND: nccl + extends: + - .container-builder-cscs-gh200 + - .build_deps_template + +build_deps_mpi: + variables: + BACKEND: mpi + extends: + - .container-builder-cscs-gh200 + - .build_deps_template + +build_deps_ucx: + variables: + BACKEND: ucx + extends: + - .container-builder-cscs-gh200 + - .build_deps_template + +# TODO: Libfabric tests are currently failing on Alps and need to be fixed. +# build_deps_libfabric: +# variables: +# BACKEND: libfabric +# extends: +# - .container-builder-cscs-gh200 +# - .build_deps_template + +.build_template: + extends: .container-builder-cscs-gh200 + timeout: 15 minutes + before_script: + - echo $DOCKERHUB_TOKEN | podman login docker.io -u $DOCKERHUB_USERNAME --password-stdin || true + - export PERSIST_IMAGE_NAME=$CSCS_REGISTRY_PATH/public/oomph-build-$BACKEND:$CI_COMMIT_SHA + - echo -e "BUILD_IMAGE=$PERSIST_IMAGE_NAME" >> build-${BACKEND}.env + variables: + DOCKERFILE: .cscs-ci/container/build.Containerfile + DOCKER_BUILD_ARGS: '["DEPS_IMAGE", "BACKEND"]' + artifacts: + reports: + dotenv: build-${BACKEND}.env + +build_nccl: + variables: + BACKEND: nccl + extends: .build_template + needs: + - job: build_deps_nccl + artifacts: true + +build_mpi: + variables: + BACKEND: mpi + extends: .build_template + needs: + - job: build_deps_mpi + artifacts: true + +build_ucx: + variables: + BACKEND: ucx + extends: .build_template + needs: + - job: build_deps_ucx + artifacts: true + +# TODO: Libfabric tests are currently failing on Alps and need to be fixed. +# build_libfabric: +# variables: +# BACKEND: libfabric +# extends: .build_template +# needs: +# - job: build_deps_libfabric +# artifacts: true + +.test_template_base: + extends: .container-runner-clariden-gh200 + variables: + SLURM_JOB_NUM_NODES: 1 + SLURM_GPUS_PER_TASK: 1 + SLURM_TIMELIMIT: '5:00' + SLURM_PARTITION: normal + SLURM_MPI_TYPE: pmix + SLURM_NETWORK: disable_rdzv_get + SLURM_LABELIO: 1 + SLURM_UNBUFFEREDIO: 1 + PMIX_MCA_psec: native + PMIX_MCA_gds: "^shmem2" + USE_MPI: NO + +.test_serial_template: + extends: .test_template_base + variables: + SLURM_NTASKS: 1 + script: + - ctest --test-dir /oomph/build -L "serial" --output-on-failure --timeout 60 --parallel 8 + +.test_parallel_template: + extends: .test_template_base + variables: + SLURM_NTASKS: 4 + script: + # All ranks write to ctest files in Testing, but this can deadlock when + # writing inside the container. + - if [[ "${SLURM_PROCID}" == 0 ]]; then rm -rf /oomph/build/Testing; mkdir /tmp/Testing; ln -s /tmp/Testing /oomph/build/Testing; fi + - sleep 1 + - ctest --test-dir /oomph/build -L "parallel-ranks-4" --output-on-failure --timeout 60 + +test_serial_nccl: + extends: .test_serial_template + needs: + - job: build_nccl + artifacts: true + image: $BUILD_IMAGE + +test_parallel_nccl: + extends: .test_parallel_template + needs: + - job: build_nccl + artifacts: true + image: $BUILD_IMAGE + +test_serial_mpi: + extends: .test_serial_template + needs: + - job: build_mpi + artifacts: true + image: $BUILD_IMAGE + +test_parallel_mpi: + extends: .test_parallel_template + needs: + - job: build_mpi + artifacts: true + image: $BUILD_IMAGE + +test_serial_ucx: + extends: .test_serial_template + needs: + - job: build_ucx + artifacts: true + image: $BUILD_IMAGE + +test_parallel_ucx: + extends: .test_parallel_template + needs: + - job: build_ucx + artifacts: true + image: $BUILD_IMAGE + +# TODO: Libfabric tests are currently failing on Alps and need to be fixed. +# test_serial_libfabric: +# extends: .test_serial_template +# needs: +# - job: build_libfabric +# artifacts: true +# image: $BUILD_IMAGE + +# test_parallel_libfabric: +# extends: .test_parallel_template +# needs: +# - job: build_libfabric +# artifacts: true +# image: $BUILD_IMAGE diff --git a/.cscs-ci/spack/libfabric.yaml b/.cscs-ci/spack/libfabric.yaml new file mode 100644 index 00000000..fac7f88f --- /dev/null +++ b/.cscs-ci/spack/libfabric.yaml @@ -0,0 +1,6 @@ +spack: + specs: + - oomph@main backend=libfabric +cuda + view: false + concretizer: + unify: true diff --git a/.cscs-ci/spack/mpi.yaml b/.cscs-ci/spack/mpi.yaml new file mode 100644 index 00000000..d59aab13 --- /dev/null +++ b/.cscs-ci/spack/mpi.yaml @@ -0,0 +1,6 @@ +spack: + specs: + - oomph@main backend=mpi +cuda + view: false + concretizer: + unify: true diff --git a/.cscs-ci/spack/nccl.yaml b/.cscs-ci/spack/nccl.yaml new file mode 100644 index 00000000..94f0dd31 --- /dev/null +++ b/.cscs-ci/spack/nccl.yaml @@ -0,0 +1,6 @@ +spack: + specs: + - oomph@main backend=nccl +cuda + view: false + concretizer: + unify: true diff --git a/.cscs-ci/spack/ucx.yaml b/.cscs-ci/spack/ucx.yaml new file mode 100644 index 00000000..51377dd8 --- /dev/null +++ b/.cscs-ci/spack/ucx.yaml @@ -0,0 +1,6 @@ +spack: + specs: + - oomph@main backend=ucx +cuda + view: false + concretizer: + unify: true diff --git a/CMakeLists.txt b/CMakeLists.txt index 90a582d1..6a2b8926 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,11 @@ include(oomph_ucx) # --------------------------------------------------------------------- include(oomph_libfabric) +# --------------------------------------------------------------------- +# oomph NCCL variant +# --------------------------------------------------------------------- +include(oomph_nccl) + # --------------------------------------------------------------------- # main src subdir # --------------------------------------------------------------------- @@ -142,6 +147,7 @@ install( ${CMAKE_CURRENT_BINARY_DIR}/oomphConfig.cmake ${CMAKE_CURRENT_BINARY_DIR}/oomphConfigVersion.cmake ${CMAKE_CURRENT_LIST_DIR}/cmake/FindLibfabric.cmake + ${CMAKE_CURRENT_LIST_DIR}/cmake/FindNCCL.cmake ${CMAKE_CURRENT_LIST_DIR}/cmake/FindUCX.cmake ${CMAKE_CURRENT_LIST_DIR}/cmake/FindPMIx.cmake DESTINATION diff --git a/README.md b/README.md index 69ec02b7..43a47018 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ # OOMPH **Oomph** is a library for enabling high performance point-to-point, asynchronous communication over -different fabrics. It leverages the ubiquitos MPI library as well as UCX and Libfabric. Both -device and host memory are supported. Under the hood it uses -[hwmalloc](https://github.com/ghex-org/hwmalloc) for memory registration. +different fabrics. It leverages the ubiquitous MPI library as well as UCX and Libfabric. Both +device and host memory are supported. A subset of functionality is also supported with NCCL. Under +the hood it uses [hwmalloc](https://github.com/ghex-org/hwmalloc) for memory registration. **selling points** - lightweight, fast @@ -136,6 +136,65 @@ comm.progress(); // or progress until some event is triggered while(!completed) { comm.progress(); } ``` + +### Groups + +Communicators expose group functionality as provided by NCCL (with +ncclGroupStart and ncclGroupEnd). For non-NCCL backends the group functionality +is a no-op. For NCCL using the group functionality can be a both a requirement +to avoid deadlocks (communication within a group can make progress +independently, while outside of a group communication is ordered) and for +performance (a single device kernel is submitted for a NCCL group). + +Groups are created by explicitly starting and ending the group: + +```cpp +comm.start_group(); +oomph::send_request sreq = comm.send(smsg, 1, 0); +oomph::recv_request rreq = comm.recv(rmsg, 1, 0); +comm.end_group(); + +// With NCCL, no progress will be made until after the group ends +sreq.wait(); +rreq.wait(); +``` + +### Stream awareness + +Some backend implementations can schedule communication on a GPU stream. +Currently only the NCCL backend makes use of this. All other backends ignore +the stream argument. To query if a backend is stream-aware use the +`is_stream_aware` member query on a communicator. The stream can be passed as +an optional last parameter to `send` or `recv`: + +```cpp +if (comm.is_stream_aware()) { + # Schedule communication on the default CUDA stream if the backend is + # stream aware + cudaStream_t stream = 0; + oomph::send_request comm.send(msg, 1, 0, stream); +} +``` + +### NCCL restrictions + +NCCL has significantly different semantics from MPI, libfabric, and UCX which +is reflected in a number of restrictions on how the NCCL communicator can be +used: + +- Tags are not supported by NCCL and ignored by the backend. Communication + order on different ranks must match (except within NCCL groups where there is + some flexibility). This also means that e.g. recv should not be called before + send unless within a NCCL group. +- The `thread_safe` option for the NCCL communicator is not supported because + of the above ordering restrictions. +- Cancellation is not supported. +- `wait` and `progress` are disallowed when a NCCL group is active as no + progress can be made until a NCCL group is ended and submitted. + +The NCCL backend is primarily designed for use in GHEX where these differences +can be hidden from the user. + ## Acknowledgments This work was financially supported by the PRACE project funded in part by the EU's Horizon 2020 Research and Innovation programme (2014-2020) under grant agreement 823767. diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 00000000..15c56896 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,78 @@ +# From https://github.com/pytorch/gloo/blob/main/cmake/Modules/Findnccl.cmake. + +# Try to find NCCL +# +# The following variables are optionally searched for defaults +# NCCL_ROOT_DIR: Base directory where all NCCL components are found +# NCCL_INCLUDE_DIR: Directory where NCCL header is found +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} CACHE PATH "Folder contains NVIDIA NCCL") + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS + ${NCCL_INCLUDE_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/include + ${CUDA_TOOLKIT_ROOT_DIR}/include) + +if(DEFINED ENV{USE_STATIC_NCCL} AND NOT "$ENV{USE_STATIC_NCCL}" STREQUAL "") + message(STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") + set(_use_static_nccl ON) + set(NCCL_LIBNAME "libnccl_static.a") +else() + set(_use_static_nccl OFF) + set(NCCL_LIBNAME "nccl") +endif() + +find_library(NCCL_LIBRARY + NAMES ${NCCL_LIBNAME} + HINTS + ${NCCL_LIB_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/lib + ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu + ${NCCL_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) + +if (NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/nccl.h") + message(STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") + file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) + if (NCCL_MAJOR_VERSION_DEFINED) + string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") + endif() + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + + if(NOT TARGET NCCL::nccl AND NCCL_FOUND) + if(_use_static_nccl) + add_library(NCCL::nccl STATIC IMPORTED) + else() + add_library(NCCL::nccl SHARED IMPORTED) + endif() + set_target_properties(NCCL::nccl PROPERTIES + IMPORTED_LOCATION ${NCCL_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${NCCL_INCLUDE_DIRS} + ) + endif() +endif() + diff --git a/cmake/oomphConfig.cmake.in b/cmake/oomphConfig.cmake.in index 6044d714..cd391e30 100644 --- a/cmake/oomphConfig.cmake.in +++ b/cmake/oomphConfig.cmake.in @@ -19,5 +19,8 @@ if (@OOMPH_WITH_LIBFABRIC@) #set(LIBFABRIC_INCLUDE_DIR @ULIBFABRIC_INCLUDE_DIRS@) find_dependency(Libfabric) endif() +if (@OOMPH_WITH_NCCL@) + find_dependency(NCCL) +endif() include(${CMAKE_CURRENT_LIST_DIR}/oomph-targets.cmake) diff --git a/cmake/oomph_nccl.cmake b/cmake/oomph_nccl.cmake new file mode 100644 index 00000000..909280f2 --- /dev/null +++ b/cmake/oomph_nccl.cmake @@ -0,0 +1,19 @@ +# set all NCCL related options and values + +#------------------------------------------------------------------------------ +# Enable NCCL support +#------------------------------------------------------------------------------ +set(OOMPH_WITH_NCCL OFF CACHE BOOL "Build with NCCL backend") + +if (OOMPH_WITH_NCCL) + find_package(CUDAToolkit REQUIRED) + find_package(NCCL REQUIRED) + add_library(oomph_nccl SHARED) + add_library(oomph::nccl ALIAS oomph_nccl) + oomph_shared_lib_options(oomph_nccl) + target_link_libraries(oomph_nccl PUBLIC NCCL::nccl CUDA::cudart) + install(TARGETS oomph_nccl + EXPORT oomph-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif() diff --git a/include/oomph/communicator.hpp b/include/oomph/communicator.hpp index 71d9908c..0d73448c 100644 --- a/include/oomph/communicator.hpp +++ b/include/oomph/communicator.hpp @@ -98,6 +98,8 @@ class communicator return m_state->m_shared_scheduled_recvs->load(); } + bool is_stream_aware() const noexcept; + bool is_ready() const noexcept { return (scheduled_sends() == 0) && (scheduled_recvs() == 0) && @@ -143,6 +145,9 @@ class communicator } #endif + void start_group(); + void end_group(); + // no callback versions // ==================== @@ -150,33 +155,33 @@ class communicator // ---- template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag) + recv_request recv(message_buffer& msg, rank_type src, tag_type tag, void* stream = nullptr) { assert(msg); return recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // shared_recv // ----------- template - shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag) + shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag, void* stream = nullptr) { assert(msg); return shared_recv(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), src, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // send // ---- template - send_request send(message_buffer const& msg, rank_type dst, tag_type tag) + send_request send(message_buffer const& msg, rank_type dst, tag_type tag, void* stream = nullptr) { assert(msg); return send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), dst, tag, - util::unique_function([](rank_type, tag_type) {})); + util::unique_function([](rank_type, tag_type) {}), stream); } // send_multi @@ -184,7 +189,7 @@ class communicator template send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - std::size_t neighs_size, tag_type tag) + std::size_t neighs_size, tag_type tag, void* stream = nullptr) { assert(msg); auto mrs = m_state->make_multi_request_state(neighs_size); @@ -192,21 +197,21 @@ class communicator { send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tag, util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + [mrs](rank_type, tag_type) { --(mrs->m_counter); }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, tag_type tag) + std::vector const& neighs, tag_type tag, void* stream = nullptr) { - return send_multi(msg, neighs.data(), neighs.size(), tag); + return send_multi(msg, neighs.data(), neighs.size(), tag, stream); } template send_multi_request send_multi(message_buffer const& msg, rank_type const* neighs, - tag_type const* tags, std::size_t neighs_size) + tag_type const* tags, std::size_t neighs_size, void* stream = nullptr) { assert(msg); auto mrs = m_state->make_multi_request_state(neighs_size); @@ -214,17 +219,17 @@ class communicator { send(msg.m.m_heap_ptr.get(), msg.size() * sizeof(T), neighs[i], tags[i], util::unique_function( - [mrs](rank_type, tag_type) { --(mrs->m_counter); })); + [mrs](rank_type, tag_type) { --(mrs->m_counter); }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, - std::vector const& neighs, std::vector const& tags) + std::vector const& neighs, std::vector const& tags, void* stream = nullptr) { assert(neighs.size() == tags.size()); - return send_multi(msg, neighs.data(), tags.data(), neighs.size()); + return send_multi(msg, neighs.data(), tags.data(), neighs.size(), stream); } // callback versions @@ -234,7 +239,7 @@ class communicator // ---- template - recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback) + recv_request recv(message_buffer&& msg, rank_type src, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -242,11 +247,11 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template - recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback) + recv_request recv(message_buffer& msg, rank_type src, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -254,7 +259,7 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } // shared_recv @@ -262,7 +267,7 @@ class communicator template shared_recv_request shared_recv(message_buffer&& msg, rank_type src, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -270,12 +275,12 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template shared_recv_request shared_recv(message_buffer& msg, rank_type src, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -283,14 +288,14 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return shared_recv(m_ptr, s * sizeof(T), src, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } // send // ---- template - send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback) + send_request send(message_buffer&& msg, rank_type dst, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK(CallBack) assert(msg); @@ -298,11 +303,11 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_rref{std::forward(callback), std::move(msg)})); + cb_rref{std::forward(callback), std::move(msg)}), stream); } template - send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback) + send_request send(message_buffer& msg, rank_type dst, tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_REF(CallBack) assert(msg); @@ -310,12 +315,12 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_lref{std::forward(callback), &msg})); + cb_lref{std::forward(callback), &msg}), stream); } template send_request send(message_buffer const& msg, rank_type dst, tag_type tag, - CallBack&& callback) + CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_CONST_REF(CallBack) assert(msg); @@ -323,7 +328,7 @@ class communicator auto m_ptr = msg.m.m_heap_ptr.get(); return send(m_ptr, s * sizeof(T), dst, tag, util::unique_function( - cb_lref_const{std::forward(callback), &msg})); + cb_lref_const{std::forward(callback), &msg}), stream); } // send_multi @@ -331,7 +336,7 @@ class communicator template send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI(CallBack) assert(msg); @@ -349,14 +354,14 @@ class communicator callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer&& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_TAGS(CallBack) assert(msg); @@ -377,14 +382,14 @@ class communicator callback(message_buffer(std::move(mrs->m_msg), mrs->m_msg_size), std::move(mrs->m_neighs), mrs->m_tags); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_REF(CallBack) assert(msg); @@ -402,14 +407,14 @@ class communicator callback(*reinterpret_cast*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_REF_TAGS(CallBack) assert(msg); @@ -429,14 +434,14 @@ class communicator callback(*reinterpret_cast*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), std::move(mrs->m_tags)); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - tag_type tag, CallBack&& callback) + tag_type tag, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_CONST_REF(CallBack) assert(msg); @@ -454,14 +459,14 @@ class communicator callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), t); } - })); + }), stream); } return {std::move(mrs)}; } template send_multi_request send_multi(message_buffer const& msg, std::vector neighs, - std::vector tags, CallBack&& callback) + std::vector tags, CallBack&& callback, void* stream = nullptr) { OOMPH_CHECK_CALLBACK_MULTI_CONST_REF_TAGS(CallBack) assert(msg); @@ -481,7 +486,7 @@ class communicator callback(*reinterpret_cast const*>(mrs->m_msg_ptr), std::move(mrs->m_neighs), std::move(mrs->m_tags)); } - })); + }), stream); } return {std::move(mrs)}; } @@ -499,13 +504,13 @@ class communicator #endif send_request send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb); + rank_type dst, tag_type tag, util::unique_function&& cb, void* stream); recv_request recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb); + tag_type tag, util::unique_function&& cb, void* stream); shared_recv_request shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb); + rank_type src, tag_type tag, util::unique_function&& cb, void* stream); }; } // namespace oomph diff --git a/include/oomph/context.hpp b/include/oomph/context.hpp index cc74344a..a19fffcf 100644 --- a/include/oomph/context.hpp +++ b/include/oomph/context.hpp @@ -104,7 +104,7 @@ class context //unsigned int num_tag_ranges() const noexcept { return m_tag_range_factory.num_ranges(); } - const char* get_transport_option(const std::string& opt); + const char* get_transport_option(const std::string& opt) const; private: detail::message_buffer make_buffer_core(std::size_t size); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ffc2d2b0..affb05cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,3 +22,7 @@ endif() if (OOMPH_WITH_LIBFABRIC) add_subdirectory(libfabric) endif() + +if (OOMPH_WITH_NCCL) + add_subdirectory(nccl) +endif() diff --git a/src/communicator.cpp b/src/communicator.cpp index 823042cc..4b764fa8 100644 --- a/src/communicator.cpp +++ b/src/communicator.cpp @@ -45,34 +45,52 @@ communicator::mpi_comm() const noexcept return m_state->m_impl->mpi_comm(); } +bool +communicator::is_stream_aware() const noexcept +{ + return m_state->m_impl->is_stream_aware(); +} + void communicator::progress() { m_state->m_impl->progress(); } +void +communicator::start_group() +{ + return m_state->m_impl->start_group(); +} + +void +communicator::end_group() +{ + return m_state->m_impl->end_group(); +} + send_request communicator::send(detail::message_buffer::heap_ptr_impl const* m_ptr, std::size_t size, - rank_type dst, tag_type tag, util::unique_function&& cb) + rank_type dst, tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->send(m_ptr->m, size, dst, tag, std::move(cb), - &(m_state->scheduled_sends)); + &(m_state->scheduled_sends), stream); } recv_request communicator::recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, rank_type src, - tag_type tag, util::unique_function&& cb) + tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->recv(m_ptr->m, size, src, tag, std::move(cb), - &(m_state->scheduled_recvs)); + &(m_state->scheduled_recvs), stream); } shared_recv_request communicator::shared_recv(detail::message_buffer::heap_ptr_impl* m_ptr, std::size_t size, - rank_type src, tag_type tag, util::unique_function&& cb) + rank_type src, tag_type tag, util::unique_function&& cb, void* stream) { return m_state->m_impl->shared_recv(m_ptr->m, size, src, tag, std::move(cb), - m_state->m_shared_scheduled_recvs); + m_state->m_shared_scheduled_recvs, stream); } detail::message_buffer diff --git a/src/context.cpp b/src/context.cpp index 73d55516..b1063e04 100644 --- a/src/context.cpp +++ b/src/context.cpp @@ -68,7 +68,7 @@ context::local_size() const noexcept } const char* -context::get_transport_option(const std::string& opt) +context::get_transport_option(const std::string& opt) const { return m->get_transport_option(opt); } diff --git a/src/libfabric/communicator.hpp b/src/libfabric/communicator.hpp index ff8fc945..2e9d2713 100644 --- a/src/libfabric/communicator.hpp +++ b/src/libfabric/communicator.hpp @@ -75,6 +75,11 @@ class communicator_impl : public communicator_base // -------------------------------------------------------------------- auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + // -------------------------------------------------------------------- /// generate a tag with 0xRRRRRRRRtttttttt rank, tag. /// original tag can be 32bits, then we add 32bits of rank info. @@ -169,7 +174,7 @@ class communicator_impl : public communicator_base // -------------------------------------------------------------------- send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, [[maybe_unused]] void* stream) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*this->rank(), */ this->m_context->get_context_tag()); @@ -242,7 +247,7 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, oomph::tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, [[maybe_unused]] void* stream) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); @@ -295,7 +300,8 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, oomph::tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, + [[maybe_unused]] void* stream) { [[maybe_unused]] auto scp = com_deb<9>.scope(NS_DEBUG::ptr(this), __func__); std::uint64_t stag = make_tag64(tag, /*src, */ this->m_context->get_context_tag()); diff --git a/src/libfabric/context.cpp b/src/libfabric/context.cpp index 5621a83b..1b4a7674 100644 --- a/src/libfabric/context.cpp +++ b/src/libfabric/context.cpp @@ -61,7 +61,7 @@ context_impl::get_communicator() } const char* -context_impl::get_transport_option(const std::string& opt) +context_impl::get_transport_option(const std::string& opt) const { if (opt == "name") { return "libfabric"; } else if (opt == "progress") { return libfabric_progress_string(); } diff --git a/src/libfabric/context.hpp b/src/libfabric/context.hpp index a7c0c112..6d1f6acb 100644 --- a/src/libfabric/context.hpp +++ b/src/libfabric/context.hpp @@ -82,7 +82,7 @@ class context_impl : public context_base inline std::uintptr_t get_context_tag() { return m_ctxt_tag; } inline controller_type* get_controller() /*const */ { return m_controller.get(); } - const char* get_transport_option(const std::string& opt); + const char* get_transport_option(const std::string& opt) const; void progress() { get_controller()->poll_for_work_completions(nullptr); } diff --git a/src/mpi/communicator.hpp b/src/mpi/communicator.hpp index 0022b157..a9f3115e 100644 --- a/src/mpi/communicator.hpp +++ b/src/mpi/communicator.hpp @@ -34,8 +34,13 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + mpi_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, - tag_type tag) + tag_type tag, [[maybe_unused]] void* stream) { MPI_Request r; const_device_guard dg(ptr); @@ -44,7 +49,7 @@ class communicator_impl : public communicator_base } mpi_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, - tag_type tag) + tag_type tag, [[maybe_unused]] void* stream) { MPI_Request r; device_guard dg(ptr); @@ -54,9 +59,9 @@ class communicator_impl : public communicator_base send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void* stream) { - auto req = send(ptr, size, dst, tag); + auto req = send(ptr, size, dst, tag, stream); if (!has_reached_recursion_depth() && req.is_ready()) { auto inc = recursion(); @@ -75,9 +80,9 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); if (!has_reached_recursion_depth() && req.is_ready()) { auto inc = recursion(); @@ -96,9 +101,9 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, void* stream) { - auto req = recv(ptr, size, src, tag); + auto req = recv(ptr, size, src, tag, stream); if (!m_context->has_reached_recursion_depth() && req.is_ready()) { auto inc = m_context->recursion(); diff --git a/src/mpi/context.cpp b/src/mpi/context.cpp index 9f3273d4..8d0b8736 100644 --- a/src/mpi/context.cpp +++ b/src/mpi/context.cpp @@ -22,7 +22,7 @@ context_impl::get_communicator() return comm; } -const char *context_impl::get_transport_option(const std::string &opt) { +const char *context_impl::get_transport_option(const std::string &opt) const { if (opt == "name") { return "mpi"; } diff --git a/src/mpi/context.hpp b/src/mpi/context.hpp index 53f1e81d..f322d2c4 100644 --- a/src/mpi/context.hpp +++ b/src/mpi/context.hpp @@ -83,7 +83,7 @@ class context_impl : public context_base unsigned int num_tag_bits() const noexcept { return m_n_tag_bits; } - const char* get_transport_option(const std::string& opt); + const char* get_transport_option(const std::string& opt) const; }; template<> diff --git a/src/nccl/CMakeLists.txt b/src/nccl/CMakeLists.txt new file mode 100644 index 00000000..51bc3105 --- /dev/null +++ b/src/nccl/CMakeLists.txt @@ -0,0 +1,9 @@ +add_library(oomph_private_nccl_headers INTERFACE) +target_include_directories(oomph_private_nccl_headers INTERFACE + "$") +target_link_libraries(oomph_nccl PRIVATE oomph_private_nccl_headers) + +list(TRANSFORM oomph_sources PREPEND ${CMAKE_CURRENT_SOURCE_DIR}/../ + OUTPUT_VARIABLE oomph_sources_nccl) +target_sources(oomph_nccl PRIVATE ${oomph_sources_nccl}) +target_sources(oomph_nccl PRIVATE context.cpp cuda_event_pool.cpp) diff --git a/src/nccl/cached_cuda_event.hpp b/src/nccl/cached_cuda_event.hpp new file mode 100644 index 00000000..a62c3f3c --- /dev/null +++ b/src/nccl/cached_cuda_event.hpp @@ -0,0 +1,46 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include "cuda_event.hpp" +#include "cuda_event_pool.hpp" + +namespace oomph::detail +{ +// A cuda_event backed by a cuda_event_pool. +// +// Same semantics as cuda_event, but the event is retrieved from a static +// cuda_event_pool on construction and returned to the pool on destruction. +struct cached_cuda_event +{ + cuda_event m_event; + + cached_cuda_event() + : m_event(get_cuda_event_pool().pop()) + { + } + cached_cuda_event(cached_cuda_event&& other) noexcept = default; + cached_cuda_event& operator=(cached_cuda_event&& other) noexcept = default; + cached_cuda_event(const cached_cuda_event&) = default; + cached_cuda_event& operator=(const cached_cuda_event&) = default; + ~cached_cuda_event() noexcept + { + if (m_event) { get_cuda_event_pool().push(std::move(m_event)); } + } + + operator bool() noexcept { return bool(m_event); } + + void record(cudaStream_t stream) { return m_event.record(stream); } + + bool is_ready() const { return m_event.is_ready(); } + + cudaEvent_t get() { return m_event.get(); } +}; +} // namespace oomph::detail diff --git a/src/nccl/communicator.hpp b/src/nccl/communicator.hpp new file mode 100644 index 00000000..bf639e40 --- /dev/null +++ b/src/nccl/communicator.hpp @@ -0,0 +1,203 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include + +#include + +#include + +// paths relative to backend +#include "../communicator_base.hpp" +#include "../device_guard.hpp" +#include "./context.hpp" +#include "cached_cuda_event.hpp" +#include "group_cuda_event.hpp" +#include "request.hpp" +#include "request_queue.hpp" +#include "request_state.hpp" + +namespace oomph +{ +class communicator_impl : public communicator_base +{ + public: + context_impl* m_context; + request_queue m_send_reqs; + request_queue m_recv_reqs; + + private: + struct group_info + { + // A shared CUDA event used for synchronization at the end of the NCCL + // group. All streams used within the group are waited for before the + // group kernel starts and all streams can be used to wait for the + // completion of the group kernel. From + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/streams.html: + // + // NCCL allows for using multiple streams within a group call. This will + // enforce a stream dependency of all streams before the NCCL kernel + // starts and block all streams until the NCCL kernel completes. + // + // It will behave as if the NCCL group operation was posted on every + // stream, but given it is a single operation, it will cause a global + // synchronization point between the streams. + detail::group_cuda_event m_event{}; + + // We arbitrarily use the last stream used within a group to synchronize + // the whole group. + cudaStream_t m_last_stream{}; + }; + + // NCCL group information. When no group is active this is std::nullopt. + // When a group is active it contains information used for synchronizing + // with the end of the group kernel. + std::optional m_group_info; + + bool is_group_active() const noexcept { return m_group_info.has_value(); } + + public: + communicator_impl(context_impl* ctxt) + : communicator_base(ctxt) + , m_context(ctxt) + { + } + + auto& get_heap() noexcept { return m_context->get_heap(); } + + bool is_stream_aware() const noexcept { return true; } + + void start_group() + { + assert(!is_group_active()); + + OOMPH_CHECK_NCCL_RESULT(ncclGroupStart()); + m_group_info.emplace(); + } + + void end_group() + { + assert(is_group_active()); + + OOMPH_CHECK_NCCL_RESULT(ncclGroupEnd()); + + // All streams used in a NCCL group synchronize with the end of the group. + // We arbitrarily pick the last stream to synchronize against. + m_group_info->m_event.record(m_group_info->m_last_stream); + m_group_info.reset(); + } + + nccl_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, + [[maybe_unused]] tag_type tag, void* stream) + { + const_device_guard dg(ptr); + OOMPH_CHECK_NCCL_RESULT(ncclSend(dg.data(), size, ncclChar, dst, m_context->get_comm(), + static_cast(stream))); + + if (is_group_active()) + { + m_group_info->m_last_stream = static_cast(stream); + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. + return {m_group_info->m_event}; + } + else + { + detail::cached_cuda_event event; + event.record(static_cast(stream)); + return {std::move(event)}; + } + } + + nccl_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + [[maybe_unused]] tag_type tag, void* stream) + { + device_guard dg(ptr); + OOMPH_CHECK_NCCL_RESULT(ncclRecv(dg.data(), size, ncclChar, src, m_context->get_comm(), + static_cast(stream))); + + if (is_group_active()) + { + m_group_info->m_last_stream = static_cast(stream); + // The event is stored now, but recorded only in end_group. Until + // an event has been recorded the event is never ready. + return {m_group_info->m_event}; + } + else + { + detail::cached_cuda_event event; + event.record(static_cast(stream)); + return {std::move(event)}; + } + } + + send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, + void* stream) + { + auto req = send(ptr, size, dst, tag, stream); + auto s = m_req_state_factory.make(m_context, this, scheduled, dst, tag, std::move(cb), + std::move(req)); + s->create_self_ref(); + m_send_reqs.enqueue(s.get()); + return {std::move(s)}; + } + + recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, + tag_type tag, util::unique_function&& cb, std::size_t* scheduled, + void* stream) + { + auto req = recv(ptr, size, src, tag, stream); + auto s = m_req_state_factory.make(m_context, this, scheduled, src, tag, std::move(cb), + std::move(req)); + s->create_self_ref(); + m_recv_reqs.enqueue(s.get()); + return {std::move(s)}; + } + + shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, + rank_type src, tag_type tag, util::unique_function&& cb, + std::atomic* scheduled, void* stream) + { + auto req = recv(ptr, size, src, tag, stream); + auto s = std::make_shared(m_context, this, scheduled, src, + tag, std::move(cb), std::move(req)); + s->create_self_ref(); + m_context->m_req_queue.enqueue(s.get()); + return {std::move(s)}; + } + + void progress() + { + if (is_group_active()) + { + // If we're inside an active group no work has been submitted yet. + // We cannot make progress so we disallow calling progress. wait + // also calls progress and would deadlock if called within an active + // group. + throw std::logic_error( + "OOMPH Error: Calling progress while a NCCL group is active is not allowed"); + } + + // Communication progresses independently, but requests must be marked + // ready and callbacks must be invoked. + m_send_reqs.progress(); + m_recv_reqs.progress(); + m_context->progress(); + } + + bool cancel_recv(detail::request_state*) { return false; } +}; + +} // namespace oomph diff --git a/src/nccl/context.cpp b/src/nccl/context.cpp new file mode 100644 index 00000000..fdba81e9 --- /dev/null +++ b/src/nccl/context.cpp @@ -0,0 +1,32 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +// paths relative to backend +#include "context.hpp" +#include "communicator.hpp" + +namespace oomph +{ +communicator_impl* +context_impl::get_communicator() +{ + auto comm = new communicator_impl{this}; + m_comms_set.insert(comm); + return comm; +} + +const char* +context_impl::get_transport_option(const std::string& opt) const +{ + if (opt == "name") { return "nccl"; } + else { return "unspecified"; } +} + +} // namespace oomph diff --git a/src/nccl/context.hpp b/src/nccl/context.hpp new file mode 100644 index 00000000..33a1651f --- /dev/null +++ b/src/nccl/context.hpp @@ -0,0 +1,82 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +#include + +// paths relative to backend +#include "../context_base.hpp" +#include "nccl_communicator.hpp" +#include "region.hpp" +#include "request_queue.hpp" + +namespace oomph +{ +class context_impl : public context_base +{ + public: + using region_type = region; + using device_region_type = region; + using heap_type = hwmalloc::heap; + + private: + heap_type m_heap; + detail::nccl_comm m_comm; + + public: + shared_request_queue m_req_queue; + + public: + context_impl(MPI_Comm comm, bool thread_safe, hwmalloc::heap_config const& heap_config) + : context_base(comm, thread_safe) + , m_heap{this, heap_config} + , m_comm{oomph::detail::nccl_comm{comm}} + { + if (thread_safe) { throw std::runtime_error("NCCL not supported with thread_safe = true"); } + } + + context_impl(context_impl const&) = delete; + context_impl(context_impl&&) = delete; + + ncclComm_t get_comm() const noexcept { return m_comm.get(); } + + region make_region(void* ptr) const { return {ptr}; } + + auto& get_heap() noexcept { return m_heap; } + + communicator_impl* get_communicator(); + + void progress() { m_req_queue.progress(); } + + bool cancel_recv(detail::shared_request_state*) { return false; } + + const char* get_transport_option(const std::string& opt) const; +}; + +template<> +inline region +register_memory(context_impl& c, void* ptr, std::size_t) +{ + return c.make_region(ptr); +} + +#if OOMPH_ENABLE_DEVICE +template<> +inline region +register_device_memory(context_impl& c, int, void* ptr, std::size_t) +{ + return c.make_region(ptr); +} +#endif + +} // namespace oomph diff --git a/src/nccl/cuda_error.hpp b/src/nccl/cuda_error.hpp new file mode 100644 index 00000000..baf9a17b --- /dev/null +++ b/src/nccl/cuda_error.hpp @@ -0,0 +1,34 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include + +#include + +#define OOMPH_CHECK_CUDA_RESULT(x) \ + if (x != cudaSuccess) \ + throw std::runtime_error("OOMPH Error: CUDA Call failed " + std::string(#x) + " (" + \ + std::string(cudaGetErrorString(x)) + ") in " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); + +#define OOMPH_CHECK_CUDA_RESULT_NO_THROW(x) \ + try \ + { \ + OOMPH_CHECK_CUDA_RESULT(x) \ + } \ + catch (const std::exception& e) \ + { \ + std::cerr << e.what() << std::endl; \ + std::terminate(); \ + } diff --git a/src/nccl/cuda_event.hpp b/src/nccl/cuda_event.hpp new file mode 100644 index 00000000..01762bea --- /dev/null +++ b/src/nccl/cuda_event.hpp @@ -0,0 +1,74 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include + +#include + +#include "cuda_error.hpp" + +namespace oomph::detail +{ +// RAII wrapper for a cudaEvent_t. +// +// Move-only wrapper around cudaEvent_t that automatically destroys the +// underlying event on destruction. Can be used to record events on streams. +struct cuda_event +{ + cudaEvent_t m_event; + oomph::util::moved_bit m_moved; + bool m_recorded{false}; + + cuda_event() + { + OOMPH_CHECK_CUDA_RESULT(cudaEventCreateWithFlags(&m_event, cudaEventDisableTiming)); + } + cuda_event(cuda_event&& other) noexcept = default; + cuda_event& operator=(cuda_event&& other) noexcept = default; + cuda_event(const cuda_event&) = delete; + cuda_event& operator=(const cuda_event&) = delete; + ~cuda_event() noexcept + { + if (!m_moved) { OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaEventDestroy(m_event)); } + } + + operator bool() noexcept { return !m_moved; } + + void record(cudaStream_t stream) + { + assert(!m_moved); + OOMPH_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); + m_recorded = true; + } + + bool is_ready() const + { + if (m_moved || !m_recorded) { return false; } + + cudaError_t res = cudaEventQuery(m_event); + if (res == cudaSuccess) { return true; } + else if (res == cudaErrorNotReady) { return false; } + else + { + OOMPH_CHECK_CUDA_RESULT(res); + return false; + } + } + + cudaEvent_t get() + { + assert(!m_moved); + return m_event; + } +}; +} // namespace oomph::detail diff --git a/src/nccl/cuda_event_pool.cpp b/src/nccl/cuda_event_pool.cpp new file mode 100644 index 00000000..4b180dcb --- /dev/null +++ b/src/nccl/cuda_event_pool.cpp @@ -0,0 +1,21 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "cuda_event_pool.hpp" + +namespace oomph::detail +{ +cuda_event_pool& +get_cuda_event_pool() +{ + static cuda_event_pool pool{128}; + return pool; +} +} // namespace oomph::detail diff --git a/src/nccl/cuda_event_pool.hpp b/src/nccl/cuda_event_pool.hpp new file mode 100644 index 00000000..d669b83b --- /dev/null +++ b/src/nccl/cuda_event_pool.hpp @@ -0,0 +1,65 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +#include + +#include + +#include "cuda_error.hpp" +#include "cuda_event.hpp" + +namespace oomph::detail +{ +// Pool of cuda_events. +// +// Simple wrapper over a vector of cuda_events. Events can be popped from the +// pool. New events are created if the pool is empty. Events can be returned to +// the pool for reuse. Events do not need to originate from the pool. Not +// thread-safe. +class cuda_event_pool +{ + private: + std::vector m_events; + + public: + cuda_event_pool(std::size_t expected_pool_size) + : m_events(expected_pool_size) + { + } + + cuda_event_pool(const cuda_event_pool&) = delete; + cuda_event_pool& operator=(const cuda_event_pool&) = delete; + cuda_event_pool(cuda_event_pool&& other) noexcept = delete; + cuda_event_pool& operator=(cuda_event_pool&&) noexcept = delete; + + public: + cuda_event pop() + { + if (m_events.empty()) { return {}; } + else + { + auto event{std::move(m_events.back())}; + m_events.pop_back(); + return event; + } + } + + void push(cuda_event&& event) { m_events.push_back(std::move(event)); } + void clear() { m_events.clear(); } +}; + +// Get a static instance of a cuda_event_pool. +cuda_event_pool& get_cuda_event_pool(); +} // namespace oomph::detail diff --git a/src/nccl/group_cuda_event.hpp b/src/nccl/group_cuda_event.hpp new file mode 100644 index 00000000..eb60e80b --- /dev/null +++ b/src/nccl/group_cuda_event.hpp @@ -0,0 +1,41 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include "cached_cuda_event.hpp" + +namespace oomph::detail +{ +// A shared cuda_event suitable for use with NCCL groups. +// +// A cached_cuda_event stored in a shared_ptr for shared usage between multiple +// requests. +struct group_cuda_event +{ + std::shared_ptr m_event; + + group_cuda_event() + : m_event(std::make_shared()) + { + } + group_cuda_event(const group_cuda_event&) = default; + group_cuda_event& operator=(const group_cuda_event&) = default; + group_cuda_event(group_cuda_event&&) = default; + group_cuda_event& operator=(group_cuda_event&&) = default; + + void record(cudaStream_t stream) { m_event->record(stream); } + + bool is_ready() { return m_event->is_ready(); } + + cudaEvent_t get() { return m_event->get(); } +}; +} // namespace oomph::detail diff --git a/src/nccl/handle.hpp b/src/nccl/handle.hpp new file mode 100644 index 00000000..9527592e --- /dev/null +++ b/src/nccl/handle.hpp @@ -0,0 +1,21 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +namespace oomph +{ +struct handle +{ + void* m_ptr; + std::size_t m_size; +}; +} // namespace oomph diff --git a/src/nccl/nccl_communicator.hpp b/src/nccl/nccl_communicator.hpp new file mode 100644 index 00000000..2719c1a0 --- /dev/null +++ b/src/nccl/nccl_communicator.hpp @@ -0,0 +1,58 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include +#include + +#include "../mpi_comm.hpp" +#include "cuda_error.hpp" +#include "nccl_error.hpp" + +namespace oomph::detail +{ +class nccl_comm +{ + ncclComm_t m_comm; + oomph::util::moved_bit m_moved; + + public: + nccl_comm(mpi_comm mpi_comm) + { + ncclUniqueId id; + if (mpi_comm.rank() == 0) { OOMPH_CHECK_NCCL_RESULT(ncclGetUniqueId(&id)); } + + OOMPH_CHECK_MPI_RESULT(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm.get())); + + OOMPH_CHECK_NCCL_RESULT(ncclCommInitRank(&m_comm, mpi_comm.size(), id, mpi_comm.rank())); + ncclResult_t result; + do { + OOMPH_CHECK_NCCL_RESULT(ncclCommGetAsyncError(m_comm, &result)); + } while (result == ncclInProgress); + OOMPH_CHECK_NCCL_RESULT(result); + } + nccl_comm(nccl_comm&&) noexcept = default; + nccl_comm& operator=(nccl_comm&&) noexcept = default; + nccl_comm(nccl_comm const&) = delete; + nccl_comm& operator=(nccl_comm const&) = delete; + ~nccl_comm() noexcept + { + if (!m_moved) + { + OOMPH_CHECK_CUDA_RESULT_NO_THROW(cudaDeviceSynchronize()); + OOMPH_CHECK_NCCL_RESULT_NO_THROW(ncclCommDestroy(m_comm)); + } + } + + ncclComm_t get() const noexcept { return m_comm; } +}; +} // namespace oomph::detail diff --git a/src/nccl/nccl_error.hpp b/src/nccl/nccl_error.hpp new file mode 100644 index 00000000..ca4cbe3b --- /dev/null +++ b/src/nccl/nccl_error.hpp @@ -0,0 +1,36 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +#include + +#define OOMPH_CHECK_NCCL_RESULT(x) \ + { \ + ncclResult_t r = x; \ + if (r != ncclSuccess && r != ncclInProgress) \ + throw std::runtime_error("OOMPH Error: NCCL Call failed " + std::string(#x) + " = " + \ + std::to_string(r) + " (\"" + ncclGetErrorString(r) + \ + "\") in " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__)); \ + } +#define OOMPH_CHECK_NCCL_RESULT_NO_THROW(x) \ + try \ + { \ + OOMPH_CHECK_NCCL_RESULT(x) \ + } \ + catch (const std::exception& e) \ + { \ + std::cerr << e.what() << std::endl; \ + std::terminate(); \ + } diff --git a/src/nccl/region.hpp b/src/nccl/region.hpp new file mode 100644 index 00000000..71a84f87 --- /dev/null +++ b/src/nccl/region.hpp @@ -0,0 +1,44 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +// paths relative to backend +#include "handle.hpp" + +namespace oomph +{ +class region +{ + public: + using handle_type = handle; + + private: + void* m_ptr; + + public: + region(void* ptr) + : m_ptr{ptr} + { + } + + region(region const&) = delete; + + region(region&& r) noexcept + : m_ptr{std::exchange(r.m_ptr, nullptr)} + { + } + + // get a handle to some portion of the region + handle_type get_handle(std::size_t offset, std::size_t size) + { + return {(void*)((char*)m_ptr + offset), size}; + } +}; +} // namespace oomph diff --git a/src/nccl/request.hpp b/src/nccl/request.hpp new file mode 100644 index 00000000..e6c24c7d --- /dev/null +++ b/src/nccl/request.hpp @@ -0,0 +1,33 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include + +#include "cuda_error.hpp" +#include "cuda_event.hpp" +#include "group_cuda_event.hpp" + +namespace oomph +{ +struct nccl_request +{ + bool is_ready() + { + return std::visit([](auto& event) { return event.is_ready(); }, m_event); + } + + // We store either a single event for a particular request, or a shared + // event that signals the end of a NCCL group. + std::variant m_event; +}; +} // namespace oomph diff --git a/src/nccl/request_queue.hpp b/src/nccl/request_queue.hpp new file mode 100644 index 00000000..b6e1d1e1 --- /dev/null +++ b/src/nccl/request_queue.hpp @@ -0,0 +1,157 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +#include + +// paths relative to backend +#include "request_state.hpp" + +namespace oomph +{ +class request_queue +{ + private: + using element_type = detail::request_state; + using queue_type = std::vector; + + private: // members + queue_type m_queue; + queue_type m_ready_queue; + bool in_progress = false; + + public: // ctors + request_queue() + { + m_queue.reserve(256); + m_ready_queue.reserve(256); + } + + public: // member functions + std::size_t size() const noexcept { return m_queue.size(); } + + void enqueue(element_type* e) + { + m_queue.push_back(e); + } + + int progress() + { + if (in_progress) return 0; + in_progress = true; + + const auto qs = size(); + if (qs == 0) + { + in_progress = false; + return 0; + } + + m_ready_queue.clear(); + m_ready_queue.reserve(qs); + std::size_t num_not_ready = 0; + for (std::size_t i = 0; i < qs; ++i) + { + auto* req = m_queue[i]; + if (req->m_req.is_ready()) { m_ready_queue.push_back(req); } + else + { + if (num_not_ready != i) + { + m_queue[num_not_ready] = req; + } + ++num_not_ready; + } + } + m_queue.erase(m_queue.begin() + num_not_ready, m_queue.end()); + + int completed = m_ready_queue.size(); + for (auto* req : m_ready_queue) + { + auto ptr = req->release_self_ref(); + req->invoke_cb(); + } + + in_progress = false; + + return completed; + } + + bool cancel(element_type*) { return false; } +}; + +class shared_request_queue +{ + private: + using element_type = detail::shared_request_state; + using queue_type = boost::lockfree::queue, + boost::lockfree::allocator>>; + + private: // members + queue_type m_queue; + std::atomic m_size; + + public: // ctors + shared_request_queue() + : m_queue(256) + , m_size(0) + { + } + + public: // member functions + std::size_t size() const noexcept { return m_size.load(); } + + void enqueue(element_type* e) + { + m_queue.push(e); + ++m_size; + } + + int progress() + { + static thread_local bool in_progress = false; + static thread_local std::vector m_local_queue; + int found = 0; + + if (in_progress) return 0; + in_progress = true; + + element_type* e; + while (m_queue.pop(e)) + { + if (e->m_req.is_ready()) + { + found = 1; + break; + } + else { m_local_queue.push_back(e); } + } + + for (auto x : m_local_queue) m_queue.push(x); + m_local_queue.clear(); + + if (found) { --m_size; } + + in_progress = false; + + if (found) + { + auto ptr = e->release_self_ref(); + e->invoke_cb(); + } + + return found; + } + + bool cancel(element_type*) { return false; } +}; +} // namespace oomph diff --git a/src/nccl/request_state.hpp b/src/nccl/request_state.hpp new file mode 100644 index 00000000..e0ac23a6 --- /dev/null +++ b/src/nccl/request_state.hpp @@ -0,0 +1,93 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2025, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include + +// paths relative to backend +#include "../request_state_base.hpp" +#include "request.hpp" + +namespace oomph::detail +{ +struct request_state +: public util::enable_shared_from_this +, public request_state_base +{ + using base = request_state_base; + using shared_ptr_t = util::unsafe_shared_ptr; + + nccl_request m_req; + shared_ptr_t m_self_ptr; + std::size_t m_index; + + request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, std::size_t* scheduled, + rank_type rank, tag_type tag, cb_type&& cb, nccl_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{std::move(m)} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } +}; + +struct shared_request_state +: public std::enable_shared_from_this +, public request_state_base +{ + using base = request_state_base; + using shared_ptr_t = std::shared_ptr; + + nccl_request m_req; + shared_ptr_t m_self_ptr; + + shared_request_state(oomph::context_impl* ctxt, oomph::communicator_impl* comm, + std::atomic* scheduled, rank_type rank, tag_type tag, cb_type&& cb, + nccl_request m) + : base{ctxt, comm, scheduled, rank, tag, std::move(cb)} + , m_req{std::move(m)} + { + } + + void progress(); + + bool cancel(); + + void create_self_ref() + { + // create a self-reference cycle!! + // this is useful if we only keep a raw pointer around internally, which still is supposed + // to keep the object alive + m_self_ptr = shared_from_this(); + } + + shared_ptr_t release_self_ref() noexcept + { + assert(((bool)m_self_ptr) && "doesn't own a self-reference!"); + return std::move(m_self_ptr); + } +}; +} // namespace oomph::detail diff --git a/src/ucx/communicator.hpp b/src/ucx/communicator.hpp index dcb4a4ac..cb2720b0 100644 --- a/src/ucx/communicator.hpp +++ b/src/ucx/communicator.hpp @@ -70,6 +70,11 @@ class communicator_impl : public communicator_base auto& get_heap() noexcept { return m_context->get_heap(); } + bool is_stream_aware() const noexcept { return false; } + + void start_group() {} + void end_group() {} + void progress() { while (ucp_worker_progress(m_send_worker->get())) {} @@ -124,7 +129,7 @@ class communicator_impl : public communicator_base send_request send(context_impl::heap_type::pointer const& ptr, std::size_t size, rank_type dst, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, [[maybe_unused]] void* stream) { const auto& ep = m_send_worker->connect(dst); const auto stag = @@ -186,7 +191,7 @@ class communicator_impl : public communicator_base recv_request recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::size_t* scheduled) + std::size_t* scheduled, [[maybe_unused]] void* stream) { const auto rtag = (communicator::any_source == src) @@ -258,7 +263,7 @@ class communicator_impl : public communicator_base shared_recv_request shared_recv(context_impl::heap_type::pointer& ptr, std::size_t size, rank_type src, tag_type tag, util::unique_function&& cb, - std::atomic* scheduled) + std::atomic* scheduled, [[maybe_unused]] void* stream) { const auto rtag = (communicator::any_source == src) diff --git a/src/ucx/context.cpp b/src/ucx/context.cpp index 8a93faea..27d1dca2 100644 --- a/src/ucx/context.cpp +++ b/src/ucx/context.cpp @@ -96,7 +96,7 @@ context_impl::~context_impl() } const char* -context_impl::get_transport_option(const std::string& opt) +context_impl::get_transport_option(const std::string& opt) const { if (opt == "name") { return "ucx"; } else { return "unspecified"; } diff --git a/src/ucx/context.hpp b/src/ucx/context.hpp index 2f790ae1..f440cb1d 100644 --- a/src/ucx/context.hpp +++ b/src/ucx/context.hpp @@ -238,7 +238,7 @@ class context_impl : public context_base return found; } - const char* get_transport_option(const std::string& opt); + const char* get_transport_option(const std::string& opt) const; unsigned int num_tag_bits() const noexcept { return OOMPH_UCX_TAG_BITS; } }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5217bbaf..7885aec0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,7 +10,7 @@ set(OOMPH_TEST_LEAK_GPU_MEMORY OFF CACHE BOOL "Do not free memory (bug on Piz Da set(serial_tests test_unique_function test_unsafe_shared_ptr) # list of parallel tests to be executed -set(parallel_tests test_context test_send_recv test_send_multi test_cancel test_locality) +set(parallel_tests test_context test_send_recv test_send_multi test_cancel test_locality test_group) #test_tag_range) if (OOMPH_ENABLE_BARRIER) list(APPEND parallel_tests test_barrier) @@ -48,6 +48,7 @@ function(reg_serial_test t) add_test( NAME ${t} COMMAND $) + set_tests_properties(${t} PROPERTIES LABELS "serial") endfunction() foreach(t ${serial_tests}) @@ -61,11 +62,15 @@ function(reg_parallel_test t_ lib n) oomph_target_compile_options(${t}) target_link_libraries(${t} PRIVATE gtest_main_mpi) target_link_libraries(${t} PRIVATE oomph_${lib}) - add_test( - NAME ${t} - COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} ${n} ${MPIEXEC_PREFLAGS} - $ ${MPIEXEC_POSTFLAGS}) - set_tests_properties(${t} PROPERTIES RUN_SERIAL TRUE) + if("${MPIEXEC_EXECUTABLE}" STREQUAL "") + add_test(NAME ${t} COMMAND $) + else() + add_test( + NAME ${t} + COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} ${n} ${MPIEXEC_PREFLAGS} + $ ${MPIEXEC_POSTFLAGS}) + endif() + set_tests_properties(${t} PROPERTIES RUN_SERIAL TRUE LABELS "parallel-ranks-${n}") endfunction() if (OOMPH_WITH_MPI) @@ -86,4 +91,10 @@ if (OOMPH_WITH_LIBFABRIC) endforeach() endif() +if (OOMPH_WITH_NCCL) + foreach(t ${parallel_tests}) + reg_parallel_test(${t} nccl 4) + endforeach() +endif() + add_subdirectory(bindings) diff --git a/test/bindings/fortran/CMakeLists.txt b/test/bindings/fortran/CMakeLists.txt index 974d2f7c..2a5980c5 100644 --- a/test/bindings/fortran/CMakeLists.txt +++ b/test/bindings/fortran/CMakeLists.txt @@ -25,12 +25,17 @@ function(reg_parallel_test_f t_ lib n nthr) $ $ $) - add_test( - NAME ${t} - COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} ${n} ${MPIEXEC_PREFLAGS} - $ ${MPIEXEC_POSTFLAGS}) + if("${MPIEXEC_EXECUTABLE}" STREQUAL "") + add_test(NAME ${t} COMMAND $) + else() + add_test( + NAME ${t} + COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} ${n} ${MPIEXEC_PREFLAGS} + $ ${MPIEXEC_POSTFLAGS}) + endif() set_tests_properties(${t} PROPERTIES - ENVIRONMENT OMP_NUM_THREADS=${nthr}) + ENVIRONMENT OMP_NUM_THREADS=${nthr} + LABELS "parallel-ranks-${n}") endfunction() if (OOMPH_WITH_MPI) diff --git a/test/test_barrier.cpp b/test/test_barrier.cpp index 3016c091..84ad2113 100644 --- a/test/test_barrier.cpp +++ b/test/test_barrier.cpp @@ -55,98 +55,138 @@ class test_barrier TEST_F(mpi_test_fixture, in_node1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); - - oomph::test_barrier{b}.test_in_node1(ctxt); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); + + oomph::test_barrier{b}.test_in_node1(ctxt); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } TEST_F(mpi_test_fixture, in_barrier_1) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); - for (int i = 0; i < 20; i++) { b.rank_barrier(); } + for (int i = 0; i < 20; i++) { b.rank_barrier(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } TEST_F(mpi_test_fixture, in_barrier) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&]() - { - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); - for (int i = 0; i < 10; i++) + auto work = [&]() { - comm.progress(); - b.thread_barrier(); - } - }; + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); + for (int i = 0; i < 10; i++) + { + comm.progress(); + b.thread_barrier(); + } + }; - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } TEST_F(mpi_test_fixture, full_barrier) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&]() - { - auto comm = ctxt.get_communicator(); - auto comm3 = ctxt.get_communicator(); - for (int i = 0; i < 10; i++) { b(); } - }; + auto work = [&]() + { + auto comm = ctxt.get_communicator(); + auto comm3 = ctxt.get_communicator(); + for (int i = 0; i < 10; i++) { b(); } + }; - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } TEST_F(mpi_test_fixture, full_barrier_sendrecv) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - barrier b(ctxt, n_threads); + std::size_t n_threads = 4; + barrier b(ctxt, n_threads); - auto work = [&](int tid) - { - auto comm = ctxt.get_communicator(); - auto comm2 = ctxt.get_communicator(); - int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); - int s_tag = comm.rank() * 10 + tid; - int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); - int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); - - auto s_buffer = comm.make_buffer(1000); - auto r_buffer = comm.make_buffer(1000); - for (auto& x : s_buffer) x = s_tag; - auto r_req = comm.recv(r_buffer, r_rank, r_tag); - auto s_req = comm.send(s_buffer, s_rank, s_tag); - b(); - while (!(r_req.test() && s_req.test())) {}; - b(); - }; - - std::vector ths; - for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work, i}); } - for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + auto work = [&](int tid) + { + auto comm = ctxt.get_communicator(); + auto comm2 = ctxt.get_communicator(); + int s_rank = (tid < 3) ? comm.rank() : ((comm.rank() + 1) % comm.size()); + int s_tag = comm.rank() * 10 + tid; + int r_rank = (tid > 0) ? comm.rank() : ((comm.rank() + comm.size() - 1) % comm.size()); + int r_tag = (tid > 0) ? (comm.rank() * 10 + tid - 1) : (r_rank * 10 + n_threads - 1); + + auto s_buffer = comm.make_buffer(1000); + auto r_buffer = comm.make_buffer(1000); + for (auto& x : s_buffer) x = s_tag; + auto r_req = comm.recv(r_buffer, r_rank, r_tag); + auto s_req = comm.send(s_buffer, s_rank, s_tag); + b(); + while (!(r_req.test() && s_req.test())) {}; + b(); + }; + + std::vector ths; + for (size_t i = 0; i < n_threads; ++i) { ths.push_back(std::thread{work, i}); } + for (size_t i = 0; i < n_threads; ++i) { ths[i].join(); } + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } diff --git a/test/test_cancel.cpp b/test/test_cancel.cpp index f00ed737..1ec5220c 100644 --- a/test/test_cancel.cpp +++ b/test/test_cancel.cpp @@ -65,6 +65,10 @@ TEST_F(mpi_test_fixture, test_cancel_request) { using namespace oomph; auto ctxt = context(MPI_COMM_WORLD, false); + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // NCCL does not support cancellation + return; + } auto comm = ctxt.get_communicator(); test_1(comm, 1); test_1(comm, 32); @@ -74,19 +78,27 @@ TEST_F(mpi_test_fixture, test_cancel_request) TEST_F(mpi_test_fixture, test_cancel_request_mt) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; ++i) - threads.push_back(std::thread{[&ctxt, i]() { - auto comm = ctxt.get_communicator(); - test_1(comm, 1, i); - test_1(comm, 32, i); - test_1(comm, 4096, i); - }}); - for (auto& t : threads) t.join(); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + + std::vector threads; + threads.reserve(n_threads); + for (size_t i = 0; i < n_threads; ++i) + threads.push_back(std::thread{[&ctxt, i]() { + auto comm = ctxt.get_communicator(); + test_1(comm, 1, i); + test_1(comm, 32, i); + test_1(comm, 4096, i); + }}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } void @@ -145,6 +157,10 @@ TEST_F(mpi_test_fixture, test_cancel_cb) { using namespace oomph; auto ctxt = context(MPI_COMM_WORLD, false); + if (ctxt.get_transport_option("name") == std::string("nccl")) { + // NCCL does not support cancellation + return; + } auto comm = ctxt.get_communicator(); test_2(comm, 1); test_2(comm, 32); @@ -154,17 +170,25 @@ TEST_F(mpi_test_fixture, test_cancel_cb) TEST_F(mpi_test_fixture, test_cancel_cb_mt) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); - std::size_t n_threads = 4; - - std::vector threads; - threads.reserve(n_threads); - for (size_t i = 0; i < n_threads; ++i) - threads.push_back(std::thread{[&ctxt, i]() { - auto comm = ctxt.get_communicator(); - test_2(comm, 1, i); - test_2(comm, 32, i); - test_2(comm, 4096, i); - }}); - for (auto& t : threads) t.join(); + try { + auto ctxt = context(MPI_COMM_WORLD, true); + std::size_t n_threads = 4; + + std::vector threads; + threads.reserve(n_threads); + for (size_t i = 0; i < n_threads; ++i) + threads.push_back(std::thread{[&ctxt, i]() { + auto comm = ctxt.get_communicator(); + test_2(comm, 1, i); + test_2(comm, 32, i); + test_2(comm, 4096, i); + }}); + for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } diff --git a/test/test_context.cpp b/test/test_context.cpp index 930c248a..845578e4 100644 --- a/test/test_context.cpp +++ b/test/test_context.cpp @@ -20,57 +20,65 @@ const int num_threads = 4; TEST_F(mpi_test_fixture, context_ordered) { using namespace oomph; - auto ctxt = context(MPI_COMM_WORLD, true); + try { + auto ctxt = context(MPI_COMM_WORLD, true); - //auto func = [&ctxt](int tid) - //{ - // auto comm = ctxt.get_communicator(); - // auto smsg_1 = comm.make_buffer(size); - // auto smsg_2 = comm.make_buffer(size); - // auto rmsg_1 = comm.make_buffer(size); - // auto rmsg_2 = comm.make_buffer(size); - // bool sent_1 = false; - // bool sent_2 = false; - // if (comm.rank() == 0) - // { - // const int payload_offset = 1 + tid; - // for (unsigned int i = 0; i < size; ++i) - // { - // smsg_1[i] = i + payload_offset; - // smsg_2[i] = i + payload_offset + 1; - // } - // std::vector neighs(comm.size()>1 ? comm.size() - 1 : 1, 0); - // for (int i = 1; i < comm.size(); ++i) neighs[i - 1] = i; + //auto func = [&ctxt](int tid) + //{ + // auto comm = ctxt.get_communicator(); + // auto smsg_1 = comm.make_buffer(size); + // auto smsg_2 = comm.make_buffer(size); + // auto rmsg_1 = comm.make_buffer(size); + // auto rmsg_2 = comm.make_buffer(size); + // bool sent_1 = false; + // bool sent_2 = false; + // if (comm.rank() == 0) + // { + // const int payload_offset = 1 + tid; + // for (unsigned int i = 0; i < size; ++i) + // { + // smsg_1[i] = i + payload_offset; + // smsg_2[i] = i + payload_offset + 1; + // } + // std::vector neighs(comm.size()>1 ? comm.size() - 1 : 1, 0); + // for (int i = 1; i < comm.size(); ++i) neighs[i - 1] = i; - // comm.send_multi(std::move(smsg_1), neighs, tid, - // [&sent_1](decltype(smsg_1), std::vector, tag_type) { sent_1 = true; }); + // comm.send_multi(std::move(smsg_1), neighs, tid, + // [&sent_1](decltype(smsg_1), std::vector, tag_type) { sent_1 = true; }); - // comm.send_multi(std::move(smsg_2), neighs, tid, - // [&sent_2](decltype(smsg_2), std::vector, tag_type) { sent_2 = true; }); + // comm.send_multi(std::move(smsg_2), neighs, tid, + // [&sent_2](decltype(smsg_2), std::vector, tag_type) { sent_2 = true; }); - // } - // if (comm.rank() > 0 || comm.size() == 1) - // { - // // ordered sends/recvs with same tag should arrive in order - // comm.recv(rmsg_1, 0, tid).wait(); - // comm.recv(rmsg_2, 0, tid).wait(); + // } + // if (comm.rank() > 0 || comm.size() == 1) + // { + // // ordered sends/recvs with same tag should arrive in order + // comm.recv(rmsg_1, 0, tid).wait(); + // comm.recv(rmsg_2, 0, tid).wait(); - // // check message - // const int payload_offset = 1 + tid; - // for (unsigned int i = 0; i < size; ++i) - // { - // EXPECT_EQ(rmsg_1[i], i + payload_offset); - // EXPECT_EQ(rmsg_2[i], i + payload_offset + 1); - // } - // } - // if (comm.rank() == 0) - // while (!sent_1 || !sent_2) { comm.progress(); } - //}; + // // check message + // const int payload_offset = 1 + tid; + // for (unsigned int i = 0; i < size; ++i) + // { + // EXPECT_EQ(rmsg_1[i], i + payload_offset); + // EXPECT_EQ(rmsg_2[i], i + payload_offset + 1); + // } + // } + // if (comm.rank() == 0) + // while (!sent_1 || !sent_2) { comm.progress(); } + //}; - //std::vector threads; - //threads.reserve(num_threads); - //for (int i = 0; i < num_threads; ++i) threads.push_back(std::thread{func, i}); - //for (auto& t : threads) t.join(); + //std::vector threads; + //threads.reserve(num_threads); + //for (int i = 0; i < num_threads; ++i) threads.push_back(std::thread{func, i}); + //for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (oomph::context(MPI_COMM_WORLD, false).get_transport_option("name") == std::string("nccl")) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } + } } //TEST_F(mpi_test_fixture, context_multi) diff --git a/test/test_group.cpp b/test/test_group.cpp new file mode 100644 index 00000000..31a47ffe --- /dev/null +++ b/test/test_group.cpp @@ -0,0 +1,53 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include "./mpi_runner/mpi_test_fixture.hpp" + +TEST_F(mpi_test_fixture, group_progress_wait) +{ + using namespace oomph; + auto ctxt = context(MPI_COMM_WORLD, false); + auto comm = ctxt.get_communicator(); + + int rank = comm.rank(); + int size = comm.size(); + int next_rank = (rank + 1) % size; + int prev_rank = (rank + size - 1) % size; + + auto buf_send = comm.make_buffer(1); + auto buf_recv = comm.make_buffer(1); + buf_send[0] = rank; + + comm.start_group(); + auto req_send = comm.send(buf_send, next_rank, 0); + auto req_recv = comm.recv(buf_recv, prev_rank, 0); + + if (ctxt.get_transport_option("name") == std::string("nccl")) + { + EXPECT_THROW(comm.progress(), std::logic_error); + EXPECT_THROW(req_send.wait(), std::logic_error); + EXPECT_THROW(req_recv.wait(), std::logic_error); + } + else + { + EXPECT_NO_THROW(comm.progress()); + EXPECT_NO_THROW(req_send.wait()); + EXPECT_NO_THROW(req_recv.wait()); + } + comm.end_group(); + + // For nccl, we threw during wait, so requests are not finished. + if (ctxt.get_transport_option("name") == std::string("nccl")) + { + req_send.wait(); + req_recv.wait(); + } +} diff --git a/test/test_send_recv.cpp b/test/test_send_recv.cpp index 0cfd1170..37e1c18d 100644 --- a/test/test_send_recv.cpp +++ b/test/test_send_recv.cpp @@ -20,9 +20,16 @@ #define NTHREADS 4 std::vector> shared_received(NTHREADS); -thread_local int thread_id; +thread_local int thread_id; -void reset_counters() +bool +is_nccl_backend(oomph::context const& ctxt) +{ + return ctxt.get_transport_option("name") == std::string("nccl"); +} + +void +reset_counters() { for (auto& x : shared_received) x.store(0); } @@ -192,7 +199,7 @@ launch_test(Func f) } // multi threaded - { + try { oomph::context ctxt(MPI_COMM_WORLD, true); std::vector threads; threads.reserve(NTHREADS); @@ -205,6 +212,12 @@ launch_test(Func f) for (int i = 0; i < NTHREADS; ++i) threads.push_back(std::thread{f, std::ref(ctxt), SIZE, i, NTHREADS, true}); for (auto& t : threads) t.join(); + } catch (std::runtime_error const& e) { + if (is_nccl_backend(oomph::context(MPI_COMM_WORLD, false))) { + EXPECT_EQ(e.what(), std::string("NCCL not supported with thread_safe = true")); + } else { + throw; + } } } @@ -219,8 +232,10 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, // use is_ready() -> must manually progress the communicator for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); while (!(rreq.is_ready() && sreq.is_ready())) { env.comm.progress(); @@ -232,8 +247,10 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); while (!(rreq.test() && sreq.test())) {}; EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -242,8 +259,11 @@ test_send_recv(oomph::context& ctxt, std::size_t size, int tid, int num_threads, // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rreq = env.comm.recv(env.rmsg, env.rpeer_rank, env.tag); - env.comm.send(env.smsg, env.speer_rank, env.tag).wait(); + auto sreq = env.comm.send(env.smsg, env.speer_rank, env.tag); + env.comm.end_group(); + sreq.wait(); rreq.wait(); EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -268,7 +288,7 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa using tag_type = test_environment::tag_type; using message = test_environment::message; - Env env(ctxt, size, tid, num_threads, user_alloc); + Env env(ctxt, size, tid, num_threads, user_alloc); volatile int received = 0; volatile int sent = 0; @@ -279,8 +299,10 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use is_ready() -> must manually progress the communicator for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -293,8 +315,10 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.test() || !sh.test()) {} EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -307,8 +331,11 @@ test_send_recv_cb(oomph::context& ctxt, std::size_t size, int tid, int num_threa // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(env.rmsg, env.rpeer_rank, 1, recv_callback); - env.comm.send(env.smsg, env.speer_rank, 1, send_callback).wait(); + auto sh = env.comm.send(env.smsg, env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); rh.wait(); EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -336,7 +363,7 @@ test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int nu using tag_type = test_environment::tag_type; using message = test_environment::message; - Env env(ctxt, size, tid, num_threads, user_alloc); + Env env(ctxt, size, tid, num_threads, user_alloc); volatile int received = 0; volatile int sent = 0; @@ -355,8 +382,10 @@ test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int nu // use is_ready() -> must manually progress the communicator for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -369,8 +398,10 @@ test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int nu // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.test() || !sh.test()) {} EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -383,8 +414,11 @@ test_send_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, int nu // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); rh.wait(); EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -412,7 +446,7 @@ test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, using tag_type = test_environment::tag_type; using message = test_environment::message; - Env env(ctxt, size, tid, num_threads, user_alloc); + Env env(ctxt, size, tid, num_threads, user_alloc); thread_id = env.thread_id; @@ -436,8 +470,10 @@ test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, // use is_ready() -> must manually progress the communicator for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.is_ready() || !sh.is_ready()) { env.comm.progress(); } EXPECT_TRUE(env.rmsg); EXPECT_TRUE(env.check_recv_buffer()); @@ -451,8 +487,10 @@ test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, // use test() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); while (!rh.test() || !sh.test()) {} EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -465,8 +503,11 @@ test_send_shared_recv_cb_disown(oomph::context& ctxt, std::size_t size, int tid, // use wait() -> communicator is progressed automatically for (int i = 0; i < NITERS; i++) { + env.comm.start_group(); auto rh = env.comm.shared_recv(std::move(env.rmsg), env.rpeer_rank, 1, recv_callback); - env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback).wait(); + auto sh = env.comm.send(std::move(env.smsg), env.speer_rank, 1, send_callback); + env.comm.end_group(); + sh.wait(); rh.wait(); EXPECT_TRUE(env.check_recv_buffer()); env.fill_recv_buffer(); @@ -494,7 +535,7 @@ test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int using tag_type = test_environment::tag_type; using message = test_environment::message; - Env env(ctxt, size, tid, num_threads, user_alloc); + Env env(ctxt, size, tid, num_threads, user_alloc); volatile int received = 0; volatile int sent = 0; @@ -525,8 +566,10 @@ test_send_recv_cb_resubmit(oomph::context& ctxt, std::size_t size, int tid, int } }; + env.comm.start_group(); env.comm.recv(env.rmsg, env.rpeer_rank, 1, recursive_recv_callback{env, received}); env.comm.send(env.smsg, env.speer_rank, 1, recursive_send_callback{env, sent}); + env.comm.end_group(); while (sent < NITERS || received < NITERS) { env.comm.progress(); }; } @@ -584,8 +627,10 @@ test_send_recv_cb_resubmit_disown(oomph::context& ctxt, std::size_t size, int ti } }; + env.comm.start_group(); env.comm.recv(std::move(env.rmsg), env.rpeer_rank, 1, recursive_recv_callback{env, received}); env.comm.send(std::move(env.smsg), env.speer_rank, 1, recursive_send_callback{env, sent}); + env.comm.end_group(); while (sent < NITERS || received < NITERS) { env.comm.progress(); }; }