diff --git a/.clang-tidy b/.clang-tidy index d7a7da71e..edef98736 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -20,7 +20,11 @@ Checks: > -modernize-macro-to-enum, google-readability-todo -#WarningsAsErrors: '*' +WarningsAsErrors: > + *, + -clang-diagnostic-unused-command-line-argument, + -clang-diagnostic-ignored-optimization-argument + HeaderFilterRegex: '.*\/include\/mrc\/.*' AnalyzeTemporaryDtors: false FormatStyle: file diff --git a/.devcontainer/conda/Dockerfile b/.devcontainer/conda/Dockerfile index d1ffbce77..62c801dd2 100644 --- a/.devcontainer/conda/Dockerfile +++ b/.devcontainer/conda/Dockerfile @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM rapidsai/devcontainers:23.04-cuda11.8-mambaforge-ubuntu22.04 AS base +FROM rapidsai/devcontainers:23.04-cuda12.1-mambaforge-ubuntu22.04 AS base ENV PATH="${PATH}:/workspaces/mrc/.devcontainer/bin" diff --git a/.devcontainer/opt/mrc/bin/post-attach-command.sh b/.devcontainer/opt/mrc/bin/post-attach-command.sh index eb00a5061..4af1fe68e 100755 --- a/.devcontainer/opt/mrc/bin/post-attach-command.sh +++ b/.devcontainer/opt/mrc/bin/post-attach-command.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,6 @@ sed -ri "s/conda activate base/conda activate $ENV_NAME/g" ~/.bashrc; if conda_env_find "${ENV_NAME}" ; \ -then mamba env update --name ${ENV_NAME} -f ${MRC_ROOT}/ci/conda/environments/dev_env.yml --prune; \ -else mamba env create --name ${ENV_NAME} -f ${MRC_ROOT}/ci/conda/environments/dev_env.yml; \ +then mamba env update --name ${ENV_NAME} -f ${MRC_ROOT}/conda/environments/all_cuda-125_arch-x86_64.yaml --prune; \ +else mamba env create --name ${ENV_NAME} -f ${MRC_ROOT}/conda/environments/all_cuda-125_arch-x86_64.yaml; \ fi diff --git a/.devcontainer/opt/mrc/conda/Dockerfile b/.devcontainer/opt/mrc/conda/Dockerfile index d1ffbce77..62c801dd2 100644 --- a/.devcontainer/opt/mrc/conda/Dockerfile +++ b/.devcontainer/opt/mrc/conda/Dockerfile @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM rapidsai/devcontainers:23.04-cuda11.8-mambaforge-ubuntu22.04 AS base +FROM rapidsai/devcontainers:23.04-cuda12.1-mambaforge-ubuntu22.04 AS base ENV PATH="${PATH}:/workspaces/mrc/.devcontainer/bin" diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml new file mode 100644 index 000000000..895ba83ee --- /dev/null +++ b/.github/copy-pr-bot.yaml @@ -0,0 +1,4 @@ +# Configuration file for `copy-pr-bot` GitHub App +# https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ + +enabled: true diff --git a/.github/ops-bot.yaml b/.github/ops-bot.yaml index 2ef41b367..1e59002c6 100644 --- a/.github/ops-bot.yaml +++ b/.github/ops-bot.yaml @@ -5,5 +5,4 @@ auto_merger: true branch_checker: true label_checker: true release_drafter: true -copy_prs: true -rerun_tests: true +forward_merger: true diff --git a/.github/workflows/ci_pipe.yml b/.github/workflows/ci_pipe.yml index 0c6da79f4..189a098e7 100644 --- a/.github/workflows/ci_pipe.yml +++ b/.github/workflows/ci_pipe.yml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -61,6 +61,7 @@ env: GH_TOKEN: "${{ github.token }}" GIT_COMMIT: "${{ github.sha }}" MRC_ROOT: "${{ github.workspace }}/mrc" + RAPIDS_CONDA_RETRY_MAX: "5" WORKSPACE: "${{ github.workspace }}/mrc" WORKSPACE_TMP: "${{ github.workspace }}/tmp" @@ -294,7 +295,7 @@ jobs: run: ./mrc/ci/scripts/github/benchmark.sh - name: post_benchmark shell: bash - run: ./mrc/ci/scripts/github/benchmark.sh + run: ./mrc/ci/scripts/github/post_benchmark.sh package: diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pr.yaml similarity index 78% rename from .github/workflows/pull_request.yml rename to .github/workflows/pr.yaml index f10b02fea..6f36c3754 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pr.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,11 +43,18 @@ permissions: statuses: none jobs: + pr-builder: + needs: + - checks + - prepare + - ci_pipe + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.10 prepare: name: Prepare runs-on: ubuntu-latest container: - image: rapidsai/ci:latest + image: rapidsai/ci-conda:latest steps: - name: Get PR Info id: get-pr-info @@ -58,11 +65,22 @@ jobs: is_main_branch: ${{ github.ref_name == 'main' }} is_dev_branch: ${{ startsWith(github.ref_name, 'branch-') }} has_conda_build_label: ${{ steps.get-pr-info.outcome == 'success' && contains(fromJSON(steps.get-pr-info.outputs.pr-info).labels.*.name, 'conda-build') || false }} + has_skip_ci_label: ${{ steps.get-pr-info.outcome == 'success' && contains(fromJSON(steps.get-pr-info.outputs.pr-info).labels.*.name, 'skip-ci') || false }} pr_info: ${{ steps.get-pr-info.outcome == 'success' && steps.get-pr-info.outputs.pr-info || '' }} + + checks: + needs: [prepare] + if: ${{ !fromJSON(needs.prepare.outputs.has_skip_ci_label) && fromJSON(needs.prepare.outputs.is_pr )}} + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.10 + with: + enable_check_generated_files: false + ci_pipe: name: CI Pipeline needs: [prepare] uses: ./.github/workflows/ci_pipe.yml + if: ${{ ! fromJSON(needs.prepare.outputs.has_skip_ci_label) }} with: # Run checks for any PR branch run_check: ${{ fromJSON(needs.prepare.outputs.is_pr) }} @@ -71,9 +89,9 @@ jobs: # Update conda package only for non PR branches. Use 'main' for main branch and 'dev' for all other branches conda_upload_label: ${{ !fromJSON(needs.prepare.outputs.is_pr) && (fromJSON(needs.prepare.outputs.is_main_branch) && 'main' || 'dev') || '' }} # Build container - container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-230711 + container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-241002 # Test container - test_container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-230711 + test_container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-241002 # Info about the PR. Empty for non PR branches. Useful for extracting PR number, title, etc. pr_info: ${{ needs.prepare.outputs.pr_info }} secrets: diff --git a/.gitignore b/.gitignore index 1a20325a2..53c9f38e0 100755 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /build*/ +.tmp *.engine .Dockerfile .gitignore @@ -17,6 +18,9 @@ include/mrc/version.hpp .vscode/settings.json .vscode/tasks.json +# Ignore user-defined clangd settings +.clangd + # Created by https://www.gitignore.io/api/vim,c++,cmake,python,synology ### C++ ### diff --git a/.gitmodules b/.gitmodules index 76d78c90c..fc54a6f5a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "morpheus_utils"] path = external/utilities url = https://github.com/nv-morpheus/utilities.git - branch = branch-23.07 + branch = branch-24.10 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..2ff37ad1a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +repos: + - repo: https://github.com/rapidsai/dependency-file-generator + rev: v1.13.11 + hooks: + - id: rapids-dependency-file-generator + args: ["--clean"] + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--settings-file=./python/setup.cfg"] + files: ^python/ + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + args: ["--config=./python/setup.cfg"] + files: ^python/ + - repo: https://github.com/google/yapf + rev: v0.40.2 + hooks: + - id: yapf + args: ["--style", "./python/setup.cfg"] + files: ^python/ + +default_language_version: + python: python3 diff --git a/CHANGELOG.md b/CHANGELOG.md index c10b0f7d1..e1d2610a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,111 @@ + + +# MRC 24.06.00 (03 Jul 2024) + +## 🚀 New Features + +- Add JSONValues container for holding Python values as JSON objects if possible, and as pybind11::object otherwise ([#455](https://github.com/nv-morpheus/MRC/pull/455)) [@dagardner-nv](https://github.com/dagardner-nv) + +## 🛠️ Improvements + +- resolve rapids-dependency-file-generator warning ([#482](https://github.com/nv-morpheus/MRC/pull/482)) [@jameslamb](https://github.com/jameslamb) +- Downgrade doxygen to match Morpheus ([#469](https://github.com/nv-morpheus/MRC/pull/469)) [@cwharris](https://github.com/cwharris) +- Consolidate redundant split_string_to_array, split_string_on & split_path methods ([#465](https://github.com/nv-morpheus/MRC/pull/465)) [@dagardner-nv](https://github.com/dagardner-nv) +- Add pybind11 type caster for JSONValues ([#458](https://github.com/nv-morpheus/MRC/pull/458)) [@dagardner-nv](https://github.com/dagardner-nv) + +# MRC 24.03.01 (16 Apr 2024) + +## 🐛 Bug Fixes + +- Add auto register helpers to AsyncSink and AsyncSource ([#473](https://github.com/nv-morpheus/MRC/pull/473)) [@dagardner-nv](https://github.com/dagardner-nv) + +# MRC 24.03.00 (7 Apr 2024) + +## 🚨 Breaking Changes + +- Update cast_from_pyobject to throw on unsupported types rather than returning null ([#451](https://github.com/nv-morpheus/MRC/pull/451)) [@dagardner-nv](https://github.com/dagardner-nv) +- RAPIDS 24.02 Upgrade ([#433](https://github.com/nv-morpheus/MRC/pull/433)) [@cwharris](https://github.com/cwharris) + +## 🐛 Bug Fixes + +- Update CR year ([#460](https://github.com/nv-morpheus/MRC/pull/460)) [@dagardner-nv](https://github.com/dagardner-nv) +- Removing the INFO log when creating an AsyncioRunnable ([#456](https://github.com/nv-morpheus/MRC/pull/456)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Update cast_from_pyobject to throw on unsupported types rather than returning null ([#451](https://github.com/nv-morpheus/MRC/pull/451)) [@dagardner-nv](https://github.com/dagardner-nv) +- Adopt updated builds of CI runners ([#442](https://github.com/nv-morpheus/MRC/pull/442)) [@dagardner-nv](https://github.com/dagardner-nv) +- Update Conda channels to prioritize `conda-forge` over `nvidia` ([#436](https://github.com/nv-morpheus/MRC/pull/436)) [@cwharris](https://github.com/cwharris) +- Remove redundant copy of libmrc_pymrc.so ([#429](https://github.com/nv-morpheus/MRC/pull/429)) [@dagardner-nv](https://github.com/dagardner-nv) +- Unifying cmake exports name across all Morpheus repos ([#427](https://github.com/nv-morpheus/MRC/pull/427)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Updating the workspace settings to remove deprecated python options ([#425](https://github.com/nv-morpheus/MRC/pull/425)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Use `dependencies.yaml` to generate environment files ([#416](https://github.com/nv-morpheus/MRC/pull/416)) [@cwharris](https://github.com/cwharris) + +## 📖 Documentation + +- Update minimum requirements ([#467](https://github.com/nv-morpheus/MRC/pull/467)) [@dagardner-nv](https://github.com/dagardner-nv) + +## 🚀 New Features + +- Add maximum simultaneous tasks support to `TaskContainer` ([#464](https://github.com/nv-morpheus/MRC/pull/464)) [@cwharris](https://github.com/cwharris) +- Add `TestScheduler` to support testing time-based coroutines without waiting for timeouts ([#453](https://github.com/nv-morpheus/MRC/pull/453)) [@cwharris](https://github.com/cwharris) +- Adding RoundRobinRouter node type for distributing values to downstream nodes ([#449](https://github.com/nv-morpheus/MRC/pull/449)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Add IoScheduler to enable epoll-based Task scheduling ([#448](https://github.com/nv-morpheus/MRC/pull/448)) [@cwharris](https://github.com/cwharris) +- Update ops-bot.yaml ([#446](https://github.com/nv-morpheus/MRC/pull/446)) [@AyodeAwe](https://github.com/AyodeAwe) +- RAPIDS 24.02 Upgrade ([#433](https://github.com/nv-morpheus/MRC/pull/433)) [@cwharris](https://github.com/cwharris) + +## 🛠️ Improvements + +- Update MRC to use CCCL instead of libcudacxx ([#444](https://github.com/nv-morpheus/MRC/pull/444)) [@cwharris](https://github.com/cwharris) +- Optionally skip the CI pipeline if the PR contains the skip-ci label ([#426](https://github.com/nv-morpheus/MRC/pull/426)) [@dagardner-nv](https://github.com/dagardner-nv) +- Add flake8, yapf, and isort pre-commit hooks. ([#420](https://github.com/nv-morpheus/MRC/pull/420)) [@cwharris](https://github.com/cwharris) + +# MRC 23.11.00 (30 Nov 2023) + +## 🐛 Bug Fixes + +- Use a traditional semaphore in AsyncioRunnable ([#412](https://github.com/nv-morpheus/MRC/pull/412)) [@cwharris](https://github.com/cwharris) +- Fix libhwloc & stubgen versions to match dev yaml ([#405](https://github.com/nv-morpheus/MRC/pull/405)) [@dagardner-nv](https://github.com/dagardner-nv) +- Update boost versions to match version used in dev env ([#404](https://github.com/nv-morpheus/MRC/pull/404)) [@dagardner-nv](https://github.com/dagardner-nv) +- Fix EdgeHolder from incorrectly reporting an active connection ([#402](https://github.com/nv-morpheus/MRC/pull/402)) [@dagardner-nv](https://github.com/dagardner-nv) +- Safe handling of control plane promises & fix CI ([#391](https://github.com/nv-morpheus/MRC/pull/391)) [@dagardner-nv](https://github.com/dagardner-nv) +- Revert boost upgrade, and update clang to v16 ([#382](https://github.com/nv-morpheus/MRC/pull/382)) [@dagardner-nv](https://github.com/dagardner-nv) +- Fixing an issue with `update-versions.sh` which always blocked CI ([#377](https://github.com/nv-morpheus/MRC/pull/377)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Add test for gc being invoked in a thread finalizer ([#365](https://github.com/nv-morpheus/MRC/pull/365)) [@dagardner-nv](https://github.com/dagardner-nv) +- Adopt patched pybind11 ([#364](https://github.com/nv-morpheus/MRC/pull/364)) [@dagardner-nv](https://github.com/dagardner-nv) + +## 📖 Documentation + +- Add missing flags to docker command to mount the working dir and set -cap-add=sys_nice ([#383](https://github.com/nv-morpheus/MRC/pull/383)) [@dagardner-nv](https://github.com/dagardner-nv) +- Make Quick Start Guide not use `make_node_full` ([#376](https://github.com/nv-morpheus/MRC/pull/376)) [@cwharris](https://github.com/cwharris) + +## 🚀 New Features + +- Add AsyncioRunnable ([#411](https://github.com/nv-morpheus/MRC/pull/411)) [@cwharris](https://github.com/cwharris) +- Adding more coroutine components to support async generators and task containers ([#408](https://github.com/nv-morpheus/MRC/pull/408)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Update ObservableProxy::pipe to support any number of operators ([#387](https://github.com/nv-morpheus/MRC/pull/387)) [@cwharris](https://github.com/cwharris) +- Updates for MRC/Morpheus to build in the same RAPIDS devcontainer environment ([#375](https://github.com/nv-morpheus/MRC/pull/375)) [@cwharris](https://github.com/cwharris) + +## 🛠️ Improvements + +- Move Pycoro from Morpheus to MRC ([#409](https://github.com/nv-morpheus/MRC/pull/409)) [@cwharris](https://github.com/cwharris) +- update rapidsai/ci to rapidsai/ci-conda ([#396](https://github.com/nv-morpheus/MRC/pull/396)) [@AyodeAwe](https://github.com/AyodeAwe) +- Add local CI scripts & rebase docker image ([#394](https://github.com/nv-morpheus/MRC/pull/394)) [@dagardner-nv](https://github.com/dagardner-nv) +- Use `copy-pr-bot` ([#369](https://github.com/nv-morpheus/MRC/pull/369)) [@ajschmidt8](https://github.com/ajschmidt8) +- Update Versions for v23.11.00 ([#357](https://github.com/nv-morpheus/MRC/pull/357)) [@mdemoret-nv](https://github.com/mdemoret-nv) + # MRC 23.07.00 (19 Jul 2023) ## 🚨 Breaking Changes diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f0f92c91..c4ff438bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,13 +30,14 @@ option(MRC_BUILD_PYTHON "Enable building the python bindings for MRC" ON) option(MRC_BUILD_TESTS "Whether or not to build MRC tests" ON) option(MRC_ENABLE_CODECOV "Enable gcov code coverage" OFF) option(MRC_ENABLE_DEBUG_INFO "Enable printing debug information" OFF) +option(MRC_PYTHON_INPLACE_BUILD "Whether or not to copy built python modules back to the source tree for debug purposes." OFF) option(MRC_USE_CCACHE "Enable caching compilation results with ccache" OFF) option(MRC_USE_CLANG_TIDY "Enable running clang-tidy as part of the build process" OFF) option(MRC_USE_CONDA "Enables finding dependencies via conda. All dependencies must be installed first in the conda environment" ON) option(MRC_USE_IWYU "Enable running include-what-you-use as part of the build process" OFF) -set(MRC_RAPIDS_VERSION "23.06" CACHE STRING "Which version of RAPIDS to build for. Sets default versions for RAPIDS CMake and RMM.") +set(MRC_RAPIDS_VERSION "24.10" CACHE STRING "Which version of RAPIDS to build for. Sets default versions for RAPIDS CMake and RMM.") set(MRC_CACHE_DIR "${CMAKE_SOURCE_DIR}/.cache" CACHE PATH "Directory to contain all CPM and CCache data") mark_as_advanced(MRC_CACHE_DIR) @@ -78,10 +79,12 @@ morpheus_utils_initialize_package_manager( morpheus_utils_initialize_cuda_arch(mrc) project(mrc - VERSION 23.07.00 + VERSION 24.10.00 LANGUAGES C CXX ) +morpheus_utils_initialize_install_prefix(MRC_USE_CONDA) + rapids_cmake_write_version_file(${CMAKE_BINARY_DIR}/autogenerated/include/mrc/version.hpp) # Delay enabling CUDA until after we have determined our CXX compiler diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b437ba321..bd755b289 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ cd $MRC_ROOT #### Create MRC Conda environment ```bash # note: `mamba` may be used in place of `conda` for better performance. -conda env create -n mrc --file $MRC_ROOT/ci/conda/environments/dev_env.yml +conda env create -n mrc --file $MRC_ROOT/conda/environments/all_cuda-125_arch-x86_64.yaml conda activate mrc ``` #### Build MRC diff --git a/Dockerfile b/Dockerfile index cae834533..f2df5d0fb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1.3 -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,39 +16,49 @@ # limitations under the License. -ARG FROM_IMAGE="rapidsai/ci" -ARG CUDA_VER=11.8.0 +ARG FROM_IMAGE="rapidsai/ci-conda" +ARG CUDA_VER=12.5.1 ARG LINUX_DISTRO=ubuntu -ARG LINUX_VER=20.04 +ARG LINUX_VER=22.04 ARG PYTHON_VER=3.10 # ============= base =================== -FROM ${FROM_IMAGE}:cuda11.8.0-ubuntu20.04-py3.10 AS base +FROM ${FROM_IMAGE}:cuda${CUDA_VER}-${LINUX_DISTRO}${LINUX_VER}-py${PYTHON_VER} AS base ARG PROJ_NAME=mrc +ARG USERNAME=morpheus +ARG USER_UID=1000 +ARG USER_GID=$USER_UID SHELL ["/bin/bash", "-c"] RUN --mount=type=cache,target=/var/cache/apt \ apt update &&\ apt install --no-install-recommends -y \ - libnuma1 && \ + libnuma1 \ + sudo && \ rm -rf /var/lib/apt/lists/* -COPY ./ci/conda/environments/* /opt/mrc/conda/environments/ +# create a user inside the container +RUN useradd --uid $USER_UID --gid $USER_GID -m $USERNAME && \ + usermod --shell /bin/bash $USERNAME && \ + echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME && \ + chmod 0440 /etc/sudoers.d/$USERNAME + +COPY ./conda/environments/all_cuda-125_arch-x86_64.yaml /opt/mrc/conda/environments/all_cuda-125_arch-x86_64.yaml RUN --mount=type=cache,target=/opt/conda/pkgs,sharing=locked \ echo "create env: ${PROJ_NAME}" && \ + sudo -g conda -u $USERNAME \ CONDA_ALWAYS_YES=true \ - /opt/conda/bin/mamba env create -q -n ${PROJ_NAME} --file /opt/mrc/conda/environments/dev_env.yml && \ - /opt/conda/bin/mamba env update -q -n ${PROJ_NAME} --file /opt/mrc/conda/environments/clang_env.yml && \ - /opt/conda/bin/mamba env update -q -n ${PROJ_NAME} --file /opt/mrc/conda/environments/ci_env.yml && \ + /opt/conda/bin/mamba env create -q -n ${PROJ_NAME} --file /opt/mrc/conda/environments/all_cuda-125_arch-x86_64.yaml && \ chmod -R a+rwX /opt/conda && \ rm -rf /tmp/conda RUN /opt/conda/bin/conda init --system &&\ sed -i 's/xterm-color)/xterm-color|*-256color)/g' ~/.bashrc &&\ - echo "conda activate ${PROJ_NAME}" >> ~/.bashrc + echo "conda activate ${PROJ_NAME}" >> ~/.bashrc && \ + cp /root/.bashrc /home/$USERNAME/.bashrc # disable sscache wrappers around compilers ENV CMAKE_CUDA_COMPILER_LAUNCHER= @@ -78,7 +88,6 @@ RUN --mount=type=cache,target=/var/cache/apt \ less \ openssh-client \ psmisc \ - sudo \ vim-tiny \ && \ rm -rf /var/lib/apt/lists/* @@ -93,17 +102,6 @@ RUN --mount=type=cache,target=/var/cache/apt \ apt-get install --no-install-recommends -y dotnet-sdk-6.0 &&\ rm -rf /var/lib/apt/lists/* -# create a user inside the container -ARG USERNAME=morpheus -ARG USER_UID=1000 -ARG USER_GID=$USER_UID - -RUN useradd --uid $USER_UID --gid $USER_GID -m $USERNAME && \ - usermod --shell /bin/bash $USERNAME && \ - echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME && \ - chmod 0440 /etc/sudoers.d/$USERNAME && \ - cp /root/.bashrc /home/$USERNAME/.bashrc - USER $USERNAME # default working directory diff --git a/README.md b/README.md index 0f05e754a..3baca24c3 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ MRC includes both Python and C++ bindings and supports installation via [conda]( ### Prerequisites -- Pascal architecture (Compute capability 6.0) or better -- NVIDIA driver `450.80.02` or higher +- Volta architecture (Compute capability 7.0) or better +- [CUDA 12.1](https://developer.nvidia.com/cuda-12-1-0-download-archive) - [conda or miniconda](https://conda.io/projects/conda/en/latest/user-guide/install/linux.html) - If using Docker: - [Docker](https://docs.docker.com/get-docker/) @@ -118,7 +118,7 @@ cd $MRC_ROOT #### Create MRC Conda Environment ```bash # note: `mamba` may be used in place of `conda` for better performance. -conda env create -n mrc-dev --file $MRC_ROOT/ci/conda/environments/dev_env.yml +conda env create -n mrc-dev --file $MRC_ROOT/conda/environments/all_cuda-125_arch-x86_64.yaml conda activate mrc-dev ``` @@ -151,13 +151,17 @@ pytest $MRC_ROOT/python ### Docker Installation A Dockerfile is provided at `$MRC_ROOT` and can be built with ```bash -docker build -t mrc:latest . +DOCKER_BUILDKIT=1 docker build -t mrc:latest . ``` To run the container ```bash -docker run --gpus all --rm -it mrc:latest /bin/bash +docker run --gpus all --cap-add=sys_nice -v $PWD:/work --rm -it mrc:latest /bin/bash ``` +> **Note:** +> Users wishing to debug MRC in a Docker container should add the following to the `docker run` command: +> `--cap-add=SYS_PTRACE` + ## Quickstart Guide To quickly learn about both the C++ and Python MRC APIs, including following along with various complexity examples, we recommend following the MRC Quickstart Repository located [here](/docs/quickstart/README.md). This tutorial walks new users through topics like diff --git a/ci/conda/environments/ci_env.yml b/ci/check_style.sh old mode 100644 new mode 100755 similarity index 51% rename from ci/conda/environments/ci_env.yml rename to ci/check_style.sh index ad05425dd..4ed7a3bf6 --- a/ci/conda/environments/ci_env.yml +++ b/ci/check_style.sh @@ -1,4 +1,5 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,10 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Additional dependencies only needed during a CI build -name: mrc -channels: - - conda-forge -dependencies: - - codecov=2.1 - - conda-merge>=0.2 +set -euo pipefail + +rapids-logger "Create checks conda environment" +. /opt/conda/etc/profile.d/conda.sh + +rapids-dependency-file-generator \ + --output conda \ + --file-key checks \ + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml + +rapids-mamba-retry env create --yes -f env.yaml -n checks +conda activate checks + +# Run pre-commit checks +pre-commit run --all-files --show-diff-on-failure diff --git a/ci/conda/environments/clang_env.yml b/ci/conda/environments/clang_env.yml deleted file mode 100644 index 9c8867ae4..000000000 --- a/ci/conda/environments/clang_env.yml +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Additional dependencies needed for clang, assumes dependencies from `dev_env.yml` -# or `dev_env_nogcc.yml` has already been installed -name: mrc -channels: - - conda-forge -dependencies: - - clang=15 - - clang-tools=15 - - clangdev=15 - - clangxx=15 - - libclang=15 - - libclang-cpp=15 - - llvmdev=15 - - include-what-you-use=0.19 diff --git a/ci/conda/environments/dev_env.yml b/ci/conda/environments/dev_env.yml deleted file mode 100644 index 58d83d9a7..000000000 --- a/ci/conda/environments/dev_env.yml +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Dependencies needed for development environment. Runtime deps are in meta.yml -name: mrc -channels: - - rapidsai - - nvidia/label/cuda-11.8.0 - - nvidia - - rapidsai-nightly - - conda-forge -dependencies: - - autoconf>=2.69 - - bash-completion - - benchmark=1.6.0 - - boost-cpp=1.74 - - ccache - - cmake=3.24 - - cuda-toolkit # Version comes from the channel above - - cxx-compiler # Sets up the distro versions of our compilers - - doxygen=1.9.2 - - flake8 - - flatbuffers=2.0 - - gcovr=5.0 - - gdb - - gflags=2.2 - - git>=2.35.3 # Needed for wildcards on safe.directory - - glog=0.6 - - gmock=1.13 - - graphviz=3.0 - - libgrpc=1.54.0 - - gtest=1.13 - - gxx=11.2 # Specifies which versions of GXX and GCC to use - - isort - - jinja2=3.0 - - lcov=1.15 - - libhwloc=2.5 - - libprotobuf=3.21 - - librmm=23.06 - - libtool - - ninja=1.10 - - nlohmann_json=3.9 - - numactl-libs-cos7-x86_64 - - numpy>=1.21 - - pip - - pkg-config=0.29 - - pybind11-stubgen=0.10 - - pytest - - pytest-timeout - - python=3.10 - - scikit-build>=0.17 - - sysroot_linux-64=2.17 - - ucx=1.14 - - yapf - - # Remove once `mamba repoquery whoneeds cudatoolkit` is empty. For now, we need to specify a version - - cudatoolkit=11.8 diff --git a/ci/conda/recipes/libmrc/build.sh b/ci/conda/recipes/libmrc/build.sh index 3bdbf295f..3b9a469e8 100644 --- a/ci/conda/recipes/libmrc/build.sh +++ b/ci/conda/recipes/libmrc/build.sh @@ -62,7 +62,6 @@ CMAKE_ARGS="-DMRC_RAPIDS_VERSION=${rapids_version} ${CMAKE_ARGS}" CMAKE_ARGS="-DMRC_USE_CCACHE=OFF ${CMAKE_ARGS}" CMAKE_ARGS="-DMRC_USE_CONDA=ON ${CMAKE_ARGS}" CMAKE_ARGS="-DPython_EXECUTABLE=${PYTHON} ${CMAKE_ARGS}" -CMAKE_ARGS="-DUCX_VERSION=${ucx} ${CMAKE_ARGS}" echo "CC : ${CC}" echo "CXX : ${CXX}" diff --git a/ci/conda/recipes/libmrc/conda_build_config.yaml b/ci/conda/recipes/libmrc/conda_build_config.yaml index 008688e98..0ab4a5dd9 100644 --- a/ci/conda/recipes/libmrc/conda_build_config.yaml +++ b/ci/conda/recipes/libmrc/conda_build_config.yaml @@ -14,71 +14,20 @@ # limitations under the License. c_compiler_version: - - 11.2 + - 12.1 cxx_compiler_version: - - 11.2 + - 12.1 cuda_compiler: - cuda-nvcc cuda_compiler_version: - - 11.8 + - 12.5 python: - - 3.8 - - 3.10 - 3.10 # Setup the dependencies to build with multiple versions of RAPIDS rapids_version: # Keep around compatibility with current version -2 - - 23.02 - - 23.04 - - 23.06 - -# Multiple versions of abseil are required to satisfy the solver for some -# environments. RAPIDS 22.06 only works with gRPC 1.45 and 22.08 only works with -# 1.46. For each version of gRPC, support 2 abseil versions. Zip all of the keys -# together to avoid impossible combinations -libabseil: - - 20230125.0 - - 20230125.0 - - 20230125.0 - -libgrpc: - - 1.51 - - 1.51 - - 1.54 - -ucx: - - 1.13 - - 1.14 - - 1.14 - -libprotobuf: - - 3.21 - - 3.21 - - 3.21 - -zip_keys: - - python - - rapids_version - - libabseil - - libgrpc - - ucx - - libprotobuf - -# The following mimic what is available in the pinning feedstock: -# https://github.com/conda-forge/conda-forge-pinning-feedstock/blob/main/recipe/conda_build_config.yaml -boost: - - 1.74.0 -boost_cpp: - - 1.74.0 -gflags: - - 2.2 -glog: - - 0.6 - -pin_run_as_build: - boost-cpp: - max_pin: x.x + - 24.10 diff --git a/ci/conda/recipes/libmrc/meta.yaml b/ci/conda/recipes/libmrc/meta.yaml index 6abbd7c19..2664ba474 100644 --- a/ci/conda/recipes/libmrc/meta.yaml +++ b/ci/conda/recipes/libmrc/meta.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +14,8 @@ # limitations under the License. {% set version = environ.get('GIT_VERSION', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} -{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} {% set py_version = environ.get('CONDA_PY', '3.10') %} -{% set cuda_version = '.'.join(environ.get('CUDA', '11.8').split('.')[:2]) %} -{% set cuda_major = cuda_version.split('.')[0] %} +{% set cuda_version = '.'.join(environ.get('CUDA', '12.5').split('.')[:2]) %} package: name: libmrc-split @@ -40,39 +38,32 @@ requirements: - {{ compiler("c") }} - {{ compiler("cuda") }} - {{ compiler("cxx") }} - - autoconf >=2.69 - ccache - - cmake >=3.24 - - cuda-cudart-dev # Needed by CMake to compile a test application + - cmake =3.27 - libtool - - ninja - - numactl-libs-cos7-x86_64 - - pkg-config 0.29.* - - sysroot_linux-64 >=2.17 + - ninja =1.11 + - numactl =2.0.18 + - pkg-config =0.29 + - sysroot_linux-64 >=2.28 host: # Libraries necessary to build. Keep sorted! - - boost-cpp - - cuda-cudart-dev - - cuda-nvml-dev - - doxygen 1.9.2.* - - flatbuffers 2.0.* - - gflags - - glog - - gmock 1.13.* - - libgrpc - - gtest 1.13.* - - libabseil - - libhwloc 2.5.* - - libprotobuf + - boost-cpp =1.84 + - cuda-cudart-dev {{ cuda_version }}.* + - cuda-nvml-dev {{ cuda_version }}.* + - cuda-nvrtc-dev {{ cuda_version }}.* + - cuda-version {{ cuda_version }}.* + - doxygen 1.10.0 + - glog>=0.7.1,<0.8 + - libgrpc =1.62.2 + - gtest =1.14 + - libhwloc =2.9.2 - librmm {{ rapids_version }} - - nlohmann_json 3.9.1 + - nlohmann_json =3.11 - pybind11-abi # See: https://conda-forge.org/docs/maintainer/knowledge_base.html#pybind11-abi-constraints - - pybind11-stubgen 0.10.5 + - pybind11-stubgen =0.10 - python {{ python }} - - scikit-build >=0.17 - - ucx - # Need to specify cudatoolkit to get correct version. Remove once all libraries migrate to cuda-toolkit - - cudatoolkit {{ cuda_version }}.* + - scikit-build =0.17 + - ucx =1.15 outputs: - name: libmrc @@ -88,32 +79,26 @@ outputs: - {{ compiler("c") }} - {{ compiler("cuda") }} - {{ compiler("cxx") }} - - cmake >=3.24 - - numactl-libs-cos7-x86_64 - - sysroot_linux-64 2.17 + - cmake =3.27 + - numactl =2.0.18 + - sysroot_linux-64 >=2.28 host: # Any libraries with weak run_exports need to go here to be added to the run. Keep sorted! - - boost-cpp - - cuda-cudart # Needed to allow pin_compatible to work - - glog - - libgrpc - - libabseil # Needed for transitive run_exports from libgrpc. Does not need a version - - libhwloc 2.5.* - - libprotobuf # Needed for transitive run_exports from libgrpc. Does not need a version + - boost-cpp =1.84 + - cuda-version # Needed to allow pin_compatible to work + - glog>=0.7.1,<0.8 + - libgrpc =1.62.2 + - libhwloc =2.9.2 - librmm {{ rapids_version }} - - nlohmann_json 3.9.* - - ucx - # Need to specify cudatoolkit to get correct version. Remove once all libraries migrate to cuda-toolkit - - cudatoolkit {{ cuda_version }}.* + - nlohmann_json =3.11 + - ucx =1.15 run: # Manually add any packages necessary for run that do not have run_exports. Keep sorted! - - {{ pin_compatible('cuda-cudart', min_pin='x.x', max_pin='x') }} - - {{ pin_compatible('nlohmann_json', max_pin='x.x')}} - - {{ pin_compatible('ucx', max_pin='x.x')}} - - boost-cpp # Needed to use pin_run_as_build - run_constrained: - # Since we dont explicitly require this but other packages might, constrain the versions - - {{ pin_compatible('cudatoolkit', min_pin='x.x', max_pin='x') }} + - cuda-version {{ cuda_version }}.* + - nlohmann_json =3.11 + - ucx =1.15 + - cuda-cudart + - boost-cpp =1.84 test: script: test_libmrc.sh files: diff --git a/ci/conda/recipes/run_conda_build.sh b/ci/conda/recipes/run_conda_build.sh index bbdc3ed2e..263a93388 100755 --- a/ci/conda/recipes/run_conda_build.sh +++ b/ci/conda/recipes/run_conda_build.sh @@ -95,11 +95,11 @@ fi # Choose default variants if hasArg quick; then # For quick build, just do most recent version of rapids - CONDA_ARGS_ARRAY+=("--variants" "{rapids_version: 23.06}") + CONDA_ARGS_ARRAY+=("--variants" "{rapids_version: 24.10}") fi -# And default channels -CONDA_ARGS_ARRAY+=("-c" "rapidsai" "-c" "nvidia/label/cuda-11.8.0" "-c" "nvidia" "-c" "conda-forge" "-c" "main") +# And default channels (should match dependencies.yaml) +CONDA_ARGS_ARRAY+=("-c" "conda-forge" "-c" "rapidsai" "-c" "rapidsai-nightly" "-c" "nvidia") # Set GIT_VERSION to set the project version inside of meta.yaml export GIT_VERSION="$(get_version)" diff --git a/ci/githooks/pre-commit b/ci/githooks/pre-commit index e74e35fb3..7fa4b83a1 100755 --- a/ci/githooks/pre-commit +++ b/ci/githooks/pre-commit @@ -41,6 +41,5 @@ export CHANGED_FILES=$(GIT_DIFF_ARGS="--cached --name-only" get_modified_files) if [[ "${CHANGED_FILES}" != "" ]]; then run_and_check "python3 ci/scripts/copyright.py --git-diff-staged --update-current-year --verify-apache-v2 --git-add" - run_and_check "ci/scripts/python_checks.sh" SKIP_CLANG_TIDY=1 SKIP_IWYU=1 run_and_check "ci/scripts/cpp_checks.sh" fi diff --git a/ci/iwyu/mappings.imp b/ci/iwyu/mappings.imp index 7e9f70083..627e20127 100644 --- a/ci/iwyu/mappings.imp +++ b/ci/iwyu/mappings.imp @@ -11,6 +11,7 @@ # boost { "include": ["@", "private", "", "public"] }, +{ "include": ["@", "private", "", "public"] }, # cuda { "include": ["", "private", "", "public"] }, @@ -33,6 +34,7 @@ { "symbol": ["@grpc::.*", "private", "", "public"] }, # nlohmann json +{ "include": ["", "public", "", "public"] }, { "include": ["", "private", "", "public"] }, { "include": ["", "private", "", "public"] }, { "include": ["", "private", "", "public"] }, @@ -108,7 +110,9 @@ { "symbol": ["nlohmann::json", "private", "", "public"] }, # pybind11 -{ "include": [ "", private, "", "public" ] }, +{ "include": [ "@", private, "", "public" ] }, +{ "include": [ "@\"pybind11/detail/.*.h\"", private, "\"pybind11/pybind11.h\"", "public" ] }, + { "symbol": ["pybind11", "private", "", "public"] }, { "symbol": ["pybind11", "private", "", "public"] }, diff --git a/ci/release/pr_code_freeze_template.md b/ci/release/pr_code_freeze_template.md new file mode 100644 index 000000000..62f0e82ed --- /dev/null +++ b/ci/release/pr_code_freeze_template.md @@ -0,0 +1,11 @@ +## :snowflake: Code freeze for `branch-${CURRENT_VERSION}` and `v${CURRENT_VERSION}` release + +### What does this mean? +Only critical/hotfix level issues should be merged into `branch-${CURRENT_VERSION}` until release (merging of this PR). + +All other development PRs should be retargeted towards the next release branch: `branch-${NEXT_VERSION}`. + +### What is the purpose of this PR? +- Update documentation +- Allow testing for the new release +- Enable a means to merge `branch-${CURRENT_VERSION}` into `main` for the release diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 8e4895f23..31b541957 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -60,6 +60,10 @@ function sed_runner() { # .gitmodules git submodule set-branch -b branch-${NEXT_SHORT_TAG} morpheus_utils +if [[ "$(git diff --name-only | grep .gitmodules)" != "" ]]; then + # Only update the submodules if setting the branch changed .gitmodules + git submodule update --remote +fi # Root CMakeLists.txt sed_runner 's/'"VERSION ${CURRENT_FULL_VERSION}.*"'/'"VERSION ${NEXT_FULL_VERSION}"'/g' CMakeLists.txt diff --git a/ci/scripts/bootstrap_local_ci.sh b/ci/scripts/bootstrap_local_ci.sh new file mode 100755 index 000000000..f1ff55bb2 --- /dev/null +++ b/ci/scripts/bootstrap_local_ci.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export WORKSPACE_TMP="$(pwd)/.tmp/local_ci_workspace" +mkdir -p ${WORKSPACE_TMP} +git clone ${GIT_URL} mrc +cd mrc/ +git checkout ${GIT_BRANCH} +git pull +git checkout ${GIT_COMMIT} + +export MRC_ROOT=$(pwd) +export WORKSPACE=${MRC_ROOT} +export LOCAL_CI=1 +GH_SCRIPT_DIR="${MRC_ROOT}/ci/scripts/github" + +unset CMAKE_CUDA_COMPILER_LAUNCHER +unset CMAKE_CXX_COMPILER_LAUNCHER +unset CMAKE_C_COMPILER_LAUNCHER + +if [[ "${STAGE}" != "bash" ]]; then + # benchmark & codecov are composite stages, the rest are composed of a single shell script + if [[ "${STAGE}" == "benchmark" || "${STAGE}" == "codecov" ]]; then + CI_SCRIPT="${WORKSPACE_TMP}/ci_script.sh" + echo "#!/bin/bash" > ${CI_SCRIPT} + if [[ "${STAGE}" == "benchmark" ]]; then + echo "${GH_SCRIPT_DIR}/pre_benchmark.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/benchmark.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/post_benchmark.sh" >> ${CI_SCRIPT} + else + echo "${GH_SCRIPT_DIR}/build.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/test_codecov.sh" >> ${CI_SCRIPT} + fi + + chmod +x ${CI_SCRIPT} + else + if [[ "${STAGE}" =~ "build" ]]; then + CI_SCRIPT="${GH_SCRIPT_DIR}/build.sh" + elif [[ "${STAGE}" =~ "test" ]]; then + CI_SCRIPT="${GH_SCRIPT_DIR}/test.sh" + else + CI_SCRIPT="${GH_SCRIPT_DIR}/${STAGE}.sh" + fi + fi + + ${CI_SCRIPT} +fi diff --git a/ci/scripts/cpp_checks.sh b/ci/scripts/cpp_checks.sh index 416c92167..c9127cc36 100755 --- a/ci/scripts/cpp_checks.sh +++ b/ci/scripts/cpp_checks.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -80,10 +80,25 @@ if [[ -n "${MRC_MODIFIED_FILES}" ]]; then # Include What You Use if [[ "${SKIP_IWYU}" == "" ]]; then - IWYU_DIRS="cpp python" - NUM_PROC=$(get_num_proc) - IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_DIRS} 2>&1` - IWYU_RETVAL=$? + # Remove .h, .hpp, and .cu files from the modified list + shopt -s extglob + IWYU_MODIFIED_FILES=( "${MRC_MODIFIED_FILES[@]/*.@(h|hpp|cu)/}" ) + + if [[ -n "${IWYU_MODIFIED_FILES}" ]]; then + # Get the list of compiled files relative to this directory + WORKING_PREFIX="${PWD}/" + COMPILED_FILES=( $(jq -r .[].file ${BUILD_DIR}/compile_commands.json | sort -u ) ) + COMPILED_FILES=( "${COMPILED_FILES[@]/#$WORKING_PREFIX/}" ) + COMBINED_FILES=("${COMPILED_FILES[@]}") + COMBINED_FILES+=("${IWYU_MODIFIED_FILES[@]}") + + # Find the intersection between compiled files and modified files + IWYU_MODIFIED_FILES=( $(printf '%s\0' "${COMBINED_FILES[@]}" | sort -z | uniq -d -z | xargs -0n1) ) + + NUM_PROC=$(get_num_proc) + IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_MODIFIED_FILES[@]} 2>&1` + IWYU_RETVAL=$? + fi fi else echo "No modified C++ files to check" diff --git a/ci/scripts/github/build.sh b/ci/scripts/github/build.sh index e63f04eef..300452c05 100755 --- a/ci/scripts/github/build.sh +++ b/ci/scripts/github/build.sh @@ -20,7 +20,12 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env -CMAKE_CACHE_FLAGS="-DCCACHE_PROGRAM_PATH=$(which sccache) -DMRC_USE_CCACHE=ON" +if [[ "${LOCAL_CI}" == "" ]]; then + CMAKE_CACHE_FLAGS="-DCCACHE_PROGRAM_PATH=$(which sccache) -DMRC_USE_CCACHE=ON" +else + CMAKE_CACHE_FLAGS="" +fi + rapids-logger "Check versions" python3 --version @@ -56,18 +61,20 @@ cmake -B build -G Ninja ${CMAKE_FLAGS} . rapids-logger "Building MRC" cmake --build build --parallel ${PARALLEL_LEVEL} -rapids-logger "sccache usage for MRC build:" -sccache --show-stats +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "sccache usage for MRC build:" + sccache --show-stats +fi -if [[ "${BUILD_CC}" != "gcc-coverage" ]]; then +if [[ "${BUILD_CC}" != "gcc-coverage" || ${LOCAL_CI} == "1" ]]; then rapids-logger "Archiving results" tar cfj "${WORKSPACE_TMP}/dot_cache.tar.bz" .cache tar cfj "${WORKSPACE_TMP}/build.tar.bz" build ls -lh ${WORKSPACE_TMP}/ rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" - aws s3 cp --no-progress "${WORKSPACE_TMP}/build.tar.bz" "${ARTIFACT_URL}/build.tar.bz" - aws s3 cp --no-progress "${WORKSPACE_TMP}/dot_cache.tar.bz" "${ARTIFACT_URL}/dot_cache.tar.bz" + upload_artifact "${WORKSPACE_TMP}/build.tar.bz" + upload_artifact "${WORKSPACE_TMP}/dot_cache.tar.bz" fi rapids-logger "Success" diff --git a/ci/scripts/github/checks.sh b/ci/scripts/github/checks.sh index 4ea5c5583..c85885d9c 100755 --- a/ci/scripts/github/checks.sh +++ b/ci/scripts/github/checks.sh @@ -1,5 +1,5 @@ #!/usr/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +24,8 @@ update_conda_env rapids-logger "Configuring CMake" git submodule update --init --recursive -cmake -B build -G Ninja ${CMAKE_BUILD_ALL_FEATURES} . +CMAKE_CLANG_OPTIONS="-DCMAKE_C_COMPILER:FILEPATH=$(which clang) -DCMAKE_CXX_COMPILER:FILEPATH=$(which clang++) -DCMAKE_CUDA_COMPILER:FILEPATH=$(which nvcc)" +cmake -B build -G Ninja ${CMAKE_CLANG_OPTIONS} ${CMAKE_BUILD_ALL_FEATURES} . rapids-logger "Building targets that generate source code" cmake --build build --target mrc_style_checks --parallel ${PARALLEL_LEVEL} @@ -35,8 +36,5 @@ ${MRC_ROOT}/ci/scripts/version_checks.sh rapids-logger "Running C++ style checks" ${MRC_ROOT}/ci/scripts/cpp_checks.sh -rapids-logger "Runing Python style checks" -${MRC_ROOT}/ci/scripts/python_checks.sh - rapids-logger "Checking copyright headers" python ${MRC_ROOT}/ci/scripts/copyright.py --verify-apache-v2 --git-diff-commits ${CHANGE_TARGET} ${GIT_COMMIT} diff --git a/ci/scripts/github/common.sh b/ci/scripts/github/common.sh index 02684da2f..6b5ba72bd 100644 --- a/ci/scripts/github/common.sh +++ b/ci/scripts/github/common.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,9 +35,7 @@ id export NUM_PROC=${PARALLEL_LEVEL:-$(nproc)} export BUILD_CC=${BUILD_CC:-"gcc"} -export CONDA_ENV_YML="${MRC_ROOT}/ci/conda/environments/dev_env.yml" -export CONDA_CLANG_ENV_YML="${MRC_ROOT}/ci/conda/environments/clang_env.yml" -export CONDA_CI_ENV_YML="${MRC_ROOT}/ci/conda/environments/ci_env.yml" +export CONDA_ENV_YML="${MRC_ROOT}/conda/environments/all_cuda-125_arch-x86_64.yaml" export CMAKE_BUILD_ALL_FEATURES="-DCMAKE_MESSAGE_CONTEXT_SHOW=ON -DMRC_BUILD_BENCHMARKS=ON -DMRC_BUILD_EXAMPLES=ON -DMRC_BUILD_PYTHON=ON -DMRC_BUILD_TESTS=ON -DMRC_USE_CONDA=ON -DMRC_PYTHON_BUILD_STUBS=ON" export CMAKE_BUILD_WITH_CODECOV="-DCMAKE_BUILD_TYPE=Debug -DMRC_ENABLE_CODECOV=ON -DMRC_PYTHON_PERFORM_INSTALL:BOOL=ON -DMRC_PYTHON_INPLACE_BUILD:BOOL=ON" @@ -56,7 +54,12 @@ export S3_URL="s3://rapids-downloads/ci/mrc" export DISPLAY_URL="https://downloads.rapids.ai/ci/mrc" export ARTIFACT_ENDPOINT="/pull-request/${PR_NUM}/${GIT_COMMIT}/${NVARCH}/${BUILD_CC}" export ARTIFACT_URL="${S3_URL}${ARTIFACT_ENDPOINT}" -export DISPLAY_ARTIFACT_URL="${DISPLAY_URL}${ARTIFACT_ENDPOINT}" + +if [[ "${LOCAL_CI}" == "1" ]]; then + export DISPLAY_ARTIFACT_URL="${LOCAL_CI_TMP}" +else + export DISPLAY_ARTIFACT_URL="${DISPLAY_URL}${ARTIFACT_ENDPOINT}" +fi # Set sccache env vars export SCCACHE_S3_KEY_PREFIX=mrc-${NVARCH}-${BUILD_CC} @@ -78,34 +81,27 @@ function update_conda_env() { # Deactivate the environment first before updating conda deactivate - # Make sure we have the conda-merge package installed - if [[ -z "$(conda list | grep conda-merge)" ]]; then - rapids-mamba-retry install -q -n mrc -c conda-forge "conda-merge>=0.2" + if [[ "${SKIP_CONDA_ENV_UPDATE}" == "" ]]; then + # Update the conda env with prune remove excess packages (in case one was removed from the env) + # use conda instead of mamba due to bug: https://github.com/mamba-org/mamba/issues/3059 + rapids-conda-retry env update -n mrc --solver=libmamba --prune --file ${CONDA_ENV_YML} fi - # Create a temp directory which we store the combined environment file in - condatmpdir=$(mktemp -d) - - # Merge the environments together so we can use --prune. Otherwise --prune - # will clobber the last env update - conda run -n mrc --live-stream conda-merge ${CONDA_ENV_YML} ${CONDA_CLANG_ENV_YML} ${CONDA_CI_ENV_YML} > ${condatmpdir}/merged_env.yml - - # Update the conda env with prune remove excess packages (in case one was removed from the env) - rapids-mamba-retry env update -n mrc --prune --file ${condatmpdir}/merged_env.yml - - # Delete the temp directory - rm -rf ${condatmpdir} - # Finally, reactivate conda activate mrc rapids-logger "Final Conda Environment" - conda list + mamba list } print_env_vars -function fetch_base_branch() { +function fetch_base_branch_gh_api() { + # For PRs, $GIT_BRANCH is like: pull-request/989 + REPO_NAME=$(basename "${GITHUB_REPOSITORY}") + ORG_NAME="${GITHUB_REPOSITORY_OWNER}" + PR_NUM="${GITHUB_REF_NAME##*/}" + rapids-logger "Retrieving base branch from GitHub API" [[ -n "$GH_TOKEN" ]] && CURL_HEADERS=('-H' "Authorization: token ${GH_TOKEN}") RESP=$( @@ -115,25 +111,31 @@ function fetch_base_branch() { "${GITHUB_API_URL}/repos/${ORG_NAME}/${REPO_NAME}/pulls/${PR_NUM}" ) - BASE_BRANCH=$(echo "${RESP}" | jq -r '.base.ref') + export BASE_BRANCH=$(echo "${RESP}" | jq -r '.base.ref') # Change target is the branch name we are merging into but due to the weird way jenkins does # the checkout it isn't recognized by git without the origin/ prefix export CHANGE_TARGET="origin/${BASE_BRANCH}" - git submodule update --init --recursive - rapids-logger "Base branch: ${BASE_BRANCH}" } -function fetch_s3() { - ENDPOINT=$1 - DESTINATION=$2 - if [[ "${USE_S3_CURL}" == "1" ]]; then - curl -f "${DISPLAY_URL}${ENDPOINT}" -o "${DESTINATION}" - FETCH_STATUS=$? +function fetch_base_branch_local() { + rapids-logger "Retrieving base branch from git" + git remote add upstream ${GIT_UPSTREAM_URL} + git fetch upstream --tags + source ${MRC_ROOT}/ci/scripts/common.sh + export BASE_BRANCH=$(get_base_branch) + export CHANGE_TARGET="upstream/${BASE_BRANCH}" +} + +function fetch_base_branch() { + if [[ "${LOCAL_CI}" == "1" ]]; then + fetch_base_branch_local else - aws s3 cp --no-progress "${S3_URL}${ENDPOINT}" "${DESTINATION}" - FETCH_STATUS=$? + fetch_base_branch_gh_api fi + + git submodule update --init --recursive + rapids-logger "Base branch: ${BASE_BRANCH}" } function show_conda_info() { @@ -143,3 +145,25 @@ function show_conda_info() { conda config --show-sources conda list --show-channel-urls } + +function upload_artifact() { + FILE_NAME=$1 + BASE_NAME=$(basename "${FILE_NAME}") + rapids-logger "Uploading artifact: ${BASE_NAME}" + if [[ "${LOCAL_CI}" == "1" ]]; then + cp ${FILE_NAME} "${LOCAL_CI_TMP}/${BASE_NAME}" + else + aws s3 cp --only-show-errors "${FILE_NAME}" "${ARTIFACT_URL}/${BASE_NAME}" + echo "- ${DISPLAY_ARTIFACT_URL}/${BASE_NAME}" >> ${GITHUB_STEP_SUMMARY} + fi +} + +function download_artifact() { + ARTIFACT=$1 + rapids-logger "Downloading ${ARTIFACT} from ${DISPLAY_ARTIFACT_URL}" + if [[ "${LOCAL_CI}" == "1" ]]; then + cp "${LOCAL_CI_TMP}/${ARTIFACT}" "${WORKSPACE_TMP}/${ARTIFACT}" + else + aws s3 cp --no-progress "${ARTIFACT_URL}/${ARTIFACT}" "${WORKSPACE_TMP}/${ARTIFACT}" + fi +} diff --git a/ci/scripts/github/conda.sh b/ci/scripts/github/conda.sh index 3b8104ad3..36a878528 100755 --- a/ci/scripts/github/conda.sh +++ b/ci/scripts/github/conda.sh @@ -16,6 +16,7 @@ set -e +CI_SCRIPT_ARGS="$@" source ${WORKSPACE}/ci/scripts/github/common.sh # Its important that we are in the base environment for the build @@ -39,4 +40,15 @@ conda info rapids-logger "Building Conda Package" # Run the conda build and upload -${MRC_ROOT}/ci/conda/recipes/run_conda_build.sh "$@" +${MRC_ROOT}/ci/conda/recipes/run_conda_build.sh "${CI_SCRIPT_ARGS}" + +if [[ " ${CI_SCRIPT_ARGS} " =~ " upload " ]]; then + rapids-logger "Building Conda Package... Done" +else + # if we didn't receive the upload argument, we can still upload the artifact to S3 + tar cfj "${WORKSPACE_TMP}/conda.tar.bz" "${RAPIDS_CONDA_BLD_OUTPUT_DIR}" + ls -lh ${WORKSPACE_TMP}/ + + rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" + upload_artifact "${WORKSPACE_TMP}/conda.tar.bz" +fi diff --git a/ci/scripts/github/docs.sh b/ci/scripts/github/docs.sh index 2e0a1f64c..c5f10a53a 100755 --- a/ci/scripts/github/docs.sh +++ b/ci/scripts/github/docs.sh @@ -39,6 +39,6 @@ rapids-logger "Tarring the docs" tar cfj "${WORKSPACE_TMP}/docs.tar.bz" build/docs/html rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp --no-progress "${WORKSPACE_TMP}/docs.tar.bz" "${ARTIFACT_URL}/docs.tar.bz" +upload_artifact "${WORKSPACE_TMP}/docs.tar.bz" rapids-logger "Success" diff --git a/ci/scripts/github/post_benchmark.sh b/ci/scripts/github/post_benchmark.sh index d08bce2b4..943abc7e0 100755 --- a/ci/scripts/github/post_benchmark.sh +++ b/ci/scripts/github/post_benchmark.sh @@ -1,5 +1,5 @@ #!/usr/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +25,6 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/benchmark_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/benchmark_reports.tar.bz "${ARTIFACT_URL}/benchmark_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/benchmark_reports.tar.bz exit $(cat ${WORKSPACE_TMP}/exit_status) diff --git a/ci/scripts/github/pre_benchmark.sh b/ci/scripts/github/pre_benchmark.sh index 419df25c2..c14a29144 100755 --- a/ci/scripts/github/pre_benchmark.sh +++ b/ci/scripts/github/pre_benchmark.sh @@ -1,5 +1,5 @@ #!/usr/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +21,7 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env rapids-logger "Fetching Build artifacts from ${DISPLAY_ARTIFACT_URL}/" -fetch_s3 "${ARTIFACT_ENDPOINT}/build.tar.bz" "${WORKSPACE_TMP}/build.tar.bz" +download_artifact "build.tar.bz" tar xf "${WORKSPACE_TMP}/build.tar.bz" diff --git a/ci/scripts/github/test.sh b/ci/scripts/github/test.sh index 0aab525a0..40000a516 100755 --- a/ci/scripts/github/test.sh +++ b/ci/scripts/github/test.sh @@ -22,8 +22,8 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env rapids-logger "Fetching Build artifacts from ${DISPLAY_ARTIFACT_URL}/" -fetch_s3 "${ARTIFACT_ENDPOINT}/dot_cache.tar.bz" "${WORKSPACE_TMP}/dot_cache.tar.bz" -fetch_s3 "${ARTIFACT_ENDPOINT}/build.tar.bz" "${WORKSPACE_TMP}/build.tar.bz" +download_artifact "dot_cache.tar.bz" +download_artifact "build.tar.bz" tar xf "${WORKSPACE_TMP}/dot_cache.tar.bz" tar xf "${WORKSPACE_TMP}/build.tar.bz" @@ -60,7 +60,7 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/test_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/test_reports.tar.bz "${ARTIFACT_URL}/test_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/test_reports.tar.bz TEST_RESULTS=$(($CTEST_RESULTS+$PYTEST_RESULTS)) exit ${TEST_RESULTS} diff --git a/ci/scripts/github/test_codecov.sh b/ci/scripts/github/test_codecov.sh index 4a0ef3ce8..97955859a 100755 --- a/ci/scripts/github/test_codecov.sh +++ b/ci/scripts/github/test_codecov.sh @@ -58,13 +58,16 @@ cd ${MRC_ROOT}/build # correctly and enabling relative only ignores system and conda files. find . -type f -name '*.gcda' -exec x86_64-conda_cos6-linux-gnu-gcov -pbc --source-prefix ${MRC_ROOT} --relative-only {} + 1> /dev/null -rapids-logger "Uploading codecov for C++ tests" -# Get the list of files that we are interested in (Keeps the upload small) -GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "Uploading codecov for C++ tests" -# Upload the .gcov files directly to codecov. They do a good job at processing the partials -/opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F cpp + # Get the list of files that we are interested in (Keeps the upload small) + GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) + + # Upload the .gcov files directly to codecov. They do a good job at processing the partials + /opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F cpp +fi # Remove the gcov files and any gcda files to reset counters find . -type f \( -iname "*.gcov" -or -iname "*.gcda" \) -exec rm {} \; @@ -85,13 +88,15 @@ cd ${MRC_ROOT}/build # correctly and enabling relative only ignores system and conda files. find . -type f -name '*.gcda' -exec x86_64-conda_cos6-linux-gnu-gcov -pbc --source-prefix ${MRC_ROOT} --relative-only {} + 1> /dev/null -rapids-logger "Uploading codecov for Python tests" +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "Uploading codecov for Python tests" -# Get the list of files that we are interested in (Keeps the upload small) -GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) + # Get the list of files that we are interested in (Keeps the upload small) + GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) -# Upload the .gcov files directly to codecov. They do a good job at processing the partials -/opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F py + # Upload the .gcov files directly to codecov. They do a good job at processing the partials + /opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F py +fi # Remove the gcov files and any gcda files to reset counters find . -type f \( -iname "*.gcov" -or -iname "*.gcda" \) -exec rm {} \; @@ -101,7 +106,7 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/test_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/test_reports.tar.bz "${ARTIFACT_URL}/test_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/test_reports.tar.bz TEST_RESULTS=$(($CTEST_RESULTS+$PYTEST_RESULTS)) exit ${TEST_RESULTS} diff --git a/ci/scripts/python_checks.sh b/ci/scripts/python_checks.sh deleted file mode 100755 index fb6015735..000000000 --- a/ci/scripts/python_checks.sh +++ /dev/null @@ -1,105 +0,0 @@ -#!/bin/bash - -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Based on style.sh from Morpheus - -SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -source ${SCRIPT_DIR}/common.sh - -# Ignore errors and set path -set +e -LC_ALL=C.UTF-8 -LANG=C.UTF-8 - -# Pre-populate the return values in case they are skipped -ISORT_RETVAL=0 -FLAKE_RETVAL=0 -YAPF_RETVAL=0 - -get_modified_files ${PYTHON_FILE_REGEX} MRC_MODIFIED_FILES - -# When invoked by the git pre-commit hook CHANGED_FILES will already be defined -if [[ -n "${MRC_MODIFIED_FILES}" ]]; then - echo -e "Running Python checks on ${#MRC_MODIFIED_FILES[@]} files:" - - for f in "${MRC_MODIFIED_FILES[@]}"; do - echo " $f" - done - - if [[ "${SKIP_ISORT}" == "" ]]; then - ISORT_OUTPUT=`python3 -m isort --settings-file ${PY_CFG} --filter-files --check-only ${MRC_MODIFIED_FILES[@]} 2>&1` - ISORT_RETVAL=$? - fi - - if [[ "${SKIP_FLAKE}" == "" ]]; then - FLAKE_OUTPUT=`python3 -m flake8 --config ${PY_CFG} ${MRC_MODIFIED_FILES[@]} 2>&1` - FLAKE_RETVAL=$? - fi - - if [[ "${SKIP_YAPF}" == "" ]]; then - # Run yapf. Will return 1 if there are any diffs - YAPF_OUTPUT=`python3 -m yapf --style ${PY_CFG} --diff ${MRC_MODIFIED_FILES[@]} 2>&1` - YAPF_RETVAL=$? - fi -else - echo "No modified Python files to check" -fi - -# Output results if failure otherwise show pass -if [[ "${SKIP_ISORT}" != "" ]]; then - echo -e "\n\n>>>> SKIPPED: isort check\n\n" -elif [ "${ISORT_RETVAL}" != "0" ]; then - echo -e "\n\n>>>> FAILED: isort style check; begin output\n\n" - echo -e "${ISORT_OUTPUT}" - echo -e "\n\n>>>> FAILED: isort style check; end output\n\n" \ - "To auto-fix many issues (not all) run:\n" \ - " ./ci/scripts/fix_all.sh\n\n" -else - echo -e "\n\n>>>> PASSED: isort style check\n\n" -fi - -if [[ "${SKIP_FLAKE}" != "" ]]; then - echo -e "\n\n>>>> SKIPPED: flake8 check\n\n" -elif [ "${FLAKE_RETVAL}" != "0" ]; then - echo -e "\n\n>>>> FAILED: flake8 style check; begin output\n\n" - echo -e "${FLAKE_OUTPUT}" - echo -e "\n\n>>>> FAILED: flake8 style check; end output\n\n" \ - "To auto-fix many issues (not all) run:\n" \ - " ./ci/scripts/fix_all.sh\n\n" -else - echo -e "\n\n>>>> PASSED: flake8 style check\n\n" -fi - -if [[ "${SKIP_YAPF}" != "" ]]; then - echo -e "\n\n>>>> SKIPPED: yapf check\n\n" -elif [ "${YAPF_RETVAL}" != "0" ]; then - echo -e "\n\n>>>> FAILED: yapf style check; begin output\n\n" - echo -e "Incorrectly formatted files:" - YAPF_OUTPUT=`echo "${YAPF_OUTPUT}" | sed -nr 's/^\+\+\+ ([^ ]*) *\(reformatted\)$/\1/p'` - echo -e "${YAPF_OUTPUT}" - echo -e "\n\n>>>> FAILED: yapf style check; end output\n\n" \ - "To auto-fix many issues (not all) run:\n" \ - " ./ci/scripts/fix_all.sh\n\n" -else - echo -e "\n\n>>>> PASSED: yapf style check\n\n" -fi - -RETVALS=(${ISORT_RETVAL} ${FLAKE_RETVAL} ${YAPF_RETVAL}) -IFS=$'\n' -RETVAL=`echo "${RETVALS[*]}" | sort -nr | head -n1` - -exit $RETVAL diff --git a/ci/scripts/run_ci_local.sh b/ci/scripts/run_ci_local.sh new file mode 100755 index 000000000..41299c3fb --- /dev/null +++ b/ci/scripts/run_ci_local.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +case "$1" in + "" ) + STAGES=("bash") + ;; + "all" ) + STAGES=("checks" "build-clang" "build-gcc" "test-clang" "test-gcc" "codecov" "docs" "benchmark" "conda") + ;; + "build" ) + STAGES=("build-clang" "build-gcc") + ;; + "test" ) + STAGES=("test-clang" "test-gcc") + ;; + "checks" | "build-clang" | "build-gcc" | "test" | "test-clang" | "test-gcc" | "codecov" | "docs" | "benchmark" | \ + "conda" | "bash" ) + STAGES=("$1") + ;; + * ) + echo "Error: Invalid argument \"$1\" provided. Expected values: \"all\", \"checks\", \"build\", " \ + "\"build-clang\", \"build-gcc\", \"test\", \"test-clang\", \"test-gcc\", \"codecov\"," \ + "\"docs\", \"benchmark\", \"conda\" or \"bash\"" + exit 1 + ;; +esac + +# CI image doesn't contain ssh, need to use https +function git_ssh_to_https() +{ + local url=$1 + echo $url | sed -e 's|^git@github\.com:|https://github.com/|' +} + +MRC_ROOT=${MRC_ROOT:-$(git rev-parse --show-toplevel)} + +GIT_URL=$(git remote get-url origin) +GIT_URL=$(git_ssh_to_https ${GIT_URL}) + +GIT_UPSTREAM_URL=$(git remote get-url upstream) +GIT_UPSTREAM_URL=$(git_ssh_to_https ${GIT_UPSTREAM_URL}) + +GIT_BRANCH=$(git branch --show-current) +GIT_COMMIT=$(git log -n 1 --pretty=format:%H) + +BASE_LOCAL_CI_TMP=${BASE_LOCAL_CI_TMP:-${MRC_ROOT}/.tmp/local_ci_tmp} +CONTAINER_VER=${CONTAINER_VER:-241002} +CUDA_VER=${CUDA_VER:-12.1} +DOCKER_EXTRA_ARGS=${DOCKER_EXTRA_ARGS:-""} + +BUILD_CONTAINER="nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-${CONTAINER_VER}" +TEST_CONTAINER="nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-${CONTAINER_VER}" + +# These variables are common to all stages +BASE_ENV_LIST="--env LOCAL_CI_TMP=/ci_tmp" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_URL=${GIT_URL}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_UPSTREAM_URL=${GIT_UPSTREAM_URL}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_BRANCH=${GIT_BRANCH}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_COMMIT=${GIT_COMMIT}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env PARALLEL_LEVEL=$(nproc)" +BASE_ENV_LIST="${BASE_ENV_LIST} --env CUDA_VER=${CUDA_VER}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env SKIP_CONDA_ENV_UPDATE=${SKIP_CONDA_ENV_UPDATE}" + +for STAGE in "${STAGES[@]}"; do + # Take a copy of the base env list, then make stage specific changes + ENV_LIST="${BASE_ENV_LIST}" + + if [[ "${STAGE}" =~ benchmark|clang|codecov|gcc ]]; then + if [[ "${STAGE}" =~ "clang" ]]; then + BUILD_CC="clang" + elif [[ "${STAGE}" == "codecov" ]]; then + BUILD_CC="gcc-coverage" + else + BUILD_CC="gcc" + fi + + ENV_LIST="${ENV_LIST} --env BUILD_CC=${BUILD_CC}" + LOCAL_CI_TMP="${BASE_LOCAL_CI_TMP}/${BUILD_CC}" + mkdir -p ${LOCAL_CI_TMP} + else + LOCAL_CI_TMP="${BASE_LOCAL_CI_TMP}" + fi + + mkdir -p ${LOCAL_CI_TMP} + cp ${MRC_ROOT}/ci/scripts/bootstrap_local_ci.sh ${LOCAL_CI_TMP} + + + DOCKER_RUN_ARGS="--rm -ti --net=host -v "${LOCAL_CI_TMP}":/ci_tmp ${ENV_LIST} --env STAGE=${STAGE}" + if [[ "${STAGE}" =~ "test" || "${STAGE}" =~ "codecov" || "${USE_GPU}" == "1" ]]; then + CONTAINER="${TEST_CONTAINER}" + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --runtime=nvidia --gpus all --cap-add=sys_nice --cap-add=sys_ptrace" + else + CONTAINER="${BUILD_CONTAINER}" + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --runtime=runc" + if [[ "${STAGE}" == "benchmark" ]]; then + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --cap-add=sys_nice --cap-add=sys_ptrace" + fi + fi + + if [[ "${STAGE}" == "bash" ]]; then + DOCKER_RUN_CMD="bash --init-file /ci_tmp/bootstrap_local_ci.sh" + else + DOCKER_RUN_CMD="/ci_tmp/bootstrap_local_ci.sh" + fi + + echo "Running ${STAGE} stage in ${CONTAINER}" + docker run ${DOCKER_RUN_ARGS} ${DOCKER_EXTRA_ARGS} ${CONTAINER} ${DOCKER_RUN_CMD} + + STATUS=$? + if [[ ${STATUS} -ne 0 ]]; then + echo "Error: docker exited with a non-zero status code for ${STAGE} of ${STATUS}" + exit ${STATUS} + fi +done diff --git a/ci/scripts/run_clang_tidy_for_ci.sh b/ci/scripts/run_clang_tidy_for_ci.sh index 54191c68b..b0a7dc2c1 100755 --- a/ci/scripts/run_clang_tidy_for_ci.sh +++ b/ci/scripts/run_clang_tidy_for_ci.sh @@ -1,5 +1,5 @@ #!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,11 +16,7 @@ # set -x -# Call clang-tidy adding warnings-as-errors option. Currently this is not -# possible with clang-tidy-diff.py until this is merged: -# https://reviews.llvm.org/D49864 - # Also add -fno-caret-diagnostics to prevent clangs own compiler warnings from # coming through: # https://github.com/llvm/llvm-project/blob/3f3faa36ff3d84af3c3ed84772d7e4278bc44ff1/libc/cmake/modules/LLVMLibCObjectRules.cmake#L226 -${CLANG_TIDY:-clang-tidy} --warnings-as-errors='*' --extra-arg=-fno-caret-diagnostics "$@" +${CLANG_TIDY:-clang-tidy} --extra-arg=-fno-caret-diagnostics "$@" diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index 3e09a3524..477bb374e 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,8 +24,8 @@ morpheus_utils_initialize_cpm(MRC_CACHE_DIR) # Start with CUDA. Need to add it to our export set rapids_find_package(CUDAToolkit REQUIRED - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - INSTALL_EXPORT_SET ${PROJECT_NAME}-core-exports + BUILD_EXPORT_SET ${PROJECT_NAME}-exports + INSTALL_EXPORT_SET ${PROJECT_NAME}-exports ) # Boost @@ -40,22 +40,16 @@ morpheus_utils_configure_ucx() # ===== morpheus_utils_configure_hwloc() +# cccl +# ========= +morpheus_utils_configure_cccl() + # NVIDIA RAPIDS RMM # ================= morpheus_utils_configure_rmm() -# gflags -# ====== -rapids_find_package(gflags REQUIRED - GLOBAL_TARGETS gflags - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - INSTALL_EXPORT_SET ${PROJECT_NAME}-core-exports -) - # glog # ==== -# - link against shared -# - todo: compile with -DWITH_GFLAGS=OFF and remove gflags dependency morpheus_utils_configure_glog() # nvidia cub @@ -69,11 +63,11 @@ find_path(CUB_INCLUDE_DIRS "cub/cub.cuh" # ========= rapids_find_package(gRPC REQUIRED GLOBAL_TARGETS - gRPC::address_sorting gRPC::gpr gRPC::grpc gRPC::grpc_unsecure gRPC::grpc++ gRPC::grpc++_alts gRPC::grpc++_error_details gRPC::grpc++_reflection - gRPC::grpc++_unsecure gRPC::grpc_plugin_support gRPC::grpcpp_channelz gRPC::upb gRPC::grpc_cpp_plugin gRPC::grpc_csharp_plugin gRPC::grpc_node_plugin - gRPC::grpc_objective_c_plugin gRPC::grpc_php_plugin gRPC::grpc_python_plugin gRPC::grpc_ruby_plugin - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - INSTALL_EXPORT_SET ${PROJECT_NAME}-core-exports + gRPC::address_sorting gRPC::gpr gRPC::grpc gRPC::grpc_unsecure gRPC::grpc++ gRPC::grpc++_alts gRPC::grpc++_error_details gRPC::grpc++_reflection + gRPC::grpc++_unsecure gRPC::grpc_plugin_support gRPC::grpcpp_channelz gRPC::upb gRPC::grpc_cpp_plugin gRPC::grpc_csharp_plugin gRPC::grpc_node_plugin + gRPC::grpc_objective_c_plugin gRPC::grpc_php_plugin gRPC::grpc_python_plugin gRPC::grpc_ruby_plugin + BUILD_EXPORT_SET ${PROJECT_NAME}-exports + INSTALL_EXPORT_SET ${PROJECT_NAME}-exports ) # RxCpp @@ -84,8 +78,8 @@ morpheus_utils_configure_rxcpp() # ====== rapids_find_package(nlohmann_json REQUIRED GLOBAL_TARGETS nlohmann_json::nlohmann_json - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - INSTALL_EXPORT_SET ${PROJECT_NAME}-core-exports + BUILD_EXPORT_SET ${PROJECT_NAME}-exports + INSTALL_EXPORT_SET ${PROJECT_NAME}-exports FIND_ARGS CONFIG ) @@ -94,33 +88,21 @@ rapids_find_package(nlohmann_json REQUIRED # ========= morpheus_utils_configure_prometheus_cpp() -# libcudacxx -# ========= -morpheus_utils_configure_libcudacxx() - if(MRC_BUILD_BENCHMARKS) # google benchmark # ================ - rapids_find_package(benchmark REQUIRED - GLOBAL_TARGETS benchmark::benchmark - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - - # No install set - FIND_ARGS - CONFIG + include(${rapids-cmake-dir}/cpm/gbench.cmake) + rapids_cpm_gbench( + BUILD_EXPORT_SET ${PROJECT_NAME}-exports ) endif() if(MRC_BUILD_TESTS) # google test # =========== - rapids_find_package(GTest REQUIRED - GLOBAL_TARGETS GTest::gtest GTest::gmock GTest::gtest_main GTest::gmock_main - BUILD_EXPORT_SET ${PROJECT_NAME}-core-exports - - # No install set - FIND_ARGS - CONFIG + include(${rapids-cmake-dir}/cpm/gtest.cmake) + rapids_cpm_gtest( + BUILD_EXPORT_SET ${PROJECT_NAME}-exports ) endif() diff --git a/conda/environments/all_cuda-121_arch-x86_64.yaml b/conda/environments/all_cuda-121_arch-x86_64.yaml new file mode 100644 index 000000000..7e6e17b84 --- /dev/null +++ b/conda/environments/all_cuda-121_arch-x86_64.yaml @@ -0,0 +1,56 @@ +# This file is generated by `rapids-dependency-file-generator`. +# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. +channels: +- conda-forge +- rapidsai +- rapidsai-nightly +- nvidia +dependencies: +- bash-completion +- benchmark=1.8.3 +- boost-cpp=1.84 +- ccache +- clang-tools=16 +- clang=16 +- clangdev=16 +- clangxx=16 +- cmake=3.27 +- codecov=2.1 +- cuda-cudart-dev=12.1 +- cuda-nvcc +- cuda-nvml-dev=12.1 +- cuda-nvrtc-dev=12.1 +- cuda-tools=12.1 +- cuda-version=12.1 +- cxx-compiler +- doxygen=1.9.2 +- flake8 +- gcovr=5.2 +- gdb +- glog=0.6 +- gtest=1.14 +- gxx=11.2 +- include-what-you-use=0.20 +- libclang-cpp=16 +- libclang=16 +- libgrpc=1.59 +- libhwloc=2.9.2 +- librmm=24.02 +- libxml2=2.11.6 +- llvmdev=16 +- ninja=1.11 +- nlohmann_json=3.11 +- numactl-libs-cos7-x86_64 +- numpy=1.24 +- pkg-config=0.29 +- pre-commit +- pybind11-stubgen=0.10 +- pytest +- pytest-asyncio +- pytest-timeout +- python-graphviz +- python=3.10 +- scikit-build=0.17 +- ucx=1.15 +- yapf +name: all_cuda-121_arch-x86_64 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml new file mode 100644 index 000000000..1e672e8b0 --- /dev/null +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -0,0 +1,55 @@ +# This file is generated by `rapids-dependency-file-generator`. +# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. +channels: +- conda-forge +- rapidsai +- rapidsai-nightly +- nvidia +dependencies: +- bash-completion +- benchmark=1.8.3 +- boost-cpp=1.84 +- ccache +- clang-tools=16 +- clang=16 +- clangdev=16 +- clangxx=16 +- cmake=3.27 +- codecov=2.1 +- cuda-cudart-dev=12.5 +- cuda-nvcc +- cuda-nvml-dev=12.5 +- cuda-nvrtc-dev=12.5 +- cuda-version=12.5 +- cxx-compiler +- doxygen=1.9.2 +- flake8 +- gcovr=5.2 +- gdb +- glog>=0.7.1,<0.8 +- gtest=1.14 +- gxx=12.1 +- include-what-you-use=0.20 +- libclang-cpp=16 +- libclang=16 +- libgrpc=1.62.2 +- libhwloc=2.9.2 +- librmm=24.10 +- libxml2=2.11.6 +- llvmdev=16 +- ninja=1.11 +- nlohmann_json=3.11 +- numactl=2.0.18 +- numpy=1.24 +- pkg-config=0.29 +- pre-commit +- pybind11-stubgen=0.10 +- pytest +- pytest-asyncio +- pytest-timeout +- python-graphviz +- python=3.10 +- scikit-build=0.17 +- ucx=1.15 +- yapf +name: all_cuda-125_arch-x86_64 diff --git a/conda/environments/ci_cuda-121_arch-x86_64.yaml b/conda/environments/ci_cuda-121_arch-x86_64.yaml new file mode 100644 index 000000000..6d5ccef0a --- /dev/null +++ b/conda/environments/ci_cuda-121_arch-x86_64.yaml @@ -0,0 +1,44 @@ +# This file is generated by `rapids-dependency-file-generator`. +# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. +channels: +- conda-forge +- rapidsai +- rapidsai-nightly +- nvidia +dependencies: +- benchmark=1.8.3 +- boost-cpp=1.84 +- ccache +- cmake=3.27 +- codecov=2.1 +- cuda-cudart-dev=12.1 +- cuda-nvcc +- cuda-nvml-dev=12.1 +- cuda-nvrtc-dev=12.1 +- cuda-tools=12.1 +- cuda-version=12.1 +- cxx-compiler +- doxygen=1.9.2 +- gcovr=5.2 +- glog=0.6 +- gtest=1.14 +- gxx=11.2 +- include-what-you-use=0.20 +- libgrpc=1.59 +- libhwloc=2.9.2 +- librmm=24.02 +- libxml2=2.11.6 +- ninja=1.11 +- nlohmann_json=3.11 +- numactl-libs-cos7-x86_64 +- pkg-config=0.29 +- pre-commit +- pybind11-stubgen=0.10 +- pytest +- pytest-asyncio +- pytest-timeout +- python-graphviz +- python=3.10 +- scikit-build=0.17 +- ucx=1.15 +name: ci_cuda-121_arch-x86_64 diff --git a/conda/environments/ci_cuda-125_arch-x86_64.yaml b/conda/environments/ci_cuda-125_arch-x86_64.yaml new file mode 100644 index 000000000..78cf2d601 --- /dev/null +++ b/conda/environments/ci_cuda-125_arch-x86_64.yaml @@ -0,0 +1,43 @@ +# This file is generated by `rapids-dependency-file-generator`. +# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. +channels: +- conda-forge +- rapidsai +- rapidsai-nightly +- nvidia +dependencies: +- benchmark=1.8.3 +- boost-cpp=1.84 +- ccache +- cmake=3.27 +- codecov=2.1 +- cuda-cudart-dev=12.5 +- cuda-nvcc +- cuda-nvml-dev=12.5 +- cuda-nvrtc-dev=12.5 +- cuda-version=12.5 +- cxx-compiler +- doxygen=1.9.2 +- gcovr=5.2 +- glog>=0.7.1,<0.8 +- gtest=1.14 +- gxx=12.1 +- include-what-you-use=0.20 +- libgrpc=1.62.2 +- libhwloc=2.9.2 +- librmm=24.10 +- libxml2=2.11.6 +- ninja=1.11 +- nlohmann_json=3.11 +- numactl=2.0.18 +- pkg-config=0.29 +- pre-commit +- pybind11-stubgen=0.10 +- pytest +- pytest-asyncio +- pytest-timeout +- python-graphviz +- python=3.10 +- scikit-build=0.17 +- ucx=1.15 +name: ci_cuda-125_arch-x86_64 diff --git a/cpp/mrc/CMakeLists.txt b/cpp/mrc/CMakeLists.txt index a0af3cbcd..88ac29a70 100644 --- a/cpp/mrc/CMakeLists.txt +++ b/cpp/mrc/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,6 +38,7 @@ add_library(libmrc src/internal/data_plane/server.cpp src/internal/executor/executor_definition.cpp src/internal/grpc/progress_engine.cpp + src/internal/grpc/promise_handler.cpp src/internal/grpc/server.cpp src/internal/memory/device_resources.cpp src/internal/memory/host_resources.cpp @@ -114,13 +115,17 @@ add_library(libmrc src/public/core/logging.cpp src/public/core/thread.cpp src/public/coroutines/event.cpp + src/public/coroutines/io_scheduler.cpp src/public/coroutines/sync_wait.cpp + src/public/coroutines/task_container.cpp + src/public/coroutines/test_scheduler.cpp src/public/coroutines/thread_local_context.cpp src/public/coroutines/thread_pool.cpp src/public/cuda/device_guard.cpp src/public/cuda/sync.cpp src/public/edge/edge_adapter_registry.cpp src/public/edge/edge_builder.cpp + src/public/exceptions/exception_catcher.cpp src/public/manifold/manifold.cpp src/public/memory/buffer_view.cpp src/public/memory/codable/buffer.cpp @@ -149,6 +154,7 @@ add_library(libmrc src/public/runnable/types.cpp src/public/runtime/remote_descriptor.cpp src/public/utils/bytes_to_string.cpp + src/public/utils/string_utils.cpp src/public/utils/thread_utils.cpp src/public/utils/type_utils.cpp ) @@ -157,19 +163,18 @@ add_library(${PROJECT_NAME}::libmrc ALIAS libmrc) target_link_libraries(libmrc PUBLIC - mrc_protos - mrc_architect_protos - rmm::rmm - CUDA::cudart - rxcpp::rxcpp - glog::glog - libcudacxx::libcudacxx - Boost::fiber Boost::context + Boost::fiber + CUDA::cudart glog::glog - gRPC::grpc++ - gRPC::grpc gRPC::gpr + gRPC::grpc + gRPC::grpc++ + libcudacxx::libcudacxx + mrc_architect_protos + mrc_protos + rmm::rmm + rxcpp::rxcpp PRIVATE hwloc::hwloc prometheus-cpp::core # private in MR !199 @@ -191,7 +196,7 @@ target_compile_definitions(libmrc $<$:MRC_ENABLE_BENCHMARKING> ) -if (MRC_ENABLE_CODECOV) +if(MRC_ENABLE_CODECOV) target_compile_definitions(libmrc INTERFACE "MRC_CODECOV_ENABLED") endif() @@ -201,7 +206,6 @@ set_target_properties(libmrc PROPERTIES OUTPUT_NAME ${PROJECT_NAME}) # ################################################################################################## # - install targets -------------------------------------------------------------------------------- - rapids_cmake_install_lib_dir(lib_dir) include(CPack) include(GNUInstallDirs) @@ -209,7 +213,7 @@ include(GNUInstallDirs) install( TARGETS libmrc DESTINATION ${lib_dir} - EXPORT ${PROJECT_NAME}-core-exports + EXPORT ${PROJECT_NAME}-exports COMPONENT Core ) @@ -221,7 +225,6 @@ install( # ################################################################################################## # - subdirectories --------------------------------------------------------------------------------- - if(MRC_BUILD_TESTS) add_subdirectory(tests) @@ -234,7 +237,6 @@ endif() # ################################################################################################## # - install export --------------------------------------------------------------------------------- - set(doc_string [=[ Provide targets for mrc. @@ -247,7 +249,7 @@ set(rapids_project_version_compat SameMinorVersion) # Need to explicitly set VERSION ${PROJECT_VERSION} here since rapids_cmake gets # confused with the `RAPIDS_VERSION` variable we use rapids_export(INSTALL ${PROJECT_NAME} - EXPORT_SET ${PROJECT_NAME}-core-exports + EXPORT_SET ${PROJECT_NAME}-exports GLOBAL_TARGETS libmrc VERSION ${PROJECT_VERSION} NAMESPACE mrc:: @@ -258,7 +260,7 @@ rapids_export(INSTALL ${PROJECT_NAME} # ################################################################################################## # - build export ---------------------------------------------------------------------------------- rapids_export(BUILD ${PROJECT_NAME} - EXPORT_SET ${PROJECT_NAME}-core-exports + EXPORT_SET ${PROJECT_NAME}-exports GLOBAL_TARGETS libmrc VERSION ${PROJECT_VERSION} LANGUAGES C CXX CUDA diff --git a/cpp/mrc/benchmarks/bench_baselines.cpp b/cpp/mrc/benchmarks/bench_baselines.cpp index a57fff83f..6d9b737b9 100644 --- a/cpp/mrc/benchmarks/bench_baselines.cpp +++ b/cpp/mrc/benchmarks/bench_baselines.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +19,6 @@ #include "mrc/benchmarking/util.hpp" #include -#include #include #include @@ -27,7 +26,6 @@ #include #include #include -#include #include #include diff --git a/cpp/mrc/benchmarks/bench_coroutines.cpp b/cpp/mrc/benchmarks/bench_coroutines.cpp index 443806ccc..b6f1b22ed 100644 --- a/cpp/mrc/benchmarks/bench_coroutines.cpp +++ b/cpp/mrc/benchmarks/bench_coroutines.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +24,7 @@ #include #include +#include #include #include diff --git a/cpp/mrc/benchmarks/bench_fibers.cpp b/cpp/mrc/benchmarks/bench_fibers.cpp index bd75ae526..09b176ab1 100644 --- a/cpp/mrc/benchmarks/bench_fibers.cpp +++ b/cpp/mrc/benchmarks/bench_fibers.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +21,8 @@ #include #include +#include + static void boost_fibers_create_single_task_and_sync_post(benchmark::State& state) { // warmup diff --git a/cpp/mrc/benchmarks/bench_segment.cpp b/cpp/mrc/benchmarks/bench_segment.cpp index 2ddeed4e2..75c1e1ea1 100644 --- a/cpp/mrc/benchmarks/bench_segment.cpp +++ b/cpp/mrc/benchmarks/bench_segment.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,9 +18,6 @@ #include "mrc/benchmarking/segment_watcher.hpp" #include "mrc/benchmarking/tracer.hpp" #include "mrc/benchmarking/util.hpp" -#include "mrc/node/rx_node.hpp" -#include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/pipeline/executor.hpp" #include "mrc/pipeline/pipeline.hpp" #include "mrc/segment/builder.hpp" // IWYU pragma: keep @@ -33,7 +30,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/mrc/include/mrc/channel/status.hpp b/cpp/mrc/include/mrc/channel/status.hpp index 91f6b2800..4c9734a69 100644 --- a/cpp/mrc/include/mrc/channel/status.hpp +++ b/cpp/mrc/include/mrc/channel/status.hpp @@ -17,6 +17,8 @@ #pragma once +#include + namespace mrc::channel { enum class Status @@ -29,4 +31,25 @@ enum class Status error }; +static inline std::ostream& operator<<(std::ostream& os, const Status& s) +{ + switch (s) + { + case Status::success: + return os << "success"; + case Status::empty: + return os << "empty"; + case Status::full: + return os << "full"; + case Status::closed: + return os << "closed"; + case Status::timeout: + return os << "timeout"; + case Status::error: + return os << "error"; + default: + throw std::logic_error("Unsupported channel::Status enum. Was a new value added recently?"); + } } + +} // namespace mrc::channel diff --git a/cpp/mrc/include/mrc/core/userspace_threads.hpp b/cpp/mrc/include/mrc/core/userspace_threads.hpp index 19e36c9c2..273b04b3a 100644 --- a/cpp/mrc/include/mrc/core/userspace_threads.hpp +++ b/cpp/mrc/include/mrc/core/userspace_threads.hpp @@ -19,44 +19,51 @@ #include -namespace mrc { +namespace mrc::userspace_threads { -struct userspace_threads // NOLINT -{ - using mutex = boost::fibers::mutex; // NOLINT +// Suppress naming conventions in this file to allow matching std and boost libraries +// NOLINTBEGIN(readability-identifier-naming) + +using mutex = boost::fibers::mutex; + +using recursive_mutex = boost::fibers::recursive_mutex; - using cv = boost::fibers::condition_variable; // NOLINT +using cv = boost::fibers::condition_variable; - using launch = boost::fibers::launch; // NOLINT +using cv_any = boost::fibers::condition_variable_any; - template - using promise = boost::fibers::promise; // NOLINT +using launch = boost::fibers::launch; - template - using future = boost::fibers::future; // NOLINT +template +using promise = boost::fibers::promise; - template - using shared_future = boost::fibers::shared_future; // NOLINT +template +using future = boost::fibers::future; - template // NOLINT - using packaged_task = boost::fibers::packaged_task; // NOLINT +template +using shared_future = boost::fibers::shared_future; - template // NOLINT - static auto async(Function&& f, Args&&... args) - { - return boost::fibers::async(f, std::forward(args)...); - } +template +using packaged_task = boost::fibers::packaged_task; + +template +static auto async(Function&& f, Args&&... args) +{ + return boost::fibers::async(f, std::forward(args)...); +} + +template +static void sleep_for(std::chrono::duration const& timeout_duration) +{ + boost::this_fiber::sleep_for(timeout_duration); +} + +template +static void sleep_until(std::chrono::time_point const& sleep_time_point) +{ + boost::this_fiber::sleep_until(sleep_time_point); +} - template // NOLINT - static void sleep_for(std::chrono::duration const& timeout_duration) - { - boost::this_fiber::sleep_for(timeout_duration); - } +// NOLINTEND(readability-identifier-naming) - template // NOLINT - static void sleep_until(std::chrono::time_point const& sleep_time_point) - { - boost::this_fiber::sleep_until(sleep_time_point); - } -}; -} // namespace mrc +} // namespace mrc::userspace_threads diff --git a/cpp/mrc/include/mrc/core/utils.hpp b/cpp/mrc/include/mrc/core/utils.hpp index 84e2f8e06..78410149a 100644 --- a/cpp/mrc/include/mrc/core/utils.hpp +++ b/cpp/mrc/include/mrc/core/utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -60,9 +61,12 @@ std::set extract_keys(const std::map& stdmap) class Unwinder { public: - explicit Unwinder(std::function unwind_fn) : m_unwind_fn(std::move(unwind_fn)) {} + explicit Unwinder(std::function unwind_fn) : + m_unwind_fn(std::move(unwind_fn)), + m_ctor_exception_count(std::uncaught_exceptions()) + {} - ~Unwinder() + ~Unwinder() noexcept(false) { if (!!m_unwind_fn) { @@ -71,8 +75,14 @@ class Unwinder m_unwind_fn(); } catch (...) { - LOG(ERROR) << "Fatal error during unwinder function"; - std::terminate(); + if (std::uncaught_exceptions() > m_ctor_exception_count) + { + LOG(ERROR) << "Error occurred during unwinder function, but another exception is active."; + std::terminate(); + } + + LOG(ERROR) << "Error occurred during unwinder function. Rethrowing"; + throw; } } } @@ -92,6 +102,9 @@ class Unwinder } private: + // Stores the number of active exceptions during creation. If the number of active exceptions during destruction is + // greater, we do not throw and log error and terminate + int m_ctor_exception_count; std::function m_unwind_fn; }; diff --git a/cpp/mrc/include/mrc/coroutines/async_generator.hpp b/cpp/mrc/include/mrc/coroutines/async_generator.hpp new file mode 100644 index 000000000..22036c2e7 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/async_generator.hpp @@ -0,0 +1,399 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/lewissbaker/cppcoro + * Original License: MIT; included below + */ + +// Copyright 2017 Lewis Baker + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is furnished +// to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "mrc/utils/macros.hpp" + +#include + +#include +#include +#include +#include + +namespace mrc::coroutines { + +template +class AsyncGenerator; + +namespace detail { + +template +class AsyncGeneratorIterator; +class AsyncGeneratorYieldOperation; +class AsyncGeneratorAdvanceOperation; + +class AsyncGeneratorPromiseBase +{ + public: + AsyncGeneratorPromiseBase() noexcept : m_exception(nullptr) {} + + DELETE_COPYABILITY(AsyncGeneratorPromiseBase) + + constexpr static std::suspend_always initial_suspend() noexcept + { + return {}; + } + + AsyncGeneratorYieldOperation final_suspend() noexcept; + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + auto return_void() noexcept -> void {} + + auto finished() const noexcept -> bool + { + return m_value == nullptr; + } + + auto rethrow_on_unhandled_exception() -> void + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + protected: + AsyncGeneratorYieldOperation internal_yield_value() noexcept; + void* m_value{nullptr}; + + private: + std::exception_ptr m_exception; + std::coroutine_handle<> m_consumer; + + friend class AsyncGeneratorYieldOperation; + friend class AsyncGeneratorAdvanceOperation; +}; + +class AsyncGeneratorYieldOperation final +{ + public: + AsyncGeneratorYieldOperation(std::coroutine_handle<> consumer) noexcept : m_consumer(consumer) {} + + constexpr static bool await_ready() noexcept + { + return false; + } + + std::coroutine_handle<> await_suspend([[maybe_unused]] std::coroutine_handle<> producer) const noexcept + { + return m_consumer; + } + + constexpr static void await_resume() noexcept {} + + private: + std::coroutine_handle<> m_consumer; +}; + +inline AsyncGeneratorYieldOperation AsyncGeneratorPromiseBase::final_suspend() noexcept +{ + m_value = nullptr; + return internal_yield_value(); +} + +inline AsyncGeneratorYieldOperation AsyncGeneratorPromiseBase::internal_yield_value() noexcept +{ + return AsyncGeneratorYieldOperation{m_consumer}; +} + +class AsyncGeneratorAdvanceOperation +{ + protected: + AsyncGeneratorAdvanceOperation(std::nullptr_t) noexcept : m_promise(nullptr), m_producer(nullptr) {} + + AsyncGeneratorAdvanceOperation(AsyncGeneratorPromiseBase& promise, std::coroutine_handle<> producer) noexcept : + m_promise(std::addressof(promise)), + m_producer(producer) + {} + + public: + constexpr static bool await_ready() noexcept + { + return false; + } + + std::coroutine_handle<> await_suspend(std::coroutine_handle<> consumer) noexcept + { + m_promise->m_consumer = consumer; + return m_producer; + } + + protected: + AsyncGeneratorPromiseBase* m_promise; + std::coroutine_handle<> m_producer; +}; + +template +class AsyncGeneratorPromise final : public AsyncGeneratorPromiseBase +{ + using value_t = std::remove_reference_t; + using reference_t = std::conditional_t, T, T&>; + using pointer_t = value_t*; + + public: + AsyncGeneratorPromise() noexcept = default; + + AsyncGenerator get_return_object() noexcept; + + template ::value, int> = 0> + auto yield_value(value_t& value) noexcept -> AsyncGeneratorYieldOperation + { + m_value = std::addressof(value); + return internal_yield_value(); + } + + auto yield_value(value_t&& value) noexcept -> AsyncGeneratorYieldOperation + { + m_value = std::addressof(value); + return internal_yield_value(); + } + + auto value() const noexcept -> reference_t + { + return *static_cast(m_value); + } +}; + +template +class AsyncGeneratorIncrementOperation final : public AsyncGeneratorAdvanceOperation +{ + public: + AsyncGeneratorIncrementOperation(AsyncGeneratorIterator& iterator) noexcept : + AsyncGeneratorAdvanceOperation(iterator.m_coroutine.promise(), iterator.m_coroutine), + m_iterator(iterator) + {} + + AsyncGeneratorIterator& await_resume(); + + private: + AsyncGeneratorIterator& m_iterator; +}; + +struct AsyncGeneratorSentinel +{}; + +template +class AsyncGeneratorIterator final +{ + using promise_t = AsyncGeneratorPromise; + using handle_t = std::coroutine_handle; + + public: + using iterator_category = std::input_iterator_tag; // NOLINT + // Not sure what type should be used for difference_type as we don't + // allow calculating difference between two iterators. + using difference_t = std::ptrdiff_t; + using value_t = std::remove_reference_t; + using reference = std::add_lvalue_reference_t; // NOLINT + using pointer = std::add_pointer_t; // NOLINT + + AsyncGeneratorIterator(std::nullptr_t) noexcept : m_coroutine(nullptr) {} + + AsyncGeneratorIterator(handle_t coroutine) noexcept : m_coroutine(coroutine) {} + + AsyncGeneratorIncrementOperation operator++() noexcept + { + return AsyncGeneratorIncrementOperation{*this}; + } + + reference operator*() const noexcept + { + return m_coroutine.promise().value(); + } + + bool operator==(const AsyncGeneratorIterator& other) const noexcept + { + return m_coroutine == other.m_coroutine; + } + + bool operator!=(const AsyncGeneratorIterator& other) const noexcept + { + return !(*this == other); + } + + operator bool() const noexcept + { + return m_coroutine && !m_coroutine.promise().finished(); + } + + private: + friend class AsyncGeneratorIncrementOperation; + + handle_t m_coroutine; +}; + +template +inline AsyncGeneratorIterator& AsyncGeneratorIncrementOperation::await_resume() +{ + if (m_promise->finished()) + { + // Update iterator to end() + m_iterator = AsyncGeneratorIterator{nullptr}; + m_promise->rethrow_on_unhandled_exception(); + } + + return m_iterator; +} + +template +class AsyncGeneratorBeginOperation final : public AsyncGeneratorAdvanceOperation +{ + using promise_t = AsyncGeneratorPromise; + using handle_t = std::coroutine_handle; + + public: + AsyncGeneratorBeginOperation(std::nullptr_t) noexcept : AsyncGeneratorAdvanceOperation(nullptr) {} + + AsyncGeneratorBeginOperation(handle_t producer) noexcept : + AsyncGeneratorAdvanceOperation(producer.promise(), producer) + {} + + bool await_ready() const noexcept + { + return m_promise == nullptr || AsyncGeneratorAdvanceOperation::await_ready(); + } + + AsyncGeneratorIterator await_resume() + { + if (m_promise == nullptr) + { + // Called begin() on the empty generator. + return AsyncGeneratorIterator{nullptr}; + } + + if (m_promise->finished()) + { + // Completed without yielding any values. + m_promise->rethrow_on_unhandled_exception(); + return AsyncGeneratorIterator{nullptr}; + } + + return AsyncGeneratorIterator{handle_t::from_promise(*static_cast(m_promise))}; + } +}; + +} // namespace detail + +template +class [[nodiscard]] AsyncGenerator +{ + public: + // There must be a type called `promise_type` for coroutines to work. Skil linting + using promise_type = detail::AsyncGeneratorPromise; // NOLINT(readability-identifier-naming) + using iterator = detail::AsyncGeneratorIterator; // NOLINT(readability-identifier-naming) + + AsyncGenerator() noexcept : m_coroutine(nullptr) {} + + explicit AsyncGenerator(promise_type& promise) noexcept : + m_coroutine(std::coroutine_handle::from_promise(promise)) + {} + + AsyncGenerator(AsyncGenerator&& other) noexcept : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + ~AsyncGenerator() + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + } + + AsyncGenerator& operator=(AsyncGenerator&& other) noexcept + { + AsyncGenerator temp(std::move(other)); + swap(temp); + return *this; + } + + AsyncGenerator(const AsyncGenerator&) = delete; + AsyncGenerator& operator=(const AsyncGenerator&) = delete; + + auto begin() noexcept + { + if (!m_coroutine) + { + return detail::AsyncGeneratorBeginOperation{nullptr}; + } + + return detail::AsyncGeneratorBeginOperation{m_coroutine}; + } + + auto end() noexcept + { + return iterator{nullptr}; + } + + void swap(AsyncGenerator& other) noexcept + { + using std::swap; + swap(m_coroutine, other.m_coroutine); + } + + private: + std::coroutine_handle m_coroutine; +}; + +template +void swap(AsyncGenerator& a, AsyncGenerator& b) noexcept +{ + a.swap(b); +} + +namespace detail { +template +AsyncGenerator AsyncGeneratorPromise::get_return_object() noexcept +{ + return AsyncGenerator{*this}; +} + +} // namespace detail + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp b/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp new file mode 100644 index 000000000..386dd7d32 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp @@ -0,0 +1,703 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/core/expected.hpp" +#include "mrc/coroutines/schedule_policy.hpp" +#include "mrc/coroutines/thread_local_context.hpp" +#include "mrc/coroutines/thread_pool.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +enum class RingBufferOpStatus +{ + Success, + Stopped, +}; + +/** + * @tparam ElementT The type of element the ring buffer will store. Note that this type should be + * cheap to move if possible as it is moved into and out of the buffer upon write and + * read operations. + */ +template +class ClosableRingBuffer +{ + using mutex_type = std::mutex; + + public: + struct Options + { + // capacity of ring buffer + std::size_t capacity{8}; + + // when there is an awaiting reader, the active execution context of the next writer will resume the awaiting + // reader, the schedule_policy_t dictates how that is accomplished. + SchedulePolicy reader_policy{SchedulePolicy::Reschedule}; + + // when there is an awaiting writer, the active execution context of the next reader will resume the awaiting + // writer, the producder_policy_t dictates how that is accomplished. + SchedulePolicy writer_policy{SchedulePolicy::Reschedule}; + + // when there is an awaiting writer, the active execution context of the next reader will resume the awaiting + // writer, the producder_policy_t dictates how that is accomplished. + SchedulePolicy completed_policy{SchedulePolicy::Reschedule}; + }; + + /** + * @throws std::runtime_error If `num_elements` == 0. + */ + explicit ClosableRingBuffer(Options opts = {}) : + m_elements(opts.capacity), // elements needs to be extended from just holding ElementT to include a TraceContext + m_num_elements(opts.capacity), + m_writer_policy(opts.writer_policy), + m_reader_policy(opts.reader_policy), + m_completed_policy(opts.completed_policy) + { + if (m_num_elements == 0) + { + throw std::runtime_error{"num_elements cannot be zero"}; + } + } + + ~ClosableRingBuffer() + { + // Wake up anyone still using the ring buffer. + notify_waiters(); + } + + ClosableRingBuffer(const ClosableRingBuffer&) = delete; + ClosableRingBuffer(ClosableRingBuffer&&) = delete; + + auto operator=(const ClosableRingBuffer&) noexcept -> ClosableRingBuffer& = delete; + auto operator=(ClosableRingBuffer&&) noexcept -> ClosableRingBuffer& = delete; + + struct Operation + { + virtual void resume() = 0; + }; + + struct WriteOperation : ThreadLocalContext, Operation + { + WriteOperation(ClosableRingBuffer& rb, ElementT e) : + m_rb(rb), + m_e(std::move(e)), + m_policy(m_rb.m_writer_policy) + {} + + auto await_ready() noexcept -> bool + { + // return immediate if the buffer is closed + if (m_rb.m_stopped.load(std::memory_order::acquire)) + { + m_stopped = true; + return true; + } + + // start a span to time the write - this would include time suspended if the buffer is full + // m_write_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + return m_rb.try_write_locked(m_lock, m_e); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); // use raii + + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_write_waiters; + m_rb.m_write_waiters = this; + return true; + } + + /** + * @return write_result + */ + auto await_resume() -> RingBufferOpStatus + { + ThreadLocalContext::resume_thread_local_context(); + return (!m_stopped ? RingBufferOpStatus::Success : RingBufferOpStatus::Stopped); + } + + WriteOperation& use_scheduling_policy(SchedulePolicy policy) & + { + m_policy = policy; + return *this; + } + + WriteOperation use_scheduling_policy(SchedulePolicy policy) && + { + m_policy = policy; + return std::move(*this); + } + + WriteOperation& resume_immediately() & + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + WriteOperation resume_immediately() && + { + m_policy = SchedulePolicy::Immediate; + return std::move(*this); + } + + WriteOperation& resume_on(ThreadPool* thread_pool) & + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + WriteOperation resume_on(ThreadPool* thread_pool) && + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return std::move(*this); + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer the element is being written into. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be written. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of write operations that are awaiting to write their element. + WriteOperation* m_next{nullptr}; + /// The element this write operation is producing into the ring buffer. + ElementT m_e; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span to measure the duration the writer spent writting data + // trace::Handle m_write_span{nullptr}; + }; + + struct ReadOperation : ThreadLocalContext, Operation + { + explicit ReadOperation(ClosableRingBuffer& rb) : m_rb(rb), m_policy(m_rb.m_reader_policy) {} + + auto await_ready() noexcept -> bool + { + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + // m_read_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + return m_rb.try_read_locked(m_lock, this); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); + + // the buffer is empty; don't suspend if the stop signal has been set. + if (m_rb.m_stopped.load(std::memory_order::acquire)) + { + m_stopped = true; + return false; + } + + // m_read_span->AddEvent("buffer_empty"); + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_read_waiters; + m_rb.m_read_waiters = this; + return true; + } + + /** + * @return The consumed element or std::nullopt if the read has failed. + */ + auto await_resume() -> mrc::expected + { + ThreadLocalContext::resume_thread_local_context(); + + if (m_stopped) + { + return mrc::unexpected(RingBufferOpStatus::Stopped); + } + + return std::move(m_e); + } + + ReadOperation& use_scheduling_policy(SchedulePolicy policy) + { + m_policy = policy; + return *this; + } + + ReadOperation& resume_immediately() + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + ReadOperation& resume_on(ThreadPool* thread_pool) + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer to read an element from. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be consumed. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of read operations that are awaiting to read an element. + ReadOperation* m_next{nullptr}; + /// The element this read operation will read. + ElementT m_e; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span measure time awaiting on reading data + // trace::Handle m_read_span; + }; + + struct CompletedOperation : ThreadLocalContext, Operation + { + explicit CompletedOperation(ClosableRingBuffer& rb) : m_rb(rb), m_policy(m_rb.m_completed_policy) {} + + auto await_ready() noexcept -> bool + { + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + // m_read_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + return m_rb.try_completed_locked(m_lock, this); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); + + // m_read_span->AddEvent("buffer_empty"); + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_completed_waiters; + m_rb.m_completed_waiters = this; + return true; + } + + /** + * @return The consumed element or std::nullopt if the read has failed. + */ + auto await_resume() + { + ThreadLocalContext::resume_thread_local_context(); + } + + ReadOperation& use_scheduling_policy(SchedulePolicy policy) + { + m_policy = policy; + return *this; + } + + ReadOperation& resume_immediately() + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + ReadOperation& resume_on(ThreadPool* thread_pool) + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer to read an element from. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be consumed. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of read operations that are awaiting to read an element. + CompletedOperation* m_next{nullptr}; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span measure time awaiting on reading data + // trace::Handle m_read_span; + }; + + /** + * Produces the given element into the ring buffer. This operation will suspend until a slot + * in the ring buffer becomes available. + * @param e The element to write. + */ + [[nodiscard]] auto write(ElementT e) -> WriteOperation + { + return WriteOperation{*this, std::move(e)}; + } + + /** + * Consumes an element from the ring buffer. This operation will suspend until an element in + * the ring buffer becomes available. + */ + [[nodiscard]] auto read() -> ReadOperation + { + return ReadOperation{*this}; + } + + /** + * Blocks until `close()` has been called and all elements have been returned + */ + [[nodiscard]] auto completed() -> CompletedOperation + { + return CompletedOperation{*this}; + } + + void close() + { + // if there are awaiting readers, then we must wait them up and signal that the buffer is closed; + // otherwise, mark the buffer as closed and fail all new writes immediately. readers should be allowed + // to keep reading until the buffer is empty. when the buffer is empty, readers will fail to suspend and exit + // with a stopped status + + // Only wake up waiters once. + if (m_stopped.load(std::memory_order::acquire)) + { + return; + } + + std::unique_lock lk{m_mutex}; + m_stopped.exchange(true, std::memory_order::release); + + // the buffer is empty and no more items will be added + if (m_used == 0) + { + // there should be no awaiting writers + CHECK(m_write_waiters == nullptr); + + // signal all awaiting readers that the buffer is stopped + while (m_read_waiters != nullptr) + { + auto* to_resume = m_read_waiters; + to_resume->m_stopped = true; + m_read_waiters = m_read_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + // signal all awaiting completed that the buffer is completed + while (m_completed_waiters != nullptr) + { + auto* to_resume = m_completed_waiters; + to_resume->m_stopped = true; + m_completed_waiters = m_completed_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + } + } + + bool is_closed() const noexcept + { + return m_stopped.load(std::memory_order::acquire); + } + + /** + * @return The current number of elements contained in the ring buffer. + */ + auto size() const -> size_t + { + std::atomic_thread_fence(std::memory_order::acquire); + return m_used; + } + + /** + * @return True if the ring buffer contains zero elements. + */ + auto empty() const -> bool + { + return size() == 0; + } + + /** + * Wakes up all currently awaiting writers and readers. Their await_resume() function + * will return an expected read result that the ring buffer has stopped. + */ + auto notify_waiters() -> void + { + // Only wake up waiters once. + if (m_stopped.load(std::memory_order::acquire)) + { + return; + } + + std::unique_lock lk{m_mutex}; + m_stopped.exchange(true, std::memory_order::release); + + while (m_write_waiters != nullptr) + { + auto* to_resume = m_write_waiters; + to_resume->m_stopped = true; + m_write_waiters = m_write_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + while (m_read_waiters != nullptr) + { + auto* to_resume = m_read_waiters; + to_resume->m_stopped = true; + m_read_waiters = m_read_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + while (m_completed_waiters != nullptr) + { + auto* to_resume = m_completed_waiters; + to_resume->m_stopped = true; + m_completed_waiters = m_completed_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + } + + private: + friend WriteOperation; + friend ReadOperation; + friend CompletedOperation; + + mutex_type m_mutex{}; + + std::vector m_elements; + const std::size_t m_num_elements; + const SchedulePolicy m_writer_policy; + const SchedulePolicy m_reader_policy; + const SchedulePolicy m_completed_policy; + + /// The current front pointer to an open slot if not full. + size_t m_front{0}; + /// The current back pointer to the oldest item in the buffer if not empty. + size_t m_back{0}; + /// The number of items in the ring buffer. + size_t m_used{0}; + + /// The LIFO list of write waiters - single writers will have order perserved + // Note: if there are multiple writers order can not be guaranteed, so no need for FIFO + WriteOperation* m_write_waiters{nullptr}; + /// The LIFO list of read watier. + ReadOperation* m_read_waiters{nullptr}; + /// The LIFO list of completed watier. + CompletedOperation* m_completed_waiters{nullptr}; + + std::atomic m_stopped{false}; + + auto try_write_locked(std::unique_lock& lk, ElementT& e) -> bool + { + if (m_used == m_num_elements) + { + DCHECK(m_read_waiters == nullptr); + return false; + } + + // We will be able to write an element into the buffer. + m_elements[m_front] = std::move(e); + m_front = (m_front + 1) % m_num_elements; + ++m_used; + + ReadOperation* to_resume = nullptr; + + if (m_read_waiters != nullptr) + { + to_resume = m_read_waiters; + m_read_waiters = m_read_waiters->m_next; + + // Since the read operation suspended it needs to be provided an element to read. + to_resume->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % m_num_elements; + --m_used; // And we just consumed up another item. + } + + // After this point we will no longer be checking state objects on the buffer + lk.unlock(); + + if (to_resume != nullptr) + { + to_resume->resume(); + } + + return true; + } + + auto try_read_locked(std::unique_lock& lk, ReadOperation* op) -> bool + { + if (m_used == 0) + { + return false; + } + + // We will be successful in reading an element from the buffer. + op->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % m_num_elements; + --m_used; + + WriteOperation* writer_to_resume = nullptr; + + if (m_write_waiters != nullptr) + { + writer_to_resume = m_write_waiters; + m_write_waiters = m_write_waiters->m_next; + + // Since the write operation suspended it needs to be provided a slot to place its element. + m_elements[m_front] = std::move(writer_to_resume->m_e); + m_front = (m_front + 1) % m_num_elements; + ++m_used; // And we just written another item. + } + + CompletedOperation* completed_waiters = nullptr; + + // Check if we are stopped and there are no more elements in the buffer. + if (m_used == 0 && m_stopped.load(std::memory_order::acquire)) + { + completed_waiters = m_completed_waiters; + m_completed_waiters = nullptr; + } + + // After this point we will no longer be checking state objects on the buffer + lk.unlock(); + + // Resume any writer + if (writer_to_resume != nullptr) + { + DCHECK(completed_waiters == nullptr) << "Logic error. Wrote value but count is 0"; + + writer_to_resume->resume(); + } + + // Resume completed if there are any + while (completed_waiters != nullptr) + { + completed_waiters->resume(); + + completed_waiters = completed_waiters->m_next; + } + + return true; + } + + auto try_completed_locked(std::unique_lock& lk, CompletedOperation* op) -> bool + { + // Condition is already met, no need to wait + if (!m_stopped.load(std::memory_order::acquire) || m_used >= 0) + { + return false; + } + + DCHECK(m_write_waiters == nullptr) << "Should not have any writers with a closed buffer"; + + // release lock + lk.unlock(); + + return true; + } +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/detail/poll_info.hpp b/cpp/mrc/include/mrc/coroutines/detail/poll_info.hpp new file mode 100644 index 000000000..d1173fe90 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/detail/poll_info.hpp @@ -0,0 +1,118 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/coroutines/fd.hpp" +#include "mrc/coroutines/poll.hpp" +#include "mrc/coroutines/time.hpp" + +#include +#include +#include +#include + +namespace mrc::coroutines::detail { +/** + * Poll Info encapsulates everything about a poll operation for the event as well as its paired + * timeout. This is important since coroutines that are waiting on an event or timeout do not + * immediately execute, they are re-scheduled onto the thread pool, so its possible its pair + * event or timeout also triggers while the coroutine is still waiting to resume. This means that + * the first one to happen, the event itself or its timeout, needs to disable the other pair item + * prior to resuming the coroutine. + * + * Finally, its also important to note that the event and its paired timeout could happen during + * the same epoll_wait and possibly trigger the coroutine to start twice. Only one can win, so the + * first one processed sets m_processed to true and any subsequent events in the same epoll batch + * are effectively discarded. + */ +struct PollInfo +{ + using timed_events_t = std::multimap; + + PollInfo() = default; + ~PollInfo() = default; + + PollInfo(const PollInfo&) = delete; + PollInfo(PollInfo&&) = delete; + auto operator=(const PollInfo&) -> PollInfo& = delete; + auto operator=(PollInfo&&) -> PollInfo& = delete; + + struct PollAwaiter + { + explicit PollAwaiter(PollInfo& pi) noexcept : m_pi(pi) {} + + static auto await_ready() noexcept -> bool + { + return false; + } + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> void + { + m_pi.m_awaiting_coroutine = awaiting_coroutine; + std::atomic_thread_fence(std::memory_order::release); + } + auto await_resume() const noexcept -> mrc::coroutines::PollStatus + { + return m_pi.m_poll_status; + } + + PollInfo& m_pi; + }; + + auto operator co_await() noexcept -> PollAwaiter + { + return PollAwaiter{*this}; + } + + /// The file descriptor being polled on. This is needed so that if the timeout occurs first then + /// the event loop can immediately disable the event within epoll. + fd_t m_fd{-1}; + /// The timeout's position in the timeout map. A poll() with no timeout or yield() this is empty. + /// This is needed so that if the event occurs first then the event loop can immediately disable + /// the timeout within epoll. + std::optional m_timer_pos{std::nullopt}; + /// The awaiting coroutine for this poll info to resume upon event or timeout. + std::coroutine_handle<> m_awaiting_coroutine; + /// The status of the poll operation. + mrc::coroutines::PollStatus m_poll_status{mrc::coroutines::PollStatus::error}; + /// Did the timeout and event trigger at the same time on the same epoll_wait call? + /// Once this is set to true all future events on this poll info are null and void. + bool m_processed{false}; +}; + +} // namespace mrc::coroutines::detail diff --git a/cpp/mrc/include/mrc/coroutines/fd.hpp b/cpp/mrc/include/mrc/coroutines/fd.hpp new file mode 100644 index 000000000..86a5e1563 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/fd.hpp @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace mrc::coroutines { +using fd_t = int; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/io_scheduler.hpp b/cpp/mrc/include/mrc/coroutines/io_scheduler.hpp new file mode 100644 index 000000000..0345a6c0c --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/io_scheduler.hpp @@ -0,0 +1,424 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/coroutines/detail/poll_info.hpp" +#include "mrc/coroutines/fd.hpp" +#include "mrc/coroutines/scheduler.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/thread_pool.hpp" +#include "mrc/coroutines/time.hpp" + +#ifdef LIBCORO_FEATURE_NETWORKING + #include "coro/net/socket.hpp" +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { +enum class PollOperation : uint64_t; +enum class PollStatus; + +class IoScheduler : public Scheduler +{ + private: + using timed_events_t = detail::PollInfo::timed_events_t; + + public: + static std::shared_ptr get_instance(); + + class schedule_operation; + + friend schedule_operation; + + enum class ThreadStrategy + { + /// Spawns a dedicated background thread for the scheduler to run on. + spawn, + /// Requires the user to call process_events() to drive the scheduler. + manual + }; + + enum class ExecutionStrategy + { + /// Tasks will be FIFO queued to be executed on a thread pool. This is better for tasks that + /// are long lived and will use lots of CPU because long lived tasks will block other i/o + /// operations while they complete. This strategy is generally better for lower latency + /// requirements at the cost of throughput. + process_tasks_on_thread_pool, + /// Tasks will be executed inline on the io scheduler thread. This is better for short tasks + /// that can be quickly processed and not block other i/o operations for very long. This + /// strategy is generally better for higher throughput at the cost of latency. + process_tasks_inline + }; + + struct Options + { + /// Should the io scheduler spawn a dedicated event processor? + ThreadStrategy thread_strategy{ThreadStrategy::spawn}; + /// If spawning a dedicated event processor a functor to call upon that thread starting. + std::function on_io_thread_start_functor{nullptr}; + /// If spawning a dedicated event processor a functor to call upon that thread stopping. + std::function on_io_thread_stop_functor{nullptr}; + /// Thread pool options for the task processor threads. See thread pool for more details. + ThreadPool::Options pool{ + .thread_count = ((std::thread::hardware_concurrency() > 1) ? (std::thread::hardware_concurrency() - 1) : 1), + .on_thread_start_functor = nullptr, + .on_thread_stop_functor = nullptr}; + + /// If inline task processing is enabled then the io worker will resume tasks on its thread + /// rather than scheduling them to be picked up by the thread pool. + const ExecutionStrategy execution_strategy{ExecutionStrategy::process_tasks_on_thread_pool}; + }; + + explicit IoScheduler(Options opts = Options{ + .thread_strategy = ThreadStrategy::spawn, + .on_io_thread_start_functor = nullptr, + .on_io_thread_stop_functor = nullptr, + .pool = {.thread_count = ((std::thread::hardware_concurrency() > 1) + ? (std::thread::hardware_concurrency() - 1) + : 1), + .on_thread_start_functor = nullptr, + .on_thread_stop_functor = nullptr}, + .execution_strategy = ExecutionStrategy::process_tasks_on_thread_pool}); + + IoScheduler(const IoScheduler&) = delete; + IoScheduler(IoScheduler&&) = delete; + auto operator=(const IoScheduler&) -> IoScheduler& = delete; + auto operator=(IoScheduler&&) -> IoScheduler& = delete; + + ~IoScheduler() override; + + /** + * Given a ThreadStrategy::manual this function should be called at regular intervals to + * process events that are ready. If a using ThreadStrategy::spawn this is run continously + * on a dedicated background thread and does not need to be manually invoked. + * @param timeout If no events are ready how long should the function wait for events to be ready? + * Passing zero (default) for the timeout will check for any events that are + * ready now, and then return. This could be zero events. Passing -1 means block + * indefinitely until an event happens. + * @param return The number of tasks currently executing or waiting to execute. + */ + auto process_events(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> std::size_t; + + class schedule_operation + { + friend class IoScheduler; + explicit schedule_operation(IoScheduler& scheduler) noexcept : m_scheduler(scheduler) {} + + public: + /** + * Operations always pause so the executing thread can be switched. + */ + static constexpr auto await_ready() noexcept -> bool + { + return false; + } + + /** + * Suspending always returns to the caller (using void return of await_suspend()) and + * stores the coroutine internally for the executing thread to resume from. + */ + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> void + { + if (m_scheduler.m_opts.execution_strategy == ExecutionStrategy::process_tasks_inline) + { + m_scheduler.m_size.fetch_add(1, std::memory_order::release); + { + std::scoped_lock lk{m_scheduler.m_scheduled_tasks_mutex}; + m_scheduler.m_scheduled_tasks.emplace_back(awaiting_coroutine); + } + + // Trigger the event to wake-up the scheduler if this event isn't currently triggered. + bool expected{false}; + if (m_scheduler.m_schedule_fd_triggered.compare_exchange_strong(expected, + true, + std::memory_order::release, + std::memory_order::relaxed)) + { + eventfd_t value{1}; + eventfd_write(m_scheduler.m_schedule_fd, value); + } + } + else + { + m_scheduler.m_thread_pool->resume(awaiting_coroutine); + } + } + + /** + * no-op as this is the function called first by the thread pool's executing thread. + */ + auto await_resume() noexcept -> void {} + + private: + /// The thread pool that this operation will execute on. + IoScheduler& m_scheduler; + }; + + /** + * Schedules the current task onto this IoScheduler for execution. + */ + auto schedule() -> schedule_operation + { + return schedule_operation{*this}; + } + + /** + * Schedules a task onto the IoScheduler and moves ownership of the task to the IoScheduler. + * Only void return type tasks can be scheduled in this manner since the task submitter will no + * longer have control over the scheduled task. + * @param task The task to execute on this IoScheduler. It's lifetime ownership will be transferred + * to this IoScheduler. + */ + auto schedule(mrc::coroutines::Task&& task) -> void; + + /** + * Schedules the current task to run after the given amount of time has elapsed. + * @param amount The amount of time to wait before resuming execution of this task. + * Given zero or negative amount of time this behaves identical to schedule(). + */ + [[nodiscard]] auto schedule_after(std::chrono::milliseconds amount) -> mrc::coroutines::Task; + + /** + * Schedules the current task to run at a given time point in the future. + * @param time The time point to resume execution of this task. Given 'now' or a time point + * in the past this behaves identical to schedule(). + */ + [[nodiscard]] auto schedule_at(time_point_t time) -> mrc::coroutines::Task; + + /** + * Yields the current task to the end of the queue of waiting tasks. + */ + [[nodiscard]] mrc::coroutines::Task yield() override + { + co_await schedule_operation{*this}; + }; + + /** + * Yields the current task for the given amount of time. + * @param amount The amount of time to yield for before resuming executino of this task. + * Given zero or negative amount of time this behaves identical to yield(). + */ + [[nodiscard]] mrc::coroutines::Task yield_for(std::chrono::milliseconds amount) override; + + /** + * Yields the current task until the given time point in the future. + * @param time The time point to resume execution of this task. Given 'now' or a time point in the + * in the past this behaves identical to yield(). + */ + [[nodiscard]] mrc::coroutines::Task yield_until(time_point_t time) override; + + /** + * Polls the given file descriptor for the given operations. + * @param fd The file descriptor to poll for events. + * @param op The operations to poll for. + * @param timeout The amount of time to wait for the events to trigger. A timeout of zero will + * block indefinitely until the event triggers. + * @return The result of the poll operation. + */ + [[nodiscard]] auto poll(fd_t fd, + mrc::coroutines::PollOperation op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> mrc::coroutines::Task; + +#ifdef LIBCORO_FEATURE_NETWORKING + /** + * Polls the given mrc::coroutines::net::socket for the given operations. + * @param sock The socket to poll for events on. + * @param op The operations to poll for. + * @param timeout The amount of time to wait for the events to trigger. A timeout of zero will + * block indefinitely until the event triggers. + * @return THe result of the poll operation. + */ + [[nodiscard]] auto poll(const net::socket& sock, + mrc::coroutines::poll_op op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) + -> mrc::coroutines::Task + { + return poll(sock.native_handle(), op, timeout); + } +#endif + + /** + * Resumes execution of a direct coroutine handle on this io scheduler. + * @param handle The coroutine handle to resume execution. + */ + void resume(std::coroutine_handle<> handle) noexcept override + { + if (m_opts.execution_strategy == ExecutionStrategy::process_tasks_inline) + { + { + std::scoped_lock lk{m_scheduled_tasks_mutex}; + m_scheduled_tasks.emplace_back(handle); + } + + bool expected{false}; + if (m_schedule_fd_triggered.compare_exchange_strong(expected, + true, + std::memory_order::release, + std::memory_order::relaxed)) + { + eventfd_t value{1}; + eventfd_write(m_schedule_fd, value); + } + } + else + { + m_thread_pool->resume(handle); + } + } + + /** + * @return The number of tasks waiting in the task queue + the executing tasks. + */ + auto size() const noexcept -> std::size_t + { + if (m_opts.execution_strategy == ExecutionStrategy::process_tasks_inline) + { + return m_size.load(std::memory_order::acquire); + } + + return m_size.load(std::memory_order::acquire) + m_thread_pool->size(); + } + + /** + * @return True if the task queue is empty and zero tasks are currently executing. + */ + auto empty() const noexcept -> bool + { + return size() == 0; + } + + /** + * Starts the shutdown of the io scheduler. All currently executing and pending tasks will complete + * prior to shutting down. This call is blocking and will not return until all tasks complete. + */ + auto shutdown() noexcept -> void; + + /** + * Scans for completed coroutines and destroys them freeing up resources. This is also done on starting + * new tasks but this allows the user to cleanup resources manually. One usage might be making sure fds + * are cleaned up as soon as possible. + */ + auto garbage_collect() noexcept -> void; + + private: + /// The configuration options. + Options m_opts; + + /// The event loop epoll file descriptor. + fd_t m_epoll_fd{-1}; + /// The event loop fd to trigger a shutdown. + fd_t m_shutdown_fd{-1}; + /// The event loop timer fd for timed events, e.g. yield_for() or scheduler_after(). + fd_t m_timer_fd{-1}; + /// The schedule file descriptor if the scheduler is in inline processing mode. + fd_t m_schedule_fd{-1}; + std::atomic m_schedule_fd_triggered{false}; + + /// The number of tasks executing or awaiting events in this io scheduler. + std::atomic m_size{0}; + + /// The background io worker threads. + std::thread m_io_thread; + /// Thread pool for executing tasks when not in inline mode. + std::unique_ptr m_thread_pool{nullptr}; + + std::mutex m_timed_events_mutex{}; + /// The map of time point's to poll infos for tasks that are yielding for a period of time + /// or for tasks that are polling with timeouts. + timed_events_t m_timed_events{}; + + /// Has the IoScheduler been requested to shut down? + std::atomic m_shutdown_requested{false}; + + std::atomic m_io_processing{false}; + auto process_events_manual(std::chrono::milliseconds timeout) -> void; + auto process_events_dedicated_thread() -> void; + auto process_events_execute(std::chrono::milliseconds timeout) -> void; + static auto event_to_poll_status(uint32_t events) -> PollStatus; + + auto process_scheduled_execute_inline() -> void; + std::mutex m_scheduled_tasks_mutex{}; + std::vector> m_scheduled_tasks{}; + + /// Tasks that have their ownership passed into the scheduler. This is a bit strange for now + /// but the concept doesn't pass since IoScheduler isn't fully defined yet. + /// The type is mrc::coroutines::Task_container* + /// Do not inline any functions that use this in the IoScheduler header, it can cause the linker + /// to complain about "defined in discarded section" because it gets defined multiple times + void* m_owned_tasks{nullptr}; + + static constexpr const int MShutdownObject{0}; + static constexpr const void* MShutdownPtr = &MShutdownObject; + + static constexpr const int MTimerObject{0}; + static constexpr const void* MTimerPtr = &MTimerObject; + + static constexpr const int MScheduleObject{0}; + static constexpr const void* MSchedulePtr = &MScheduleObject; + + static const constexpr std::chrono::milliseconds MDefaultTimeout{1000}; + static const constexpr std::chrono::milliseconds MNoTimeout{0}; + static const constexpr std::size_t MMaxEvents = 16; + std::array m_events{}; + std::vector> m_handles_to_resume{}; + + auto process_event_execute(detail::PollInfo* pi, PollStatus status) -> void; + auto process_timeout_execute() -> void; + + auto add_timer_token(time_point_t tp, detail::PollInfo& pi) -> timed_events_t::iterator; + auto remove_timer_token(timed_events_t::iterator pos) -> void; + auto update_timeout(time_point_t now) -> void; +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/poll.hpp b/cpp/mrc/include/mrc/coroutines/poll.hpp new file mode 100644 index 000000000..86bb28867 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/poll.hpp @@ -0,0 +1,82 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace mrc::coroutines { +enum class PollOperation : uint64_t +{ + /// Poll for read operations. + read = EPOLLIN, + /// Poll for write operations. + write = EPOLLOUT, + /// Poll for read and write operations. + read_write = EPOLLIN | EPOLLOUT +}; + +inline auto poll_op_readable(PollOperation op) -> bool +{ + return (static_cast(op) & EPOLLIN) != 0; +} + +inline auto poll_op_writeable(PollOperation op) -> bool +{ + return (static_cast(op) & EPOLLOUT) != 0; +} + +auto to_string(PollOperation op) -> const std::string&; + +enum class PollStatus +{ + /// The poll operation was was successful. + event, + /// The poll operation timed out. + timeout, + /// The file descriptor had an error while polling. + error, + /// The file descriptor has been closed by the remote or an internal error/close. + closed +}; + +auto to_string(PollStatus status) -> const std::string&; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/schedule_on.hpp b/cpp/mrc/include/mrc/coroutines/schedule_on.hpp new file mode 100644 index 000000000..73505a1bd --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/schedule_on.hpp @@ -0,0 +1,98 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/lewissbaker/cppcoro + * Original License: MIT; included below + */ + +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "async_generator.hpp" + +#include +#include +#include + +#include + +namespace mrc::coroutines { + +/** + * @brief Schedules an awaitable to run on the supplied scheduler. Returns the value as if it were awaited on in the + * current thread. + */ +template +auto schedule_on(SchedulerT& scheduler, AwaitableT awaitable) -> Task::awaiter_return_type>::type> +{ + using return_t = typename boost::detail::remove_rvalue_ref< + typename mrc::coroutines::concepts::awaitable_traits::awaiter_return_type>::type; + + co_await scheduler.schedule(); + + if constexpr (std::is_same_v) + { + co_await std::move(awaitable); + VLOG(10) << "schedule_on completed"; + co_return; + } + else + { + auto result = co_await std::move(awaitable); + VLOG(10) << "schedule_on completed"; + co_return std::move(result); + } +} + +/** + * @brief Schedules an async generator to run on the supplied scheduler. Each value in the generator run on the + * scheduler. The return value is the same as if the generator was run on the current thread. + * + * @tparam T + * @tparam SchedulerT + * @param scheduler + * @param source + * @return mrc::coroutines::AsyncGenerator + */ +template +mrc::coroutines::AsyncGenerator schedule_on(SchedulerT& scheduler, mrc::coroutines::AsyncGenerator source) +{ + // Transfer exection to the scheduler before the implicit calls to + // 'co_await begin()' or subsequent calls to `co_await iterator::operator++()` + // below. This ensures that all calls to the generator's coroutine_handle<>::resume() + // are executed on the execution context of the scheduler. + co_await scheduler.schedule(); + + const auto iter_end = source.end(); + auto iter = co_await source.begin(); + while (iter != iter_end) + { + co_yield *iter; + + co_await scheduler.schedule(); + + (void)co_await ++iter; + } +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/scheduler.hpp b/cpp/mrc/include/mrc/coroutines/scheduler.hpp new file mode 100644 index 000000000..d8efff83b --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/scheduler.hpp @@ -0,0 +1,62 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/time.hpp" + +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +/** + * @brief Scheduler base class + */ +class Scheduler : public std::enable_shared_from_this +{ + public: + virtual ~Scheduler() = default; + + /** + * @brief Resumes a coroutine according to the scheduler's implementation. + */ + virtual void resume(std::coroutine_handle<> handle) noexcept = 0; + + /** + * @brief Suspends the current function and resumes it according to the scheduler's implementation. + */ + [[nodiscard]] virtual Task<> yield() = 0; + + /** + * @brief Suspends the current function for a given duration and resumes it according to the schedulers's + * implementation. + */ + [[nodiscard]] virtual Task<> yield_for(std::chrono::milliseconds amount) = 0; + + /** + * @brief Suspends the current function until a given time point and resumes it according to the schedulers's + * implementation. + */ + [[nodiscard]] virtual Task<> yield_until(time_point_t time) = 0; +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/task_container.hpp b/cpp/mrc/include/mrc/coroutines/task_container.hpp new file mode 100644 index 000000000..88730b919 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/task_container.hpp @@ -0,0 +1,171 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "mrc/coroutines/task.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { +class Scheduler; + +class TaskContainer +{ + public: + using task_position_t = std::list>>::iterator; + + /** + * @param e Tasks started in the container are scheduled onto this executor. For tasks created + * from a coro::io_scheduler, this would usually be that coro::io_scheduler instance. + */ + TaskContainer(std::shared_ptr e, std::size_t max_concurrent_tasks = 0); + + TaskContainer(const TaskContainer&) = delete; + TaskContainer(TaskContainer&&) = delete; + auto operator=(const TaskContainer&) -> TaskContainer& = delete; + auto operator=(TaskContainer&&) -> TaskContainer& = delete; + + ~TaskContainer(); + + enum class GarbageCollectPolicy + { + /// Execute garbage collection. + yes, + /// Do not execute garbage collection. + no + }; + + /** + * Stores a user task and starts its execution on the container's thread pool. + * @param user_task The scheduled user's task to store in this task container and start its execution. + * @param cleanup Should the task container run garbage collect at the beginning of this store + * call? Calling at regular intervals will reduce memory usage of completed + * tasks and allow for the task container to re-use allocated space. + */ + auto start(Task&& user_task, GarbageCollectPolicy cleanup = GarbageCollectPolicy::yes) -> void; + + /** + * Garbage collects any tasks that are marked as deleted. This frees up space to be re-used by + * the task container for newly stored tasks. + * @return The number of tasks that were deleted. + */ + auto garbage_collect() -> std::size_t; + + /** + * @return The number of active tasks in the container. + */ + auto size() -> std::size_t; + + /** + * @return True if there are no active tasks in the container. + */ + auto empty() -> bool; + + /** + * @return The capacity of this task manager before it will need to grow in size. + */ + auto capacity() -> std::size_t; + + /** + * Will continue to garbage collect and yield until all tasks are complete. This method can be + * co_await'ed to make it easier to wait for the task container to have all its tasks complete. + * + * This does not shut down the task container, but can be used when shutting down, or if your + * logic requires all the tasks contained within to complete, it is similar to coro::latch. + */ + auto garbage_collect_and_yield_until_empty() -> Task; + + private: + /** + * Special constructor for internal types to create their embeded task containers. + */ + TaskContainer(Scheduler& e); + + /** + * Interal GC call, expects the public function to lock. + */ + auto gc_internal() -> std::size_t; + + /** + * Starts the next taks in the queue if one is available and max concurrent tasks has not yet been met. + */ + void try_start_next_task(std::unique_lock lock); + + /** + * Encapsulate the users tasks in a cleanup task which marks itself for deletion upon + * completion. Simply co_await the users task until its completed and then mark the given + * position within the task manager as being deletable. The scheduler's next iteration + * in its event loop will then free that position up to be re-used. + * + * This function will also unconditionally catch all unhandled exceptions by the user's + * task to prevent the scheduler from throwing exceptions. + * @param user_task The user's task. + * @param pos The position where the task data will be stored in the task manager. + * @return The user's task wrapped in a self cleanup task. + */ + auto make_cleanup_task(Task user_task, task_position_t pos) -> Task; + + /// Mutex for safely mutating the task containers across threads, expected usage is within + /// thread pools for indeterminate lifetime requests. + std::mutex m_mutex{}; + /// The number of alive tasks. + std::size_t m_size{}; + /// Maintains the lifetime of the tasks until they are completed. + std::list>> m_tasks{}; + /// The set of tasks that have completed and need to be deleted. + std::vector m_tasks_to_delete{}; + /// The executor to schedule tasks that have just started. This is only used for lifetime management and may be + /// nullptr + std::shared_ptr m_scheduler_lifetime{nullptr}; + /// This is used internally since io_scheduler cannot pass itself in as a shared_ptr. + Scheduler* m_scheduler{nullptr}; + /// tasks to be processed in order of start + std::queue m_next_tasks; + /// maximum number of tasks to be run simultaneously + std::size_t m_max_concurrent_tasks; + + friend Scheduler; +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp b/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp new file mode 100644 index 000000000..5d74f2168 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp @@ -0,0 +1,112 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/scheduler.hpp" +#include "mrc/coroutines/task.hpp" + +#include +#include +#include +#include +#include + +#pragma once + +namespace mrc::coroutines { + +class TestScheduler : public Scheduler +{ + private: + struct Operation + { + public: + Operation(TestScheduler* self, std::chrono::time_point time); + + static constexpr bool await_ready() + { + return false; + } + + void await_suspend(std::coroutine_handle<> handle); + + void await_resume() {} + + private: + TestScheduler* m_self; + std::chrono::time_point m_time; + }; + + using item_t = std::pair, std::chrono::time_point>; + struct ItemCompare + { + bool operator()(item_t& lhs, item_t& rhs); + }; + + std::priority_queue, ItemCompare> m_queue; + std::chrono::time_point m_time = std::chrono::steady_clock::now(); + + public: + /** + * @brief Enqueue's the coroutine handle to be resumed at the current logical time. + */ + void resume(std::coroutine_handle<> handle) noexcept override; + + /** + * Suspends the current function and enqueue's it to be resumed at the current logical time. + */ + mrc::coroutines::Task<> yield() override; + + /** + * Suspends the current function and enqueue's it to be resumed at the current logica time + the given duration. + */ + mrc::coroutines::Task<> yield_for(std::chrono::milliseconds time) override; + + /** + * Suspends the current function and enqueue's it to be resumed at the given logical time. + */ + mrc::coroutines::Task<> yield_until(std::chrono::time_point time) override; + + /** + * Returns the time according to the scheduler. Time may be progressed by resume_next, resume_for, and resume_until. + * + * @return the current time according to the scheduler. + */ + std::chrono::time_point time(); + + /** + * Immediately resumes the next-in-queue coroutine handle. + * + * @return true if more coroutines exist in the queue after resuming, false otherwise. + */ + bool resume_next(); + + /** + * Immediately resumes next-in-queue coroutines up to the current logical time + the given duration, in-order. + * + * @return true if more coroutines exist in the queue after resuming, false otherwise. + */ + bool resume_for(std::chrono::milliseconds time); + + /** + * Immediately resumes next-in-queue coroutines up to the given logical time. + * + * @return true if more coroutines exist in the queue after resuming, false otherwise. + */ + bool resume_until(std::chrono::time_point time); +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/time.hpp b/cpp/mrc/include/mrc/coroutines/time.hpp new file mode 100644 index 000000000..f7844b5b7 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/time.hpp @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace mrc::coroutines { +using clock_t = std::chrono::steady_clock; +using time_point_t = clock_t::time_point; +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/edge/edge_channel.hpp b/cpp/mrc/include/mrc/edge/edge_channel.hpp index 5da85d74c..240764421 100644 --- a/cpp/mrc/include/mrc/edge/edge_channel.hpp +++ b/cpp/mrc/include/mrc/edge/edge_channel.hpp @@ -20,6 +20,7 @@ #include "mrc/edge/edge_readable.hpp" #include "mrc/edge/edge_writable.hpp" #include "mrc/edge/forward.hpp" +#include "mrc/utils/macros.hpp" #include @@ -89,6 +90,24 @@ class EdgeChannel { CHECK(m_channel) << "Cannot create an EdgeChannel from an empty pointer"; } + + EdgeChannel(EdgeChannel&& other) : m_channel(std::move(other.m_channel)) {} + + EdgeChannel& operator=(EdgeChannel&& other) + { + if (this == &other) + { + return *this; + } + + m_channel = std::move(other.m_channel); + + return *this; + } + + // This should not be copyable because it requires passing in a unique_ptr + DELETE_COPYABILITY(EdgeChannel); + virtual ~EdgeChannel() = default; [[nodiscard]] std::shared_ptr> get_reader() const diff --git a/cpp/mrc/include/mrc/edge/edge_holder.hpp b/cpp/mrc/include/mrc/edge/edge_holder.hpp index b3d801484..0262a7e71 100644 --- a/cpp/mrc/include/mrc/edge/edge_holder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_holder.hpp @@ -152,7 +152,6 @@ class EdgeHolder void release_edge_connection() { - m_owned_edge_lifetime.reset(); m_connected_edge.reset(); } diff --git a/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp new file mode 100644 index 000000000..98c4a7d6d --- /dev/null +++ b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp @@ -0,0 +1,53 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#pragma once + +namespace mrc { + +/** + * @brief A utility for catching out-of-stack exceptions in a thread-safe manner such that they + * can be checked and throw from a parent thread. + */ +class ExceptionCatcher +{ + public: + /** + * @brief "catches" an exception to the catcher + */ + void push_exception(std::exception_ptr ex); + + /** + * @brief checks to see if any exceptions have been "caught" by the catcher. + */ + bool has_exception(); + + /** + * @brief rethrows the next exception (in the order in which it was "caught"). + */ + void rethrow_next_exception(); + + private: + std::mutex m_mutex{}; + std::queue m_exceptions{}; +}; + +} // namespace mrc diff --git a/cpp/mrc/include/mrc/memory/resources/detail/arena.hpp b/cpp/mrc/include/mrc/memory/resources/detail/arena.hpp index e25631606..b514fb5c5 100644 --- a/cpp/mrc/include/mrc/memory/resources/detail/arena.hpp +++ b/cpp/mrc/include/mrc/memory/resources/detail/arena.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -169,7 +169,7 @@ inline bool block_size_compare(block lhs, block rhs) */ constexpr std::size_t align_up(std::size_t value) noexcept { - return rmm::detail::align_up(value, rmm::detail::CUDA_ALLOCATION_ALIGNMENT); + return rmm::align_up(value, rmm::CUDA_ALLOCATION_ALIGNMENT); } /** @@ -180,7 +180,7 @@ constexpr std::size_t align_up(std::size_t value) noexcept */ constexpr std::size_t align_down(std::size_t value) noexcept { - return rmm::detail::align_down(value, rmm::detail::CUDA_ALLOCATION_ALIGNMENT); + return rmm::align_down(value, rmm::CUDA_ALLOCATION_ALIGNMENT); } /** diff --git a/cpp/mrc/include/mrc/node/node_parent.hpp b/cpp/mrc/include/mrc/node/node_parent.hpp new file mode 100644 index 000000000..51f1f36be --- /dev/null +++ b/cpp/mrc/include/mrc/node/node_parent.hpp @@ -0,0 +1,41 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::node { + +template +class NodeParent +{ + public: + using child_types_t = std::tuple; + + virtual std::tuple>...> get_children_refs() const = 0; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp index a5d50d217..b1e71d11f 100644 --- a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp +++ b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp @@ -20,6 +20,7 @@ #include "mrc/channel/status.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" +#include "mrc/utils/tuple_utils.hpp" #include "mrc/utils/type_utils.hpp" #include @@ -31,91 +32,6 @@ namespace mrc::node { -// template -// class ParameterPackIndexer -// { -// public: -// ParameterPackIndexer(TypesT... ts) : ParameterPackIndexer(std::make_index_sequence{}, ts...) -// {} - -// std::tuple...> tup; - -// private: -// template -// ParameterPackIndexer(std::index_sequence const& /*unused*/, TypesT... ts) : tup{std::make_tuple(ts, -// Is)...} -// {} -// }; - -// template -// constexpr size_t getTypeIndexInTemplateList() -// { -// if constexpr (std::is_same::value) -// { -// return 0; -// } -// else -// { -// return 1 + getTypeIndexInTemplateList(); -// } -// } - -namespace detail { -struct Surely -{ - template - auto operator()(const T&... t) const -> decltype(std::make_tuple(t.value()...)) - { - return std::make_tuple(t.value()...); - } -}; -} // namespace detail - -// template -// inline auto surely(const std::tuple& tpl) -> decltype(rxcpp::util::apply(tpl, detail::surely())) -// { -// return rxcpp::util::apply(tpl, detail::surely()); -// } - -template -inline auto surely2(const std::tuple& tpl) -{ - return std::apply([](auto... args) { - return std::make_tuple(args.value()...); - }); -} - -// template -// static auto surely2(const std::tuple& tpl, std::index_sequence) -// { -// return std::make_tuple(std::make_shared>(*self)...); -// } - -// template -// struct IndexTypePair -// { -// static constexpr size_t index{i}; -// using Type = T; -// }; - -// template -// struct make_index_type_tuple_helper -// { -// template -// struct idx; - -// template -// struct idx> -// { -// using tuple_type = std::tuple...>; -// }; - -// using tuple_type = typename idx>::tuple_type; -// }; - -// template -// using make_index_type_tuple = typename make_index_type_tuple_helper::tuple_type; - template class CombineLatest : public WritableAcceptor> { @@ -128,9 +44,7 @@ class CombineLatest : public WritableAcceptor> public: CombineLatest() : m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) - { - // auto a = build_ingress(const_cast(this), std::index_sequence_for{}); - } + {} virtual ~CombineLatest() = default; @@ -193,9 +107,9 @@ class CombineLatest : public WritableAcceptor> // Check if we should push the new value if (m_values_set == sizeof...(TypesT)) { - // std::tuple new_val = surely2(m_state); + std::tuple new_val = utils::tuple_surely(m_state); - // status = this->get_writable_edge()->await_write(std::move(new_val)); + status = this->get_writable_edge()->await_write(std::move(new_val)); } return status; @@ -209,6 +123,9 @@ class CombineLatest : public WritableAcceptor> if (m_completions == sizeof...(TypesT)) { + // Clear the held tuple to remove any dangling values + m_state = std::tuple...>(); + WritableAcceptor>::release_edge_connection(); } } diff --git a/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp b/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp new file mode 100644 index 000000000..0eafd8572 --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp @@ -0,0 +1,144 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/edge/deferred_edge.hpp" + +#include +#include +#include + +namespace mrc::node { + +class RoundRobinRouterTypeless : public edge::IWritableProviderBase, public edge::IWritableAcceptorBase +{ + public: + std::shared_ptr get_writable_edge_handle() const override + { + auto* self = const_cast(this); + + // Create a new upstream edge. On connection, have it attach to any downstreams + auto deferred_ingress = std::make_shared( + [self](std::shared_ptr deferred_edge) { + // Set the broadcast indices function + deferred_edge->set_indices_fn([self](edge::DeferredWritableMultiEdgeBase& deferred_edge) { + // Increment the index and return the key for that index + auto next_idx = self->m_current_idx++; + + auto current_keys = deferred_edge.edge_connection_keys(); + + return std::vector{current_keys[next_idx % current_keys.size()]}; + }); + + // Need to work with weak ptr here otherwise we will keep it from closing + std::weak_ptr weak_deferred_edge = deferred_edge; + + // Use a connector here in case the object never gets set to an edge + deferred_edge->add_connector([self, weak_deferred_edge]() { + // Lock whenever working on the handles + std::unique_lock lock(self->m_mutex); + + // Save to the upstream handles + self->m_upstream_handles.emplace_back(weak_deferred_edge); + + auto deferred_edge = weak_deferred_edge.lock(); + + CHECK(deferred_edge) << "Edge was destroyed before making connection."; + + for (const auto& downstream : self->m_downstream_handles) + { + auto count = deferred_edge->edge_connection_count(); + + // Connect + deferred_edge->set_writable_edge_handle(count, downstream); + } + + // Now add a disconnector that will remove it from the list + deferred_edge->add_disconnector([self, weak_deferred_edge]() { + // Need to lock here since this could be driven by different progress engines + std::unique_lock lock(self->m_mutex); + + bool is_expired = weak_deferred_edge.expired(); + + // Cull all expired ptrs from the list + auto iter = self->m_upstream_handles.begin(); + + while (iter != self->m_upstream_handles.end()) + { + if ((*iter).expired()) + { + iter = self->m_upstream_handles.erase(iter); + } + else + { + ++iter; + } + } + + // If there are no more upstream handles, then delete the downstream + if (self->m_upstream_handles.empty()) + { + self->m_downstream_handles.clear(); + } + }); + }); + }); + + return deferred_ingress; + } + + edge::EdgeTypeInfo writable_provider_type() const override + { + return edge::EdgeTypeInfo::create_deferred(); + } + + void set_writable_edge_handle(std::shared_ptr ingress) override + { + // Lock whenever working on the handles + std::unique_lock lock(m_mutex); + + // We have a new downstream object. Hold onto it + m_downstream_handles.push_back(ingress); + + // If we have an upstream object, try to make a connection now + for (auto& upstream_weak : m_upstream_handles) + { + auto upstream = upstream_weak.lock(); + + CHECK(upstream) << "Upstream edge went out of scope before downstream edges were connected"; + + auto count = upstream->edge_connection_count(); + + // Connect + upstream->set_writable_edge_handle(count, ingress); + } + } + + edge::EdgeTypeInfo writable_acceptor_type() const override + { + return edge::EdgeTypeInfo::create_deferred(); + } + + private: + std::mutex m_mutex; + std::atomic_size_t m_current_idx{0}; + std::vector> m_upstream_handles; + std::vector> m_downstream_handles; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp new file mode 100644 index 000000000..dd99ed285 --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/channel/buffered_channel.hpp" +#include "mrc/channel/status.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/node/sink_properties.hpp" +#include "mrc/node/source_properties.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include "mrc/utils/type_utils.hpp" + +#include +#include + +#include +#include +#include +#include + +namespace mrc::node { + +template +class WithLatestFrom : public WritableAcceptor> +{ + template + using queue_t = BufferedChannel; + template + using wrapped_queue_t = std::unique_ptr>; + using output_t = std::tuple; + + template + static auto build_ingress(WithLatestFrom* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_shared>(*self)...); + } + + public: + WithLatestFrom() : + m_primary_queue(std::make_unique>>()), + m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + {} + + virtual ~WithLatestFrom() = default; + + template + std::shared_ptr>> get_sink() const + { + return std::get(m_upstream_holders); + } + + protected: + template + class Upstream : public WritableProvider> + { + using upstream_t = NthTypeOf; + + public: + Upstream(WithLatestFrom& parent) + { + this->init_owned_edge(std::make_shared(parent)); + } + + private: + class InnerEdge : public edge::IEdgeWritable> + { + public: + InnerEdge(WithLatestFrom& parent) : m_parent(parent) {} + ~InnerEdge() + { + m_parent.edge_complete(); + } + + virtual channel::Status await_write(upstream_t&& data) + { + return m_parent.set_upstream_value(std::move(data)); + } + + private: + WithLatestFrom& m_parent; + }; + }; + + private: + template + channel::Status set_upstream_value(NthTypeOf value) + { + std::unique_lock lock(m_mutex); + + // Get a reference to the current value + auto& nth_val = std::get(m_state); + + // Check if we have fully initialized + if (m_values_set < sizeof...(TypesT)) + { + if (!nth_val.has_value()) + { + ++m_values_set; + } + + // Move the value into the state + nth_val = std::move(value); + + // For the primary upstream only, move the value onto a queue + if constexpr (N == 0) + { + // Temporarily unlock to prevent deadlock + lock.unlock(); + + Unwinder relock([&]() { + lock.lock(); + }); + + // Move it into the queue + CHECK_EQ(m_primary_queue->await_write(std::move(nth_val.value())), channel::Status::success); + } + + // Check if this put us over the edge + if (m_values_set == sizeof...(TypesT)) + { + // Need to complete initialization. First close the primary channel + m_primary_queue->close_channel(); + + auto& primary_val = std::get<0>(m_state); + + // Loop over the values in the queue, pushing each one + while (m_primary_queue->await_read(primary_val.value()) == channel::Status::success) + { + std::tuple new_val = utils::tuple_surely(m_state); + + CHECK_EQ(this->get_writable_edge()->await_write(std::move(new_val)), channel::Status::success); + } + } + } + else + { + // Move the value into the state + nth_val = std::move(value); + + // Only when we are the primary, do we push a new value + if constexpr (N == 0) + { + std::tuple new_val = utils::tuple_surely(m_state); + + return this->get_writable_edge()->await_write(std::move(new_val)); + } + } + + return channel::Status::success; + } + + void edge_complete() + { + std::unique_lock lock(m_mutex); + + m_completions++; + + if (m_completions == sizeof...(TypesT)) + { + NthTypeOf<0, TypesT...> tmp; + bool had_values = false; + + // Try to clear out any values left in the channel + while (m_primary_queue->await_read(tmp) == channel::Status::success) + { + had_values = true; + } + + LOG_IF(ERROR, had_values) << "The primary source values were never pushed downstream. Ensure all upstream " + "sources pushed at least 1 value"; + + // Clear the held tuple to remove any dangling values + m_state = std::tuple...>(); + + WritableAcceptor>::release_edge_connection(); + } + } + + boost::fibers::mutex m_mutex; + + // The number of elements that have been set. Can start emitting when m_values_set == sizeof...(TypesT) + size_t m_values_set{0}; + + // Counts the number of upstream completions. When m_completions == sizeof...(TypesT), the downstream edges are + // released + size_t m_completions{0}; + + // Holds onto the latest values to eventually push when new ones are emitted + std::tuple...> m_state; + + // Queue to allow backpressure to upstreams. Only 1 queue for the primary is needed + wrapped_queue_t> m_primary_queue; + + // Upstream edges + std::tuple>...> m_upstream_holders; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/zip.hpp b/cpp/mrc/include/mrc/node/operators/zip.hpp new file mode 100644 index 000000000..f06a39657 --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/zip.hpp @@ -0,0 +1,291 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/channel/buffered_channel.hpp" +#include "mrc/channel/channel.hpp" +#include "mrc/channel/status.hpp" +#include "mrc/node/node_parent.hpp" +#include "mrc/node/sink_properties.hpp" +#include "mrc/node/source_properties.hpp" +#include "mrc/types.hpp" +#include "mrc/utils/string_utils.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include "mrc/utils/type_utils.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::node { + +class ZipBase +{ + public: + virtual ~ZipBase() = default; +}; + +template +class Zip : public ZipBase, + public WritableAcceptor>, + public NodeParent...> +{ + template + using queue_t = BufferedChannel; + template + using wrapped_queue_t = std::unique_ptr>; + using queues_tuple_type = std::tuple...>; + using output_t = std::tuple; + + template + static auto build_ingress(Zip* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_shared>(*self)...); + } + + static auto build_queues(size_t channel_size) + { + return std::make_tuple(std::make_unique>(channel_size)...); + } + + template + static std::tuple>>...> + build_child_pairs(Zip* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_pair(MRC_CONCAT_STR("sink[" << Is << "]"), std::ref(self->get_sink()))...); + } + + template + channel::Status tuple_pop_each(queues_tuple_type& queues_tuple, output_t& output_tuple) + { + channel::Status status = std::get(queues_tuple)->await_read(std::get(output_tuple)); + + if constexpr (I + 1 < sizeof...(TypesT)) + { + // Iterate to the next index + channel::Status inner_status = tuple_pop_each(queues_tuple, output_tuple); + + // If the inner status failed, return that, otherwise return our status + status = inner_status == channel::Status::success ? status : inner_status; + } + + return status; + } + + public: + Zip(size_t channel_size = channel::default_channel_size()) : + m_queues(build_queues(channel_size)), + m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + { + // Must be sure to set any array values + m_queue_counts.fill(0); + } + + ~Zip() override = default; + + template + edge::IWritableProvider>& get_sink() const + { + return *std::get(m_upstream_holders); + } + + std::tuple>>...> get_children_refs() + const override + { + return build_child_pairs(const_cast(this), std::index_sequence_for{}); + } + + protected: + template + class Upstream : public WritableProvider> + { + using upstream_t = NthTypeOf; + + public: + Upstream(Zip& parent) + { + this->init_owned_edge(std::make_shared(parent)); + } + + private: + class InnerEdge : public edge::IEdgeWritable> + { + public: + InnerEdge(Zip& parent) : m_parent(parent) {} + ~InnerEdge() + { + m_parent.edge_complete(); + } + + virtual channel::Status await_write(upstream_t&& data) + { + return m_parent.upstream_await_write(std::move(data)); + } + + private: + Zip& m_parent; + }; + }; + + private: + template + channel::Status upstream_await_write(NthTypeOf value) + { + // Push before locking so we dont deadlock + auto push_status = std::get(m_queues)->await_write(std::move(value)); + + if (push_status != channel::Status::success) + { + return push_status; + } + + std::unique_lock lock(m_mutex); + + // Update the counts array + m_queue_counts[N]++; + + if (m_queue_counts[N] == m_max_queue_count) + { + // Close the queue to prevent pushing more messages + std::get(m_queues)->close_channel(); + } + + DCHECK_LE(m_queue_counts[N], m_max_queue_count) << "Queue count has surpassed the max count"; + + // See if we have values in every queue + auto all_queues_have_value = std::transform_reduce(m_queue_counts.begin(), + m_queue_counts.end(), + true, + std::logical_and<>(), + [this](const size_t& v) { + return v > m_pull_count; + }); + + channel::Status status = channel::Status::success; + + if (all_queues_have_value) + { + // For each tuple, pop a value off + std::tuple new_val; + + auto channel_status = tuple_pop_each(m_queues, new_val); + + DCHECK_EQ(channel_status, channel::Status::success) << "Queues returned failed status"; + + // Push the new value + status = this->get_writable_edge()->await_write(std::move(new_val)); + + m_pull_count++; + } + + return status; + } + + template + void edge_complete() + { + std::unique_lock lock(m_mutex); + + if (m_queue_counts[N] < m_max_queue_count) + { + // We are setting a new lower limit. Check to make sure this isnt an issue + m_max_queue_count = m_queue_counts[N]; + + utils::tuple_for_each(m_queues, + [this](std::unique_ptr>& q, size_t idx) { + if (m_queue_counts[idx] >= m_max_queue_count) + { + // Close the channel + q->close_channel(); + + if (m_queue_counts[idx] > m_max_queue_count) + { + LOG(ERROR) + << "Unbalanced count in upstream sources for Zip operator. Upstream '" + << N << "' ended with " << m_queue_counts[N] << " elements but " + << m_queue_counts[idx] + << " elements have already been pushed by upstream '" << idx << "'"; + } + } + }); + } + + m_completions++; + + if (m_completions == sizeof...(TypesT)) + { + // Warn on any left over values + auto left_over_messages = std::transform_reduce(m_queue_counts.begin(), + m_queue_counts.end(), + 0, + std::plus<>(), + [this](const size_t& v) { + return v - m_pull_count; + }); + if (left_over_messages > 0) + { + LOG(ERROR) << "Unbalanced count in upstream sources for Zip operator. " << left_over_messages + << " messages were left in the queues"; + } + + // Finally, drain the queues of any remaining values + utils::tuple_for_each(m_queues, + [](std::unique_ptr>& q, size_t idx) { + QueueValueT value; + + while (q->await_read(value) == channel::Status::success) {} + }); + + WritableAcceptor>::release_edge_connection(); + } + } + + mutable Mutex m_mutex; + + // Once an upstream is closed, this is set representing the max number of values in a queue before its closed + size_t m_max_queue_count{std::numeric_limits::max()}; + + // Counts the number of upstream completions. When m_completions == sizeof...(TypesT), the downstream edges are + // released + size_t m_completions{0}; + + // Holds the number of values pushed to each queue + std::array m_queue_counts; + + // The number of messages pulled off the queue + size_t m_pull_count{0}; + + // Queue used to allow backpressure to upstreams + queues_tuple_type m_queues; + + // Upstream edges + std::tuple>...> m_upstream_holders; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp index 8997e3a8d..764cd944a 100644 --- a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp @@ -38,13 +38,13 @@ class SinkChannelOwner : public virtual SinkProperties { edge::EdgeChannel edge_channel(std::move(channel)); - this->do_set_channel(edge_channel); + this->do_set_channel(std::move(edge_channel)); } protected: SinkChannelOwner() = default; - void do_set_channel(edge::EdgeChannel& edge_channel) + void do_set_channel(edge::EdgeChannel edge_channel) { // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use // get_readable+edge diff --git a/cpp/mrc/include/mrc/node/source_channel_owner.hpp b/cpp/mrc/include/mrc/node/source_channel_owner.hpp index 226492e5e..2be60c690 100644 --- a/cpp/mrc/include/mrc/node/source_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/source_channel_owner.hpp @@ -40,13 +40,13 @@ class SourceChannelOwner : public virtual SourceProperties { edge::EdgeChannel edge_channel(std::move(channel)); - this->do_set_channel(edge_channel); + this->do_set_channel(std::move(edge_channel)); } protected: SourceChannelOwner() = default; - void do_set_channel(edge::EdgeChannel& edge_channel) + void do_set_channel(edge::EdgeChannel edge_channel) { // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use // get_writable_edge diff --git a/cpp/mrc/include/mrc/segment/component.hpp b/cpp/mrc/include/mrc/segment/component.hpp index 3e25f9b63..462a49bff 100644 --- a/cpp/mrc/include/mrc/segment/component.hpp +++ b/cpp/mrc/include/mrc/segment/component.hpp @@ -31,7 +31,13 @@ template class Component final : public Object { public: - Component(std::unique_ptr resource) : m_resource(std::move(resource)) {} + Component(std::unique_ptr resource) : + ObjectProperties(Object::build_state()), + Object(), + m_resource(std::move(resource)) + { + this->init_children(); + } ~Component() final = default; private: diff --git a/cpp/mrc/include/mrc/segment/egress_port.hpp b/cpp/mrc/include/mrc/segment/egress_port.hpp index 7fe52a5ce..909d0e1b2 100644 --- a/cpp/mrc/include/mrc/segment/egress_port.hpp +++ b/cpp/mrc/include/mrc/segment/egress_port.hpp @@ -59,10 +59,14 @@ class EgressPort final : public Object>, public: EgressPort(SegmentAddress address, PortName name) : + ObjectProperties(Object>::build_state()), m_segment_address(address), m_port_name(std::move(name)), m_sink(std::make_unique>()) - {} + { + // Must call after constructing Object + this->init_children(); + } private: node::RxSinkBase* get_object() const final diff --git a/cpp/mrc/include/mrc/segment/ingress_port.hpp b/cpp/mrc/include/mrc/segment/ingress_port.hpp index fec6d469e..8757f5a70 100644 --- a/cpp/mrc/include/mrc/segment/ingress_port.hpp +++ b/cpp/mrc/include/mrc/segment/ingress_port.hpp @@ -53,10 +53,14 @@ class IngressPort : public Object>, public IngressPortBase public: IngressPort(SegmentAddress address, PortName name) : + ObjectProperties(Object>::build_state()), m_segment_address(address), m_port_name(std::move(name)), m_source(std::make_unique>()) - {} + { + // Must call after constructing Object + this->init_children(); + } private: node::RxSourceBase* get_object() const final diff --git a/cpp/mrc/include/mrc/segment/object.hpp b/cpp/mrc/include/mrc/segment/object.hpp index 2ccc80094..e7b291960 100644 --- a/cpp/mrc/include/mrc/segment/object.hpp +++ b/cpp/mrc/include/mrc/segment/object.hpp @@ -19,28 +19,64 @@ #include "mrc/channel/ingress.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_readable.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" +#include "mrc/node/node_parent.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" #include "mrc/node/type_traits.hpp" #include "mrc/runnable/launch_options.hpp" #include "mrc/runnable/runnable.hpp" #include "mrc/segment/forward.hpp" +#include "mrc/type_traits.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include +#include #include #include #include +#include namespace mrc::segment { -struct ObjectProperties +template +class SharedObject; + +template +class ReferencedObject; + +struct ObjectPropertiesState { + std::string name; + std::string type_name; + bool is_writable_acceptor; + bool is_writable_provider; + bool is_readable_acceptor; + bool is_readable_provider; +}; + +class ObjectProperties +{ + public: virtual ~ObjectProperties() = 0; - virtual void set_name(const std::string& name) = 0; - virtual std::string name() const = 0; - virtual std::string type_name() const = 0; + virtual void set_name(const std::string& name) + { + m_state->name = name; + } + + virtual std::string name() const + { + return m_state->name; + } + + virtual std::string type_name() const + { + return m_state->type_name; + } virtual bool is_sink() const = 0; virtual bool is_source() const = 0; @@ -48,10 +84,22 @@ struct ObjectProperties virtual std::type_index sink_type(bool ignore_holder = false) const = 0; virtual std::type_index source_type(bool ignore_holder = false) const = 0; - virtual bool is_writable_acceptor() const = 0; - virtual bool is_writable_provider() const = 0; - virtual bool is_readable_acceptor() const = 0; - virtual bool is_readable_provider() const = 0; + bool is_writable_acceptor() const + { + return m_state->is_writable_acceptor; + } + bool is_writable_provider() const + { + return m_state->is_writable_provider; + } + bool is_readable_acceptor() const + { + return m_state->is_readable_acceptor; + } + bool is_readable_provider() const + { + return m_state->is_readable_provider; + } virtual edge::IWritableAcceptorBase& writable_acceptor_base() = 0; virtual edge::IWritableProviderBase& writable_provider_base() = 0; @@ -74,6 +122,20 @@ struct ObjectProperties virtual runnable::LaunchOptions& launch_options() = 0; virtual const runnable::LaunchOptions& launch_options() const = 0; + + virtual std::shared_ptr get_child(const std::string& name) const = 0; + virtual std::map> get_children() const = 0; + + protected: + ObjectProperties(std::shared_ptr state) : m_state(std::move(state)) {} + + std::shared_ptr get_state() const + { + return m_state; + } + + private: + std::shared_ptr m_state; }; inline ObjectProperties::~ObjectProperties() = default; @@ -147,15 +209,33 @@ edge::IReadableProvider& ObjectProperties::readable_provider_typed() } // Object - template -class Object : public virtual ObjectProperties +class Object : public virtual ObjectProperties, public std::enable_shared_from_this> { + protected: + static std::shared_ptr build_state() + { + auto state = std::make_shared(); + + state->type_name = std::string(::mrc::type_name()); + state->is_writable_acceptor = std::is_base_of_v; + state->is_writable_provider = std::is_base_of_v; + state->is_readable_acceptor = std::is_base_of_v; + state->is_readable_provider = std::is_base_of_v; + + return state; + } + public: + // Object(const Object& other) : m_name(other.m_name), m_launch_options(other.m_launch_options) {} + // Object(Object&&) = delete; + // Object& operator=(const Object&) = delete; + // Object& operator=(Object&&) = delete; + ObjectT& object(); - std::string name() const final; - std::string type_name() const final; + // std::string name() const final; + // std::string type_name() const final; bool is_source() const final; bool is_sink() const final; @@ -163,10 +243,10 @@ class Object : public virtual ObjectProperties std::type_index sink_type(bool ignore_holder) const final; std::type_index source_type(bool ignore_holder) const final; - bool is_writable_acceptor() const final; - bool is_writable_provider() const final; - bool is_readable_acceptor() const final; - bool is_readable_provider() const final; + // bool is_writable_acceptor() const final; + // bool is_writable_provider() const final; + // bool is_readable_acceptor() const final; + // bool is_readable_provider() const final; edge::IWritableAcceptorBase& writable_acceptor_base() final; edge::IWritableProviderBase& writable_provider_base() final; @@ -198,15 +278,106 @@ class Object : public virtual ObjectProperties return m_launch_options; } + std::shared_ptr get_child(const std::string& name) const override + { + CHECK(m_children.contains(name)) << "Child " << name << " not found in " << this->name(); + + if (auto child = m_children.at(name).lock()) + { + return child; + } + + auto* mutable_this = const_cast(this); + + // Otherwise, we need to build one + auto child = mutable_this->m_create_children_fns.at(name)(); + + mutable_this->m_children[name] = child; + + return child; + } + + std::map> get_children() const override + { + std::map> children; + + for (const auto& [name, child] : m_children) + { + children[name] = this->get_child(name); + } + + return children; + } + + template + requires std::derived_from + std::shared_ptr> as() const + { + auto shared_object = std::make_shared>(*const_cast(this)); + + return shared_object; + } + protected: - // Move to protected to allow only the IBuilder to set the name - void set_name(const std::string& name) override; + Object() : ObjectProperties(build_state()) + { + LOG(INFO) << "Creating Object '" << this->name() << "' with type: " << this->type_name(); + } - private: - std::string m_name{}; + template + requires std::derived_from + Object(const Object& other) : + ObjectProperties(other), + m_launch_options(other.m_launch_options), + m_children(other.m_children), + m_create_children_fns(other.m_create_children_fns) + { + LOG(INFO) << "Copying Object '" << this->name() << "' from type: " << other.type_name() + << " to type: " << this->type_name(); + } + + void init_children() + { + if constexpr (is_base_of_template::value) + { + using child_types_t = typename ObjectT::child_types_t; + // Get the name/reference pairs from the NodeParent + auto children_ref_pairs = this->object().get_children_refs(); + + // Finally, convert the tuple of name/ChildObject pairs into a map + utils::tuple_for_each( + children_ref_pairs, + [this](std::pair>& pair, + size_t idx) { + // auto child_obj = std::make_shared>(this->shared_from_this(), + // pair.second); + + // m_children.emplace(std::move(pair.first), std::move(child_obj)); + + m_children.emplace(pair.first, std::weak_ptr()); + + m_create_children_fns.emplace(pair.first, [this, obj_ref = pair.second]() { + return std::make_shared>(this->shared_from_this(), obj_ref); + }); + }); + } + } + + // // Move to protected to allow only the IBuilder to set the name + // void set_name(const std::string& name) override; + + private: virtual ObjectT* get_object() const = 0; + runnable::LaunchOptions m_launch_options; + + std::map> m_children; + std::map()>> m_create_children_fns; + + // Allows converting to base classes + template + friend class Object; }; template @@ -222,23 +393,23 @@ ObjectT& Object::object() return *node; } -template -void Object::set_name(const std::string& name) -{ - m_name = name; -} +// template +// void Object::set_name(const std::string& name) +// { +// m_name = name; +// } -template -std::string Object::name() const -{ - return m_name; -} +// template +// std::string Object::name() const +// { +// return m_name; +// } -template -std::string Object::type_name() const -{ - return std::string(::mrc::type_name()); -} +// template +// std::string Object::type_name() const +// { +// return std::string(::mrc::type_name()); +// } template bool Object::is_source() const @@ -276,83 +447,130 @@ std::type_index Object::source_type(bool ignore_holder) const return base->source_type(ignore_holder); } -template -bool Object::is_writable_acceptor() const -{ - return std::is_base_of_v; -} +// template +// bool Object::is_writable_acceptor() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_writable_provider() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_readable_acceptor() const +// { +// return std::is_base_of_v; +// } + +// template +// bool Object::is_readable_provider() const +// { +// return std::is_base_of_v; +// } template -bool Object::is_writable_provider() const +edge::IWritableAcceptorBase& Object::writable_acceptor_base() { - return std::is_base_of_v; -} + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; + // throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); + // } -template -bool Object::is_readable_acceptor() const -{ - return std::is_base_of_v; + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IIngressAcceptorBase"; + return *base; } template -bool Object::is_readable_provider() const +edge::IWritableProviderBase& Object::writable_provider_base() { - return std::is_base_of_v; + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; + // throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); + // } + + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IWritableProviderBase"; + return *base; } template -edge::IWritableAcceptorBase& Object::writable_acceptor_base() +edge::IReadableAcceptorBase& Object::readable_acceptor_base() { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); - } + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; + // throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); + // } - auto* base = dynamic_cast(get_object()); - CHECK(base); + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableAcceptorBase"; return *base; } template -edge::IWritableProviderBase& Object::writable_provider_base() +edge::IReadableProviderBase& Object::readable_provider_base() { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); - } + // if constexpr (!std::is_base_of_v) + // { + // LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; + // throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); + // } - auto* base = dynamic_cast(get_object()); - CHECK(base); + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableProviderBase"; return *base; } template -edge::IReadableAcceptorBase& Object::readable_acceptor_base() +class SharedObject final : public Object { - if constexpr (!std::is_base_of_v) + public: + SharedObject(std::shared_ptr owner, std::reference_wrapper resource) : + ObjectProperties(Object::build_state()), + m_owner(std::move(owner)), + m_resource(std::move(resource)) + {} + ~SharedObject() final = default; + + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; template -edge::IReadableProviderBase& Object::readable_provider_base() +class ReferencedObject final : public Object { - if constexpr (!std::is_base_of_v) + public: + template + requires std::derived_from + ReferencedObject(Object& other) : + ObjectProperties(other), + Object(other), + m_owner(other.shared_from_this()), + m_resource(other.object()) + {} + + ~ReferencedObject() final = default; + + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; + } // namespace mrc::segment diff --git a/cpp/mrc/include/mrc/segment/runnable.hpp b/cpp/mrc/include/mrc/segment/runnable.hpp index ab5b590ca..b40e01a00 100644 --- a/cpp/mrc/include/mrc/segment/runnable.hpp +++ b/cpp/mrc/include/mrc/segment/runnable.hpp @@ -37,15 +37,20 @@ template class Runnable : public Object, public runnable::Launchable { public: - template - Runnable(ArgsT&&... args) : m_node(std::make_unique(std::forward(args)...)) - {} - - Runnable(std::unique_ptr node) : m_node(std::move(node)) + Runnable(std::unique_ptr node) : + ObjectProperties(Object::build_state()), + Object(), + m_node(std::move(node)) { CHECK(m_node); + + this->init_children(); } + template + Runnable(ArgsT&&... args) : Runnable(std::make_unique(std::forward(args)...)) + {} + private: NodeT* get_object() const final; std::unique_ptr prepare_launcher(runnable::LaunchControl& launch_control) final; diff --git a/cpp/mrc/include/mrc/types.hpp b/cpp/mrc/include/mrc/types.hpp index 063e00831..4bbdc8171 100644 --- a/cpp/mrc/include/mrc/types.hpp +++ b/cpp/mrc/include/mrc/types.hpp @@ -24,33 +24,40 @@ namespace mrc { +// Suppress naming conventions in this file to allow matching std and boost libraries +// NOLINTBEGIN(readability-identifier-naming) + // Typedefs template -using Promise = userspace_threads::promise; // NOLINT(readability-identifier-naming) +using Promise = userspace_threads::promise; template -using Future = userspace_threads::future; // NOLINT(readability-identifier-naming) +using Future = userspace_threads::future; template -using SharedFuture = userspace_threads::shared_future; // NOLINT(readability-identifier-naming) +using SharedFuture = userspace_threads::shared_future; + +using Mutex = userspace_threads::mutex; -using Mutex = userspace_threads::mutex; // NOLINT(readability-identifier-naming) +using RecursiveMutex = userspace_threads::recursive_mutex; -using CondV = userspace_threads::cv; // NOLINT(readability-identifier-naming) +using CondV = userspace_threads::cv; -using MachineID = std::uint64_t; // NOLINT(readability-identifier-naming) -using InstanceID = std::uint64_t; // NOLINT(readability-identifier-naming) -using TagID = std::uint64_t; // NOLINT(readability-identifier-naming) +using MachineID = std::uint64_t; +using InstanceID = std::uint64_t; +using TagID = std::uint64_t; template -using Handle = std::shared_ptr; // NOLINT(readability-identifier-naming) +using Handle = std::shared_ptr; + +using SegmentID = std::uint16_t; +using SegmentRank = std::uint16_t; +using SegmentAddress = std::uint32_t; // id + rank -using SegmentID = std::uint16_t; // NOLINT(readability-identifier-naming) -using SegmentRank = std::uint16_t; // NOLINT(readability-identifier-naming) -using SegmentAddress = std::uint32_t; // NOLINT(readability-identifier-naming) // id + rank +using PortName = std::string; +using PortID = std::uint16_t; +using PortAddress = std::uint64_t; // id + rank + port -using PortName = std::string; // NOLINT(readability-identifier-naming) -using PortID = std::uint16_t; // NOLINT(readability-identifier-naming) -using PortAddress = std::uint64_t; // NOLINT(readability-identifier-naming) // id + rank + port +// NOLINTEND(readability-identifier-naming) } // namespace mrc diff --git a/cpp/mrc/include/mrc/utils/string_utils.hpp b/cpp/mrc/include/mrc/utils/string_utils.hpp index a835189f3..024ff8096 100644 --- a/cpp/mrc/include/mrc/utils/string_utils.hpp +++ b/cpp/mrc/include/mrc/utils/string_utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +17,23 @@ #pragma once -#include +// for ostringstream +#include // IWYU pragma: keep #include +#include // Concats multiple strings together using ostringstream. Use with MRC_CONCAT_STR("Start [" << my_int << "]") #define MRC_CONCAT_STR(strs) ((std::ostringstream&)(std::ostringstream() << strs)).str() + +namespace mrc { + +/** + * @brief Splits a string into an vector of strings based on a delimiter. + * + * @param str The string to split. + * @param delimiter The delimiter to split the string on. + * @return std::vector vector array of strings. + */ +std::vector split_string_to_vector(const std::string& str, const std::string& delimiter); + +} // namespace mrc diff --git a/cpp/mrc/include/mrc/utils/tuple_utils.hpp b/cpp/mrc/include/mrc/utils/tuple_utils.hpp new file mode 100644 index 000000000..edf0f2e9d --- /dev/null +++ b/cpp/mrc/include/mrc/utils/tuple_utils.hpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace mrc::utils { + +template +auto tuple_surely(TupleT&& tuple, std::index_sequence /*unused*/) +{ + return std::tuple>::value_type...>( + (std::get(tuple).value())...); +} + +/** + * @brief Converts a std::tuple, std::optional, ...> to std::tuple + * + * @tparam TupleT The type of tuple + * @param tuple + * @return auto A new Tuple with `std::optional` types removed + */ +template +auto tuple_surely(TupleT&& tuple) +{ + return tuple_surely(std::forward(tuple), + std::make_index_sequence>::value>()); +} + +template +void tuple_for_each(TupleT&& tuple, FuncT&& f, std::index_sequence /*unused*/) +{ + (f(std::get(std::forward(tuple)), Is), ...); +} + +/** + * @brief Executes a function for each element of a tuple. + * + * @tparam TupleT The type of the tuple + * @tparam FuncT The type of the lambda + * @param tuple Tuple to run the function on + * @param f A function which accepts an element of the tuple as the first arg and the index for the second arg. + * Recommended to use `auto` or a templated lambda as the first argument + */ +template +void tuple_for_each(TupleT&& tuple, FuncT&& f) +{ + tuple_for_each(std::forward(tuple), + std::forward(f), + std::make_index_sequence>::value>()); +} +} // namespace mrc::utils diff --git a/cpp/mrc/src/internal/codable/decodable_storage_view.cpp b/cpp/mrc/src/internal/codable/decodable_storage_view.cpp index a4db24dac..5d29c7128 100644 --- a/cpp/mrc/src/internal/codable/decodable_storage_view.cpp +++ b/cpp/mrc/src/internal/codable/decodable_storage_view.cpp @@ -37,7 +37,6 @@ #include #include #include -#include namespace mrc::codable { diff --git a/cpp/mrc/src/internal/codable/storage_view.cpp b/cpp/mrc/src/internal/codable/storage_view.cpp index 3ae474ad7..834af06e1 100644 --- a/cpp/mrc/src/internal/codable/storage_view.cpp +++ b/cpp/mrc/src/internal/codable/storage_view.cpp @@ -19,7 +19,6 @@ #include -#include #include namespace mrc::codable { diff --git a/cpp/mrc/src/internal/control_plane/client.cpp b/cpp/mrc/src/internal/control_plane/client.cpp index 7a85adc2e..54f68a5da 100644 --- a/cpp/mrc/src/internal/control_plane/client.cpp +++ b/cpp/mrc/src/internal/control_plane/client.cpp @@ -19,8 +19,10 @@ #include "internal/control_plane/client/connections_manager.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler +#include "internal/grpc/stream_writer.hpp" // for StreamWriter #include "internal/runnable/runnable_resources.hpp" +#include "internal/service.hpp" #include "internal/system/system.hpp" #include "mrc/channel/status.hpp" @@ -33,23 +35,42 @@ #include "mrc/runnable/launch_control.hpp" #include "mrc/runnable/launcher.hpp" #include "mrc/runnable/runner.hpp" +#include "mrc/types.hpp" +#include // for promise #include #include #include +#include #include namespace mrc::control_plane { +std::atomic_uint64_t AsyncEventStatus::s_request_id_counter; + +AsyncEventStatus::AsyncEventStatus() : m_request_id(++s_request_id_counter) {} + +size_t AsyncEventStatus::request_id() const +{ + return m_request_id; +} + +void AsyncEventStatus::set_future(Future future) +{ + m_future = std::move(future); +} + Client::Client(resources::PartitionResourceBase& base, std::shared_ptr cq) : resources::PartitionResourceBase(base), + Service("control_plane::Client"), m_cq(std::move(cq)), m_owns_progress_engine(false) {} Client::Client(resources::PartitionResourceBase& base) : resources::PartitionResourceBase(base), + Service("control_plane::Client"), m_cq(std::make_shared()), m_owns_progress_engine(true) {} @@ -73,13 +94,11 @@ void Client::do_service_start() if (m_owns_progress_engine) { CHECK(m_cq); - auto progress_engine = std::make_unique(m_cq); - auto progress_handler = std::make_unique(); + auto progress_engine = std::make_unique(m_cq); + m_progress_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *progress_handler); + mrc::make_edge(*progress_engine, *m_progress_handler); - m_progress_handler = - runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_handler))->ignition(); m_progress_engine = runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_engine))->ignition(); } @@ -135,7 +154,6 @@ void Client::do_service_await_live() if (m_owns_progress_engine) { m_progress_engine->await_live(); - m_progress_handler->await_live(); } m_event_handler->await_live(); } @@ -150,7 +168,6 @@ void Client::do_service_await_join() { m_cq->Shutdown(); m_progress_engine->await_join(); - m_progress_handler->await_join(); } } @@ -161,10 +178,21 @@ void Client::do_handle_event(event_t&& event) // handle a subset of events directly on the event handler case protos::EventType::Response: { - auto* promise = reinterpret_cast*>(event.msg.tag()); - if (promise != nullptr) + auto event_tag = event.msg.tag(); + + if (event_tag != 0) { - promise->set_value(std::move(event.msg)); + // Lock to prevent multiple threads + std::unique_lock lock(m_mutex); + + // Find the promise associated with the event tag + auto promise = m_pending_events.extract(event_tag); + + // Unlock to allow other threads to continue as soon as possible + lock.unlock(); + + // Finally, set the value + promise.mapped().set_value(std::move(event.msg)); } } break; @@ -242,11 +270,11 @@ const mrc::runnable::LaunchOptions& Client::launch_options() const return m_launch_options; } -void Client::issue_event(const protos::EventType& event_type) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type) { protos::Event event; event.set_event(event_type); - m_writer->await_write(std::move(event)); + return this->write_event(std::move(event), false); } void Client::request_update() @@ -260,4 +288,37 @@ void Client::request_update() // } } +AsyncEventStatus Client::write_event(protos::Event event, bool await_response) +{ + if (event.tag() != 0) + { + LOG(WARNING) << "event tag is set but this field should exclusively be used by the control plane client. " + "Clearing to avoid confusion"; + event.clear_tag(); + } + + AsyncEventStatus status; + + if (await_response) + { + // If we are supporting awaiting, create the promise now + Promise promise; + + // Set the future to the status + status.set_future(promise.get_future()); + + // Set the tag to the request ID to allow looking up the promise later + event.set_tag(status.request_id()); + + // Save the promise to the pending promises to be retrieved later + std::unique_lock lock(m_mutex); + + m_pending_events[status.request_id()] = std::move(promise); + } + + // Finally, write the event + m_writer->await_write(std::move(event)); + + return status; +} } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/control_plane/client.hpp b/cpp/mrc/src/internal/control_plane/client.hpp index 0a07991a6..efda25db8 100644 --- a/cpp/mrc/src/internal/control_plane/client.hpp +++ b/cpp/mrc/src/internal/control_plane/client.hpp @@ -19,22 +19,22 @@ #include "internal/control_plane/client/instance.hpp" // IWYU pragma: keep #include "internal/grpc/client_streaming.hpp" -#include "internal/grpc/stream_writer.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/service.hpp" #include "mrc/core/error.hpp" +#include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" #include "mrc/node/writable_entrypoint.hpp" #include "mrc/protos/architect.grpc.pb.h" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_options.hpp" #include "mrc/types.hpp" -#include "mrc/utils/macros.hpp" -#include #include +#include +#include // for size_t #include #include #include @@ -65,10 +65,56 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +template +struct StreamWriter; +} // namespace mrc::rpc + namespace mrc::control_plane { -template -class AsyncStatus; +class AsyncEventStatus +{ + public: + size_t request_id() const; + + template + Expected await_response() + { + if (!m_future.valid()) + { + throw exceptions::MrcRuntimeError( + "This AsyncEventStatus is not expecting a response or the response has already been awaited"); + } + + auto event = m_future.get(); + + if (event.has_error()) + { + return Error::create(event.error().message()); + } + + ResponseT response; + if (!event.message().UnpackTo(&response)) + { + throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); + } + + return response; + } + + private: + AsyncEventStatus(); + + void set_future(Future future); + + static std::atomic_size_t s_request_id_counter; + + size_t m_request_id; + Future m_future; + + friend class Client; +}; /** * @brief Primary Control Plane Client @@ -128,13 +174,13 @@ class Client final : public resources::PartitionResourceBase, public Service template Expected await_unary(const protos::EventType& event_type, RequestT&& request); - template - void async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status); + template + AsyncEventStatus async_unary(const protos::EventType& event_type, RequestT&& request); template - void issue_event(const protos::EventType& event_type, MessageT&& message); + AsyncEventStatus issue_event(const protos::EventType& event_type, MessageT&& message); - void issue_event(const protos::EventType& event_type); + AsyncEventStatus issue_event(const protos::EventType& event_type); bool has_subscription_service(const std::string& name) const; @@ -150,6 +196,8 @@ class Client final : public resources::PartitionResourceBase, public Service void request_update(); private: + AsyncEventStatus write_event(protos::Event event, bool await_response = false); + void route_state_update(std::uint64_t tag, protos::StateUpdate&& update); void do_service_start() final; @@ -175,7 +223,7 @@ class Client final : public resources::PartitionResourceBase, public Service // if true, then the following runners should not be null // if false, then the following runners must be null const bool m_owns_progress_engine; - std::unique_ptr m_progress_handler; + std::unique_ptr m_progress_handler; std::unique_ptr m_progress_engine; std::unique_ptr m_event_handler; @@ -201,70 +249,39 @@ class Client final : public resources::PartitionResourceBase, public Service std::mutex m_mutex; + std::map> m_pending_events; + friend network::NetworkResources; }; // todo: create this object from the client which will own the stop_source // create this object with a stop_token associated with the client's stop_source -template -class AsyncStatus -{ - public: - AsyncStatus() = default; - - DELETE_COPYABILITY(AsyncStatus); - DELETE_MOVEABILITY(AsyncStatus); - - Expected await_response() - { - // todo(ryan): expand this into a wait_until with a deadline and a stop token - auto event = m_promise.get_future().get(); - - if (event.has_error()) - { - return Error::create(event.error().message()); - } - - ResponseT response; - if (!event.message().UnpackTo(&response)) - { - throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); - } - - return response; - } - - private: - Promise m_promise; - friend Client; -}; - template Expected Client::await_unary(const protos::EventType& event_type, RequestT&& request) { - AsyncStatus status; - async_unary(event_type, std::move(request), status); - return status.await_response(); + auto status = this->async_unary(event_type, std::move(request)); + return status.template await_response(); } -template -void Client::async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status) +template +AsyncEventStatus Client::async_unary(const protos::EventType& event_type, RequestT&& request) { protos::Event event; event.set_event(event_type); - event.set_tag(reinterpret_cast(&status.m_promise)); CHECK(event.mutable_message()->PackFrom(request)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), true); } template -void Client::issue_event(const protos::EventType& event_type, MessageT&& message) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type, MessageT&& message) { protos::Event event; event.set_event(event_type); CHECK(event.mutable_message()->PackFrom(message)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), false); } } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp b/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp index 76cc2477e..1cb40b953 100644 --- a/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp @@ -31,7 +31,6 @@ #include #include -#include #include #include #include diff --git a/cpp/mrc/src/internal/control_plane/client/instance.cpp b/cpp/mrc/src/internal/control_plane/client/instance.cpp index 65c0040ad..5843c59a8 100644 --- a/cpp/mrc/src/internal/control_plane/client/instance.cpp +++ b/cpp/mrc/src/internal/control_plane/client/instance.cpp @@ -24,6 +24,7 @@ #include "internal/utils/contains.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_control.hpp" @@ -49,6 +50,7 @@ Instance::Instance(Client& client, resources::PartitionResourceBase& base, mrc::edge::IWritableAcceptor& update_channel) : resources::PartitionResourceBase(base), + Service("control_plane::client::Instance"), m_client(client), m_instance_id(instance_id) { diff --git a/cpp/mrc/src/internal/control_plane/client/state_manager.cpp b/cpp/mrc/src/internal/control_plane/client/state_manager.cpp index 1970e3574..e21fc6519 100644 --- a/cpp/mrc/src/internal/control_plane/client/state_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/client/state_manager.cpp @@ -22,6 +22,7 @@ #include "mrc/core/error.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_control.hpp" diff --git a/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp b/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp index 50e6e2351..c190e3995 100644 --- a/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp +++ b/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp @@ -34,6 +34,7 @@ namespace mrc::control_plane::client { SubscriptionService::SubscriptionService(const std::string& service_name, Instance& instance) : + Service("control_plane::client::SubscriptionService"), m_service_name(std::move(service_name)), m_instance(instance) { diff --git a/cpp/mrc/src/internal/control_plane/server.cpp b/cpp/mrc/src/internal/control_plane/server.cpp index aa980aba8..afaee91c7 100644 --- a/cpp/mrc/src/internal/control_plane/server.cpp +++ b/cpp/mrc/src/internal/control_plane/server.cpp @@ -41,7 +41,6 @@ #include #include -#include #include #include #include @@ -86,9 +85,16 @@ static Expected<> unary_response(Server::event_t& event, Expected&& me return {}; } -Server::Server(runnable::RunnableResources& runnable) : m_runnable(runnable), m_server(m_runnable) {} +Server::Server(runnable::RunnableResources& runnable) : + Service("control_plane::Server"), + m_runnable(runnable), + m_server(m_runnable) +{} -Server::~Server() = default; +Server::~Server() +{ + Service::call_in_destructor(); +} void Server::do_service_start() { diff --git a/cpp/mrc/src/internal/control_plane/server.hpp b/cpp/mrc/src/internal/control_plane/server.hpp index d3d319502..6f7464de9 100644 --- a/cpp/mrc/src/internal/control_plane/server.hpp +++ b/cpp/mrc/src/internal/control_plane/server.hpp @@ -35,7 +35,7 @@ #include #include #include - +// IWYU pragma: no_include "internal/control_plane/server/subscription_manager.hpp" // IWYU pragma: no_forward_declare mrc::node::WritableEntrypoint namespace mrc::node { @@ -45,7 +45,7 @@ class Queue; namespace mrc::control_plane::server { class ClientInstance; -class SubscriptionService; +class SubscriptionService; // IWYU pragma: keep } // namespace mrc::control_plane::server namespace mrc::rpc { template diff --git a/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp b/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp index 617c3b4c6..2098f283b 100644 --- a/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp @@ -27,7 +27,6 @@ #include #include -#include #include #include diff --git a/cpp/mrc/src/internal/data_plane/client.cpp b/cpp/mrc/src/internal/data_plane/client.cpp index 0f0a5ee4c..dc8709e43 100644 --- a/cpp/mrc/src/internal/data_plane/client.cpp +++ b/cpp/mrc/src/internal/data_plane/client.cpp @@ -25,7 +25,7 @@ #include "internal/memory/transient_pool.hpp" #include "internal/remote_descriptor/manager.hpp" #include "internal/runnable/runnable_resources.hpp" -#include "internal/ucx/common.hpp" +#include "internal/service.hpp" #include "internal/ucx/endpoint.hpp" #include "internal/ucx/ucx_resources.hpp" #include "internal/ucx/worker.hpp" @@ -53,7 +53,6 @@ #include #include #include -#include namespace mrc::data_plane { @@ -64,13 +63,17 @@ Client::Client(resources::PartitionResourceBase& base, control_plane::client::ConnectionsManager& connections_manager, memory::TransientPool& transient_pool) : resources::PartitionResourceBase(base), + Service("data_plane::Client"), m_ucx(ucx), m_connnection_manager(connections_manager), m_transient_pool(transient_pool), m_rd_channel(std::make_unique>()) {} -Client::~Client() = default; +Client::~Client() +{ + Service::call_in_destructor(); +} std::shared_ptr Client::endpoint_shared(const InstanceID& id) const { diff --git a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp index 3ecf2d3f6..78cf64f7e 100644 --- a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp +++ b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp @@ -38,6 +38,7 @@ DataPlaneResources::DataPlaneResources(resources::PartitionResourceBase& base, const InstanceID& instance_id, control_plane::Client& control_plane_client) : resources::PartitionResourceBase(base), + Service("DataPlaneResources"), m_ucx(ucx), m_host(host), m_control_plane_client(control_plane_client), diff --git a/cpp/mrc/src/internal/data_plane/server.cpp b/cpp/mrc/src/internal/data_plane/server.cpp index a230ad934..d2f3974c9 100644 --- a/cpp/mrc/src/internal/data_plane/server.cpp +++ b/cpp/mrc/src/internal/data_plane/server.cpp @@ -36,7 +36,6 @@ #include "mrc/runnable/runner.hpp" #include "mrc/types.hpp" -#include #include #include #include @@ -47,7 +46,6 @@ #include #include #include -#include #include #include @@ -148,6 +146,7 @@ Server::Server(resources::PartitionResourceBase& provider, memory::TransientPool& transient_pool, InstanceID instance_id) : resources::PartitionResourceBase(provider), + Service("data_plane::Server"), m_ucx(ucx), m_host(host), m_instance_id(instance_id), diff --git a/cpp/mrc/src/internal/executor/executor_definition.cpp b/cpp/mrc/src/internal/executor/executor_definition.cpp index de630115d..a341f4434 100644 --- a/cpp/mrc/src/internal/executor/executor_definition.cpp +++ b/cpp/mrc/src/internal/executor/executor_definition.cpp @@ -76,6 +76,7 @@ static bool valid_pipeline(const pipeline::PipelineDefinition& pipeline) ExecutorDefinition::ExecutorDefinition(std::unique_ptr system) : SystemProvider(std::move(system)), + Service("ExecutorDefinition"), m_resources_manager(std::make_unique(*this)) {} @@ -128,7 +129,6 @@ void ExecutorDefinition::join() void ExecutorDefinition::do_service_start() { CHECK(m_pipeline_manager); - m_pipeline_manager->service_start(); pipeline::SegmentAddresses initial_segments; for (const auto& [id, segment] : m_pipeline_manager->pipeline().segments()) diff --git a/cpp/mrc/src/internal/grpc/client_streaming.hpp b/cpp/mrc/src/internal/grpc/client_streaming.hpp index 8ee6bd82e..ad2c82fb5 100644 --- a/cpp/mrc/src/internal/grpc/client_streaming.hpp +++ b/cpp/mrc/src/internal/grpc/client_streaming.hpp @@ -18,6 +18,7 @@ #pragma once #include "internal/grpc/progress_engine.hpp" +#include "internal/grpc/promise_handler.hpp" #include "internal/grpc/stream_writer.hpp" #include "internal/runnable/runnable_resources.hpp" #include "internal/service.hpp" @@ -152,6 +153,7 @@ class ClientStream : private Service, public std::enable_shared_from_this>(grpc::ClientContext* context)>; ClientStream(prepare_fn_t prepare_fn, runnable::RunnableResources& runnable) : + Service("rpc::ClientStream"), m_prepare_fn(prepare_fn), m_runnable(runnable), m_reader_source(std::make_unique>( @@ -195,10 +197,10 @@ class ClientStream : private Service, public std::enable_shared_from_this read; + auto* wrapper = new PromiseWrapper("Client::Read"); IncomingData data; - m_stream->Read(&data.msg, &read); - auto ok = read.get_future().get(); + m_stream->Read(&data.msg, wrapper); + auto ok = wrapper->get_future(); if (!ok) { m_write_channel.reset(); @@ -216,9 +218,9 @@ class ClientStream : private Service, public std::enable_shared_from_this promise; - m_stream->Write(request, &promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Client::Write"); + m_stream->Write(request, wrapper); + auto ok = wrapper->get_future(); if (!ok) { m_can_write = false; @@ -234,10 +236,20 @@ class ClientStream : private Service, public std::enable_shared_from_this writes_done; - m_stream->WritesDone(&writes_done); - writes_done.get_future().get(); - DVLOG(10) << "client issued writes done to server"; + { + auto* wrapper = new PromiseWrapper("Client::WritesDone"); + m_stream->WritesDone(wrapper); + wrapper->get_future(); + } + + { + // Now issue finish since this is OK at the client level + auto* wrapper = new PromiseWrapper("Client::Finish"); + m_stream->Finish(&m_status, wrapper); + wrapper->get_future(); + } + + // DVLOG(10) << "client issued writes done to server"; }; } @@ -284,9 +296,9 @@ class ClientStream : private Service, public std::enable_shared_from_this promise; - m_stream->StartCall(&promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Client::StartCall", false); + m_stream->StartCall(wrapper); + auto ok = wrapper->get_future(); if (!ok) { @@ -327,10 +339,6 @@ class ClientStream : private Service, public std::enable_shared_from_thisawait_join(); m_reader->await_join(); - - Promise finish; - m_stream->Finish(&m_status, &finish); - auto ok = finish.get_future().get(); } } diff --git a/cpp/mrc/src/internal/grpc/progress_engine.cpp b/cpp/mrc/src/internal/grpc/progress_engine.cpp index 68f157bf5..f540bf8b9 100644 --- a/cpp/mrc/src/internal/grpc/progress_engine.cpp +++ b/cpp/mrc/src/internal/grpc/progress_engine.cpp @@ -23,7 +23,6 @@ #include #include -#include #include #include @@ -40,6 +39,9 @@ void ProgressEngine::data_source(rxcpp::subscriber& s) while (s.is_subscribed()) { + event.ok = false; + event.tag = nullptr; + switch (m_cq->AsyncNext(&event.tag, &event.ok, gpr_time_0(GPR_CLOCK_REALTIME))) { case grpc::CompletionQueue::NextStatus::GOT_EVENT: { diff --git a/cpp/mrc/src/internal/grpc/progress_engine.hpp b/cpp/mrc/src/internal/grpc/progress_engine.hpp index 7bea6239e..23afa26f1 100644 --- a/cpp/mrc/src/internal/grpc/progress_engine.hpp +++ b/cpp/mrc/src/internal/grpc/progress_engine.hpp @@ -23,7 +23,6 @@ #include #include -#include namespace grpc { class CompletionQueue; diff --git a/cpp/mrc/src/internal/grpc/promise_handler.cpp b/cpp/mrc/src/internal/grpc/promise_handler.cpp new file mode 100644 index 000000000..444d69738 --- /dev/null +++ b/cpp/mrc/src/internal/grpc/promise_handler.cpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/grpc/promise_handler.hpp" + +// MRC_CONCAT_STR is needed for debug builds, in CI IWYU is run with a release config +#include "mrc/utils/string_utils.hpp" // IWYU pragma: keep for MRC_CONCAT_STR + +#include // for future +#include // for COMPACT_GOOGLE_LOG_INFO + +#include +#include // for operator<<, basic_ostream +#include // for move + +namespace mrc::rpc { + +std::atomic_size_t PromiseWrapper::s_id_counter = 0; + +PromiseWrapper::PromiseWrapper(std::string method, bool in_runtime) : id(++s_id_counter), method(std::move(method)) +{ +#if (!defined(NDEBUG)) + this->prefix = MRC_CONCAT_STR("Promise[" << id << ", " << this << "](" << method << "): "); +#endif + VLOG(20) << this->to_string() << "#1 creating promise"; +} + +void PromiseWrapper::set_value(bool val) +{ + auto tmp_prefix = this->to_string(); + + VLOG(20) << tmp_prefix << "#2 setting promise to " << val; + this->promise.set_value(val); + VLOG(20) << tmp_prefix << "#3 setting promise to " << val << "... done"; +} + +bool PromiseWrapper::get_future() +{ + auto future = this->promise.get_future(); + + auto value = future.get(); + + VLOG(20) << this->to_string() << "#4 got future with value " << value; + + return value; +} + +std::string PromiseWrapper::to_string() const +{ + return this->prefix; +} + +} // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/promise_handler.hpp b/cpp/mrc/src/internal/grpc/promise_handler.hpp index 437a22e69..0220eb685 100644 --- a/cpp/mrc/src/internal/grpc/promise_handler.hpp +++ b/cpp/mrc/src/internal/grpc/promise_handler.hpp @@ -20,21 +20,55 @@ #include "internal/grpc/progress_engine.hpp" #include "mrc/node/generic_sink.hpp" +#include "mrc/node/sink_properties.hpp" // for SinkProperties, Status -#include +#include // for promise + +#include // for atomic_size_t +#include // for size_t +#include namespace mrc::rpc { +struct PromiseWrapper +{ + PromiseWrapper(std::string method, bool in_runtime = true); + + ~PromiseWrapper() = default; + + size_t id; + std::string method; + std::string prefix; + boost::fibers::promise promise; + + void set_value(bool val); + + bool get_future(); + + std::string to_string() const; + + private: + static std::atomic_size_t s_id_counter; +}; + /** * @brief MRC Sink to handle ProgressEvents which correspond to Promise tags */ -class PromiseHandler final : public mrc::node::GenericSink +class PromiseHandler final : public mrc::node::GenericSinkComponent { - void on_data(ProgressEvent&& event) final + mrc::channel::Status on_data(ProgressEvent&& event) final { - auto* promise = static_cast*>(event.tag); + auto* promise = static_cast(event.tag); + promise->set_value(event.ok); - } + return mrc::channel::Status::success; + delete promise; + }; + + void on_complete() override + { + SinkProperties::release_edge_connection(); + }; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/server.cpp b/cpp/mrc/src/internal/grpc/server.cpp index 9e0c0ecb4..e03293d15 100644 --- a/cpp/mrc/src/internal/grpc/server.cpp +++ b/cpp/mrc/src/internal/grpc/server.cpp @@ -18,7 +18,7 @@ #include "internal/grpc/server.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler #include "internal/runnable/runnable_resources.hpp" #include "mrc/edge/edge_builder.hpp" @@ -31,7 +31,7 @@ namespace mrc::rpc { -Server::Server(runnable::RunnableResources& runnable) : m_runnable(runnable) +Server::Server(runnable::RunnableResources& runnable) : Service("rpc::Server"), m_runnable(runnable) { m_cq = m_builder.AddCompletionQueue(); m_builder.AddListeningPort("0.0.0.0:13337", grpc::InsecureServerCredentials()); @@ -47,11 +47,10 @@ void Server::do_service_start() m_server = m_builder.BuildAndStart(); auto progress_engine = std::make_unique(m_cq); - auto event_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *event_handler); + m_event_hander = std::make_unique(); + mrc::make_edge(*progress_engine, *m_event_hander); m_progress_engine = m_runnable.launch_control().prepare_launcher(std::move(progress_engine))->ignition(); - m_event_hander = m_runnable.launch_control().prepare_launcher(std::move(event_handler))->ignition(); } void Server::do_service_stop() @@ -70,19 +69,17 @@ void Server::do_service_kill() void Server::do_service_await_live() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_live(); - m_event_hander->await_live(); } } void Server::do_service_await_join() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_join(); - m_event_hander->await_join(); } } diff --git a/cpp/mrc/src/internal/grpc/server.hpp b/cpp/mrc/src/internal/grpc/server.hpp index cacd4602d..db9436d95 100644 --- a/cpp/mrc/src/internal/grpc/server.hpp +++ b/cpp/mrc/src/internal/grpc/server.hpp @@ -34,6 +34,10 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +} // namespace mrc::rpc + namespace mrc::rpc { class Server : public Service @@ -61,7 +65,7 @@ class Server : public Service std::shared_ptr m_cq; std::unique_ptr m_server; std::unique_ptr m_progress_engine; - std::unique_ptr m_event_hander; + std::unique_ptr m_event_hander; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/server_streaming.hpp b/cpp/mrc/src/internal/grpc/server_streaming.hpp index 0d4da8b44..f2d50e1d4 100644 --- a/cpp/mrc/src/internal/grpc/server_streaming.hpp +++ b/cpp/mrc/src/internal/grpc/server_streaming.hpp @@ -18,6 +18,7 @@ #pragma once #include "internal/grpc/progress_engine.hpp" +#include "internal/grpc/promise_handler.hpp" #include "internal/grpc/stream_writer.hpp" #include "internal/runnable/runnable_resources.hpp" #include "internal/service.hpp" @@ -164,6 +165,7 @@ class ServerStream : private Service, public std::enable_shared_from_this* stream, void* tag)>; ServerStream(request_fn_t request_fn, runnable::RunnableResources& runnable) : + Service("rpc::ServerStream"), m_runnable(runnable), m_stream(std::make_unique>(&m_context)), m_reader_source(std::make_unique>( @@ -223,10 +225,11 @@ class ServerStream : private Service, public std::enable_shared_from_this read; + IncomingData data; - m_stream->Read(&data.msg, &read); - auto ok = read.get_future().get(); + auto* wrapper = new PromiseWrapper("Server::Read"); + m_stream->Read(&data.msg, wrapper); + auto ok = wrapper->get_future(); data.ok = ok; data.stream = writer(); s.on_next(std::move(data)); @@ -247,9 +250,9 @@ class ServerStream : private Service, public std::enable_shared_from_this promise; - m_stream->Write(request, &promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Server::Write"); + m_stream->Write(request, wrapper); + auto ok = wrapper->get_future(); if (!ok) { DVLOG(10) << "server failed to write to client; disabling writes and beginning shutdown"; @@ -272,10 +275,10 @@ class ServerStream : private Service, public std::enable_shared_from_this finish; - m_stream->Finish(*m_status, &finish); - auto ok = finish.get_future().get(); - DVLOG(10) << "server done with finish"; + auto* wrapper = new PromiseWrapper("Server::Finish"); + m_stream->Finish(*m_status, wrapper); + auto ok = wrapper->get_future(); + // DVLOG(10) << "server done with finish"; } } @@ -317,10 +320,9 @@ class ServerStream : private Service, public std::enable_shared_from_this promise; - m_init_fn(&promise); - auto ok = promise.get_future().get(); - + auto* wrapper = new PromiseWrapper("Server::m_init_fn"); + m_init_fn(wrapper); + auto ok = wrapper->get_future(); if (!ok) { DVLOG(10) << "server stream could not be initialized"; diff --git a/cpp/mrc/src/internal/memory/device_resources.cpp b/cpp/mrc/src/internal/memory/device_resources.cpp index 907eb1a4a..9ec0f5b04 100644 --- a/cpp/mrc/src/internal/memory/device_resources.cpp +++ b/cpp/mrc/src/internal/memory/device_resources.cpp @@ -35,16 +35,12 @@ #include "mrc/types.hpp" #include "mrc/utils/bytes_to_string.hpp" -#include #include -#include -#include #include #include #include #include -#include namespace mrc::memory { diff --git a/cpp/mrc/src/internal/memory/host_resources.cpp b/cpp/mrc/src/internal/memory/host_resources.cpp index c98c78618..42acfd32b 100644 --- a/cpp/mrc/src/internal/memory/host_resources.cpp +++ b/cpp/mrc/src/internal/memory/host_resources.cpp @@ -35,13 +35,10 @@ #include "mrc/types.hpp" #include "mrc/utils/bytes_to_string.hpp" -#include #include -#include #include #include -#include #include #include #include diff --git a/cpp/mrc/src/internal/network/network_resources.cpp b/cpp/mrc/src/internal/network/network_resources.cpp index b28a0d14f..ea078bee5 100644 --- a/cpp/mrc/src/internal/network/network_resources.cpp +++ b/cpp/mrc/src/internal/network/network_resources.cpp @@ -27,7 +27,6 @@ #include "mrc/core/task_queue.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/pipeline/controller.cpp b/cpp/mrc/src/internal/pipeline/controller.cpp index 93946abbe..459817351 100644 --- a/cpp/mrc/src/internal/pipeline/controller.cpp +++ b/cpp/mrc/src/internal/pipeline/controller.cpp @@ -31,12 +31,10 @@ #include #include #include -#include #include #include #include #include -#include namespace mrc::pipeline { diff --git a/cpp/mrc/src/internal/pipeline/manager.cpp b/cpp/mrc/src/internal/pipeline/manager.cpp index 0487fdfb9..abec10d4a 100644 --- a/cpp/mrc/src/internal/pipeline/manager.cpp +++ b/cpp/mrc/src/internal/pipeline/manager.cpp @@ -34,16 +34,14 @@ #include #include -#include #include #include -#include #include -#include namespace mrc::pipeline { Manager::Manager(std::shared_ptr pipeline, resources::Manager& resources) : + Service("pipeline::Manager"), m_pipeline(std::move(pipeline)), m_resources(resources) { diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp index 50e3abca1..dddd73a3c 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp @@ -24,6 +24,7 @@ #include "internal/runnable/runnable_resources.hpp" #include "internal/segment/segment_definition.hpp" #include "internal/segment/segment_instance.hpp" +#include "internal/service.hpp" #include "mrc/core/addresses.hpp" #include "mrc/core/task_queue.hpp" @@ -46,13 +47,17 @@ namespace mrc::pipeline { PipelineInstance::PipelineInstance(std::shared_ptr definition, resources::Manager& resources) : PipelineResources(resources), + Service("pipeline::PipelineInstance"), m_definition(std::move(definition)) { CHECK(m_definition); m_joinable_future = m_joinable_promise.get_future().share(); } -PipelineInstance::~PipelineInstance() = default; +PipelineInstance::~PipelineInstance() +{ + Service::call_in_destructor(); +} void PipelineInstance::update() { diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp index d9f2489b8..7dc51e38e 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp @@ -25,12 +25,13 @@ #include #include #include +// IWYU pragma: no_include "internal/segment/segment_instance.hpp" namespace mrc::resources { class Manager; } // namespace mrc::resources namespace mrc::segment { -class SegmentInstance; +class SegmentInstance; // IWYU pragma: keep } // namespace mrc::segment namespace mrc::manifold { struct Interface; diff --git a/cpp/mrc/src/internal/pubsub/publisher_service.cpp b/cpp/mrc/src/internal/pubsub/publisher_service.cpp index 2ea517e44..5175e5315 100644 --- a/cpp/mrc/src/internal/pubsub/publisher_service.cpp +++ b/cpp/mrc/src/internal/pubsub/publisher_service.cpp @@ -39,10 +39,8 @@ #include #include -#include #include #include -#include namespace mrc::pubsub { diff --git a/cpp/mrc/src/internal/pubsub/subscriber_service.cpp b/cpp/mrc/src/internal/pubsub/subscriber_service.cpp index c53dac546..fba47135b 100644 --- a/cpp/mrc/src/internal/pubsub/subscriber_service.cpp +++ b/cpp/mrc/src/internal/pubsub/subscriber_service.cpp @@ -27,6 +27,7 @@ #include "internal/runtime/partition.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/operators/router.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/codable.pb.h" @@ -41,7 +42,6 @@ #include #include #include -#include namespace mrc::pubsub { diff --git a/cpp/mrc/src/internal/remote_descriptor/manager.cpp b/cpp/mrc/src/internal/remote_descriptor/manager.cpp index fe73a61bc..b624b7c82 100644 --- a/cpp/mrc/src/internal/remote_descriptor/manager.cpp +++ b/cpp/mrc/src/internal/remote_descriptor/manager.cpp @@ -55,9 +55,7 @@ #include #include #include -#include #include -#include namespace mrc::remote_descriptor { @@ -86,6 +84,7 @@ ucs_status_t active_message_callback(void* arg, } // namespace Manager::Manager(const InstanceID& instance_id, resources::PartitionResources& resources) : + Service("remote_descriptor::Manager"), m_instance_id(instance_id), m_resources(resources) { diff --git a/cpp/mrc/src/internal/resources/manager.cpp b/cpp/mrc/src/internal/resources/manager.cpp index b47334c04..fab210109 100644 --- a/cpp/mrc/src/internal/resources/manager.cpp +++ b/cpp/mrc/src/internal/resources/manager.cpp @@ -26,6 +26,7 @@ #include "internal/network/network_resources.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/runnable/runnable_resources.hpp" +#include "internal/system/device_partition.hpp" #include "internal/system/engine_factory_cpu_sets.hpp" #include "internal/system/host_partition.hpp" #include "internal/system/partition.hpp" @@ -45,6 +46,7 @@ #include #include +#include #include #include #include @@ -54,16 +56,18 @@ namespace mrc::resources { +std::atomic_size_t Manager::s_id_counter = 0; thread_local Manager* Manager::m_thread_resources{nullptr}; thread_local PartitionResources* Manager::m_thread_partition{nullptr}; Manager::Manager(const system::SystemProvider& system) : SystemProvider(system), + m_runtime_id(++s_id_counter), m_threading(std::make_unique(system)) { const auto& partitions = this->system().partitions().flattened(); const auto& host_partitions = this->system().partitions().host_partitions(); - const bool network_enabled = !this->system().options().architect_url().empty(); + bool network_enabled = !this->system().options().architect_url().empty(); // construct the runnable resources on each host_partition - launch control and main for (std::size_t i = 0; i < host_partitions.size(); ++i) @@ -197,6 +201,11 @@ Manager::~Manager() m_network.clear(); } +std::size_t Manager::runtime_id() const +{ + return m_runtime_id; +} + std::size_t Manager::partition_count() const { return system().partitions().flattened().size(); diff --git a/cpp/mrc/src/internal/resources/manager.hpp b/cpp/mrc/src/internal/resources/manager.hpp index a823bbe27..55e4af014 100644 --- a/cpp/mrc/src/internal/resources/manager.hpp +++ b/cpp/mrc/src/internal/resources/manager.hpp @@ -24,25 +24,29 @@ #include "mrc/types.hpp" +#include #include #include #include #include +// IWYU pragma: no_include "internal/memory/device_resources.hpp" +// IWYU pragma: no_include "internal/network/network_resources.hpp" +// IWYU pragma: no_include "internal/ucx/ucx_resources.hpp" namespace mrc::network { -class NetworkResources; +class NetworkResources; // IWYU pragma: keep } // namespace mrc::network namespace mrc::control_plane { class ControlPlaneResources; } // namespace mrc::control_plane namespace mrc::memory { -class DeviceResources; +class DeviceResources; // IWYU pragma: keep } // namespace mrc::memory namespace mrc::system { class ThreadingResources; } // namespace mrc::system namespace mrc::ucx { -class UcxResources; +class UcxResources; // IWYU pragma: keep } // namespace mrc::ucx namespace mrc::runtime { class Runtime; @@ -57,6 +61,8 @@ class Manager final : public system::SystemProvider // Manager(std::unique_ptr resources); ~Manager() override; + std::size_t runtime_id() const; + static Manager& get_resources(); static PartitionResources& get_partition(); @@ -68,6 +74,8 @@ class Manager final : public system::SystemProvider private: Future shutdown(); + const size_t m_runtime_id; // unique id for this runtime + const std::unique_ptr m_threading; std::vector m_runnable; // one per host partition std::vector> m_ucx; // one per flattened partition if network is enabled @@ -82,6 +90,7 @@ class Manager final : public system::SystemProvider // which must be destroyed before all other std::vector> m_network; // one per flattened partition + static std::atomic_size_t s_id_counter; static thread_local PartitionResources* m_thread_partition; static thread_local Manager* m_thread_resources; diff --git a/cpp/mrc/src/internal/runnable/fiber_engine.cpp b/cpp/mrc/src/internal/runnable/fiber_engine.cpp index 10dc1eb51..f208d5791 100644 --- a/cpp/mrc/src/internal/runnable/fiber_engine.cpp +++ b/cpp/mrc/src/internal/runnable/fiber_engine.cpp @@ -21,8 +21,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include - #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/runnable/fiber_engines.cpp b/cpp/mrc/src/internal/runnable/fiber_engines.cpp index 87dfa5556..ed720803c 100644 --- a/cpp/mrc/src/internal/runnable/fiber_engines.cpp +++ b/cpp/mrc/src/internal/runnable/fiber_engines.cpp @@ -27,7 +27,6 @@ #include #include -#include #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/runnable/runnable_resources.cpp b/cpp/mrc/src/internal/runnable/runnable_resources.cpp index 4fa98f1ce..9930c7778 100644 --- a/cpp/mrc/src/internal/runnable/runnable_resources.cpp +++ b/cpp/mrc/src/internal/runnable/runnable_resources.cpp @@ -27,7 +27,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/runnable/thread_engine.cpp b/cpp/mrc/src/internal/runnable/thread_engine.cpp index fb18c3b60..b22edd730 100644 --- a/cpp/mrc/src/internal/runnable/thread_engine.cpp +++ b/cpp/mrc/src/internal/runnable/thread_engine.cpp @@ -24,7 +24,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/runnable/thread_engines.cpp b/cpp/mrc/src/internal/runnable/thread_engines.cpp index 23f9c430a..92ea1a65e 100644 --- a/cpp/mrc/src/internal/runnable/thread_engines.cpp +++ b/cpp/mrc/src/internal/runnable/thread_engines.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/segment/builder_definition.cpp b/cpp/mrc/src/internal/segment/builder_definition.cpp index e631c3f1e..c2cf2ae19 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.cpp +++ b/cpp/mrc/src/internal/segment/builder_definition.cpp @@ -28,9 +28,9 @@ #include "mrc/modules/properties/persistent.hpp" // IWYU pragma: keep #include "mrc/modules/segment_modules.hpp" #include "mrc/node/port_registry.hpp" +#include "mrc/runnable/launchable.hpp" #include "mrc/segment/egress_port.hpp" // IWYU pragma: keep #include "mrc/segment/ingress_port.hpp" // IWYU pragma: keep -#include "mrc/segment/initializers.hpp" #include "mrc/segment/object.hpp" #include "mrc/types.hpp" @@ -164,7 +164,7 @@ std::shared_ptr BuilderDefinition::get_egress(std::string name void BuilderDefinition::init_module(std::shared_ptr smodule) { - this->ns_push(smodule); + this->module_push(smodule); VLOG(2) << "Initializing module: " << m_namespace_prefix; smodule->m_module_instance_registered_namespace = m_namespace_prefix; smodule->initialize(*this); @@ -177,7 +177,8 @@ void BuilderDefinition::init_module(std::shared_ptr // Just save to a vector to keep it alive m_modules.push_back(persist); } - this->ns_pop(); + + this->module_pop(smodule); } void BuilderDefinition::register_module_input(std::string input_name, std::shared_ptr object) @@ -366,6 +367,24 @@ void BuilderDefinition::add_object(const std::string& name, std::shared_ptr<::mr // Save by the original name m_egress_ports[local_name] = egress_port; } + + // Now register any child objects + auto children = object->get_children(); + + if (!children.empty()) + { + // Push the namespace for this object + this->ns_push(local_name); + + for (auto& [child_name, child_object] : children) + { + // Add the child object + this->add_object(child_name, child_object); + } + + // Pop the namespace for this object + this->ns_pop(local_name); + } } std::shared_ptr<::mrc::segment::IngressPortBase> BuilderDefinition::get_ingress_base(const std::string& name) @@ -402,20 +421,43 @@ std::function BuilderDefinition::make_throughput_counter(con }; } -void BuilderDefinition::ns_push(std::shared_ptr smodule) +std::string BuilderDefinition::module_push(std::shared_ptr smodule) { m_module_stack.push_back(smodule); - m_namespace_stack.push_back(smodule->component_prefix()); + + return this->ns_push(smodule->component_prefix()); +} + +std::string BuilderDefinition::module_pop(std::shared_ptr smodule) +{ + CHECK_EQ(smodule, m_module_stack.back()) + << "Namespace stack mismatch. Expected " << m_module_stack.back()->component_prefix() << " but got " + << smodule->component_prefix(); + + m_module_stack.pop_back(); + + return this->ns_pop(smodule->component_prefix()); +} + +std::string BuilderDefinition::ns_push(const std::string& name) +{ + m_namespace_stack.push_back(name); m_namespace_prefix = std::accumulate(m_namespace_stack.begin(), m_namespace_stack.end(), std::string(""), ::accum_merge); + + return m_namespace_prefix; } -void BuilderDefinition::ns_pop() +std::string BuilderDefinition::ns_pop(const std::string& name) { - m_module_stack.pop_back(); + CHECK_EQ(name, m_namespace_stack.back()) + << "Namespace stack mismatch. Expected " << m_namespace_stack.back() << " but got " << name; + m_namespace_stack.pop_back(); m_namespace_prefix = std::accumulate(m_namespace_stack.begin(), m_namespace_stack.end(), std::string(""), ::accum_merge); + + return m_namespace_prefix; } } // namespace mrc::segment diff --git a/cpp/mrc/src/internal/segment/builder_definition.hpp b/cpp/mrc/src/internal/segment/builder_definition.hpp index aa0c96140..fa8d0ece3 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.hpp +++ b/cpp/mrc/src/internal/segment/builder_definition.hpp @@ -135,8 +135,11 @@ class BuilderDefinition : public IBuilder // Local methods bool has_object(const std::string& name) const; - void ns_push(std::shared_ptr smodule); - void ns_pop(); + std::string module_push(std::shared_ptr smodule); + std::string module_pop(std::shared_ptr smodule); + + std::string ns_push(const std::string& name); + std::string ns_pop(const std::string& name); // definition std::shared_ptr m_definition; diff --git a/cpp/mrc/src/internal/segment/segment_instance.cpp b/cpp/mrc/src/internal/segment/segment_instance.cpp index 871b7a2ca..53f66b804 100644 --- a/cpp/mrc/src/internal/segment/segment_instance.cpp +++ b/cpp/mrc/src/internal/segment/segment_instance.cpp @@ -36,7 +36,6 @@ #include "mrc/segment/utils.hpp" #include "mrc/types.hpp" -#include #include #include @@ -54,6 +53,7 @@ SegmentInstance::SegmentInstance(std::shared_ptr defini SegmentRank rank, pipeline::PipelineResources& resources, std::size_t partition_id) : + Service("segment::SegmentInstance"), m_name(definition->name()), m_id(definition->id()), m_rank(rank), @@ -78,7 +78,10 @@ SegmentInstance::SegmentInstance(std::shared_ptr defini .get(); } -SegmentInstance::~SegmentInstance() = default; +SegmentInstance::~SegmentInstance() +{ + Service::call_in_destructor(); +} const std::string& SegmentInstance::name() const { diff --git a/cpp/mrc/src/internal/service.cpp b/cpp/mrc/src/internal/service.cpp index 01c51b014..3ea3f6b90 100644 --- a/cpp/mrc/src/internal/service.cpp +++ b/cpp/mrc/src/internal/service.cpp @@ -17,131 +17,293 @@ #include "internal/service.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/exceptions/runtime_error.hpp" +#include "mrc/utils/string_utils.hpp" + #include -#include +#include +#include // for function +#include +#include // for operator<<, basic_ostream #include namespace mrc { +Service::Service(std::string service_name) : m_service_name(std::move(service_name)) {} + Service::~Service() { + if (!m_call_in_destructor_called) + { + LOG(ERROR) << "Must call Service::call_in_destructor to ensure service is cleaned up before being " + "destroyed"; + } + auto state = this->state(); CHECK(state == ServiceState::Initialized || state == ServiceState::Completed); } +const std::string& Service::service_name() const +{ + return m_service_name; +} + +bool Service::is_service_startable() const +{ + std::lock_guard lock(m_mutex); + return (m_state == ServiceState::Initialized); +} + +bool Service::is_running() const +{ + std::lock_guard lock(m_mutex); + return (m_state > ServiceState::Initialized && m_state < ServiceState::Completed); +} + +const ServiceState& Service::state() const +{ + std::lock_guard lock(m_mutex); + return m_state; +} + void Service::service_start() { - if (forward_state(ServiceState::Running)) + std::unique_lock lock(m_mutex); + + if (!this->is_service_startable()) { - do_service_start(); + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service has already been started")); + } + + if (advance_state(ServiceState::Starting)) + { + // Unlock the mutex before calling start to avoid a deadlock + lock.unlock(); + + try + { + this->do_service_start(); + + // Use ensure_state here in case the service itself called stop or kill + this->ensure_state(ServiceState::Running); + } catch (...) + { + // On error, set this to completed and rethrow the error to allow for cleanup + this->advance_state(ServiceState::Completed); + + throw; + } } } void Service::service_await_live() { - do_service_await_live(); + { + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) + { + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "awaiting live")); + } + + // Check if this is our first call to service_await_join + if (!m_service_await_live_called) + { + // Prevent reentry + m_service_await_live_called = true; + + // We now create a promise and a future to track the completion of this function + Promise live_promise; + + m_live_future = live_promise.get_future(); + + // Unlock the mutex before calling await to avoid a deadlock + lock.unlock(); + + try + { + // Now call the await join (this can throw!) + this->do_service_await_live(); + + // Set the value only if there was not an exception + live_promise.set_value(); + + } catch (...) + { + // Join must have thrown, set the exception in the promise (it will be retrieved later) + live_promise.set_exception(std::current_exception()); + } + } + } + + // Wait for the future to be returned. This will rethrow any exception thrown in do_service_await_join + m_live_future.get(); } void Service::service_stop() { - bool execute = false; + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Stopping) - { - execute = (m_state < ServiceState::Stopping); - m_state = ServiceState::Stopping; - } + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "stopping")); } - if (execute) + + // Ensure we are at least in the stopping state. If so, execute the stop call + if (this->ensure_state(ServiceState::Stopping)) { - do_service_stop(); + lock.unlock(); + + this->do_service_stop(); } } void Service::service_kill() { - bool execute = false; + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Killing) - { - execute = (m_state < ServiceState::Killing); - m_state = ServiceState::Killing; - } + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "killing")); } - if (execute) + + // Ensure we are at least in the stopping state. If so, execute the stop call + if (this->ensure_state(ServiceState::Killing)) { - do_service_kill(); + lock.unlock(); + + this->do_service_kill(); } } void Service::service_await_join() { - bool execute = false; { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Completed) + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - execute = (m_state < ServiceState::Completed); - m_state = ServiceState::Awaiting; + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "awaiting join")); } - } - if (execute) - { - do_service_await_join(); - forward_state(ServiceState::Completed); - } -} -const ServiceState& Service::state() const -{ - std::lock_guard lock(m_mutex); - return m_state; -} + // Check if this is our first call to service_await_join + if (!m_service_await_join_called) + { + // Prevent reentry + m_service_await_join_called = true; -bool Service::is_service_startable() const -{ - std::lock_guard lock(m_mutex); - return (m_state == ServiceState::Initialized); + // We now create a promise and a future to track the completion of the service + Promise completed_promise; + + m_completed_future = completed_promise.get_future(); + + // Unlock the mutex before calling await join to avoid a deadlock + lock.unlock(); + + try + { + Unwinder ensure_completed_set([this]() { + // Always set the state to completed before releasing the future + this->advance_state(ServiceState::Completed); + }); + + // Now call the await join (this can throw!) + this->do_service_await_join(); + + // Set the value only if there was not an exception + completed_promise.set_value(); + + } catch (const std::exception& ex) + { + LOG(ERROR) << this->debug_prefix() << " caught exception in service_await_join: " << ex.what(); + // Join must have thrown, set the exception in the promise (it will be retrieved later) + completed_promise.set_exception(std::current_exception()); + } + } + } + + // Wait for the completed future to be returned. This will rethrow any exception thrown in do_service_await_join + m_completed_future.get(); } -bool Service::forward_state(ServiceState new_state) +std::string Service::debug_prefix() const { - std::lock_guard lock(m_mutex); - CHECK(m_state <= new_state) << m_description - << ": invalid ServiceState requested; ServiceState is only allowed to advance"; - if (m_state < new_state) - { - m_state = new_state; - return true; - } - return false; + return MRC_CONCAT_STR("Service[" << m_service_name << "]:"); } void Service::call_in_destructor() { + // Guarantee that we set the flag that this was called + Unwinder ensure_flag([this]() { + m_call_in_destructor_called = true; + }); + auto state = this->state(); if (state > ServiceState::Initialized) { if (state == ServiceState::Running) { - LOG(ERROR) << m_description << ": service was not stopped/killed before being destructed; issuing kill"; - service_kill(); + LOG(ERROR) << this->debug_prefix() + << ": service was not stopped/killed before being destructed; issuing kill"; + this->service_kill(); } if (state != ServiceState::Completed) { - LOG(ERROR) << m_description << ": service was not joined before being destructed; issuing join"; - service_await_join(); + LOG(ERROR) << this->debug_prefix() << ": service was not joined before being destructed; issuing join"; + this->service_await_join(); } } } void Service::service_set_description(std::string description) { - m_description = std::move(description); + m_service_name = std::move(description); +} + +bool Service::advance_state(ServiceState new_state, bool assert_state_change) +{ + std::lock_guard lock(m_mutex); + + // State needs to always be moving foward or the same + CHECK_GE(new_state, m_state) << this->debug_prefix() + << " invalid ServiceState requested; ServiceState is only allowed to advance. " + "Current: " + << m_state << ", Requested: " << new_state; + + if (m_state < new_state) + { + DVLOG(20) << this->debug_prefix() << " advancing state. From: " << m_state << " to " << new_state; + + m_state = new_state; + + return true; + } + + CHECK(!assert_state_change) << this->debug_prefix() + << " invalid ServiceState requested; ServiceState was required to move forward " + "but the state was already set to " + << m_state; + + return false; +} + +bool Service::ensure_state(ServiceState desired_state) +{ + std::lock_guard lock(m_mutex); + + if (desired_state > m_state) + { + return advance_state(desired_state); + } + + return false; } } // namespace mrc diff --git a/cpp/mrc/src/internal/service.hpp b/cpp/mrc/src/internal/service.hpp index f707321e2..47d5b7fab 100644 --- a/cpp/mrc/src/internal/service.hpp +++ b/cpp/mrc/src/internal/service.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,10 @@ #pragma once -#include +#include "mrc/types.hpp" + +#include // for ostream +#include // for logic_error #include namespace mrc { @@ -25,44 +28,90 @@ namespace mrc { enum class ServiceState { Initialized, + Starting, Running, - Awaiting, Stopping, Killing, Completed, }; -// struct IService -// { -// virtual ~IService() = default; +/** + * @brief Converts a `ServiceState` enum to a string + * + * @param f + * @return std::string + */ +inline std::string servicestate_to_str(const ServiceState& s) +{ + switch (s) + { + case ServiceState::Initialized: + return "Initialized"; + case ServiceState::Starting: + return "Starting"; + case ServiceState::Running: + return "Running"; + case ServiceState::Stopping: + return "Stopping"; + case ServiceState::Killing: + return "Killing"; + case ServiceState::Completed: + return "Completed"; + default: + throw std::logic_error("Unsupported ServiceState enum. Was a new value added recently?"); + } +} -// virtual void service_start() = 0; -// virtual void service_await_live() = 0; -// virtual void service_stop() = 0; -// virtual void service_kill() = 0; -// virtual void service_await_join() = 0; -// }; +/** + * @brief Stream operator for `AsyncServiceState` + * + * @param os + * @param f + * @return std::ostream& + */ +static inline std::ostream& operator<<(std::ostream& os, const ServiceState& f) +{ + os << servicestate_to_str(f); + return os; +} -class Service // : public IService +class Service { public: virtual ~Service(); + const std::string& service_name() const; + + bool is_service_startable() const; + + bool is_running() const; + + const ServiceState& state() const; + void service_start(); void service_await_live(); void service_stop(); void service_kill(); void service_await_join(); - bool is_service_startable() const; - const ServiceState& state() const; - protected: + Service(std::string service_name); + + // Prefix to use for debug messages. Contains useful information about the service + std::string debug_prefix() const; + void call_in_destructor(); void service_set_description(std::string description); private: - bool forward_state(ServiceState new_state); + // Advances the state. New state value must be greater than or equal to current state. Using a value less than the + // current state will generate an error. Use assert_forward = false to require that the state advances. Normally, + // same states are fine + bool advance_state(ServiceState new_state, bool assert_state_change = false); + + // Ensures the state is at least the current value or higher. Does not change the state if the value is less than or + // equal the current state + bool ensure_state(ServiceState desired_state); virtual void do_service_start() = 0; virtual void do_service_await_live() = 0; @@ -71,8 +120,21 @@ class Service // : public IService virtual void do_service_await_join() = 0; ServiceState m_state{ServiceState::Initialized}; - std::string m_description{"mrc::service"}; - mutable std::mutex m_mutex; + std::string m_service_name{"mrc::Service"}; + + // This future is set in `service_await_live` and is used to wait for the service to to be live. We use a future + // here in case it is called multiple times, so that all callers will all be released when the service is live. + SharedFuture m_live_future; + + // This future is set in `service_await_join` and is used to wait for the service to complete. We use a future here + // in case join is called multiple times, so that all callers will all be released when the service completes. + SharedFuture m_completed_future; + + bool m_service_await_live_called{false}; + bool m_service_await_join_called{false}; + bool m_call_in_destructor_called{false}; + + mutable RecursiveMutex m_mutex; }; } // namespace mrc diff --git a/cpp/mrc/src/internal/system/device_info.cpp b/cpp/mrc/src/internal/system/device_info.cpp index b9f3461f2..2ead6abee 100644 --- a/cpp/mrc/src/internal/system/device_info.cpp +++ b/cpp/mrc/src/internal/system/device_info.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -150,7 +150,7 @@ struct NvmlState m_nvml_handle = std::make_unique(); } catch (std::runtime_error e) { - LOG(WARNING) << "NVML: " << e.what() << ". Setting DeviceCount to 0, CUDA will not be initialized"; + VLOG(1) << "NVML: " << e.what() << ". Setting DeviceCount to 0, CUDA will not be initialized"; return; } diff --git a/cpp/mrc/src/internal/system/fiber_manager.cpp b/cpp/mrc/src/internal/system/fiber_manager.cpp index 2eec52f12..5a73dcab7 100644 --- a/cpp/mrc/src/internal/system/fiber_manager.cpp +++ b/cpp/mrc/src/internal/system/fiber_manager.cpp @@ -26,9 +26,11 @@ #include "mrc/exceptions/runtime_error.hpp" #include "mrc/options/fiber_pool.hpp" #include "mrc/options/options.hpp" +#include "mrc/utils/string_utils.hpp" #include #include +#include namespace mrc::system { @@ -44,7 +46,7 @@ FiberManager::FiberManager(const ThreadingResources& resources) : m_cpu_set(reso topology.cpu_set().for_each_bit([&](std::int32_t idx, std::int32_t cpu_id) { DVLOG(10) << "initializing fiber queue " << idx << " of " << cpu_count << " on cpu_id " << cpu_id; - m_queues[cpu_id] = std::make_unique(resources, cpu_id); + m_queues[cpu_id] = std::make_unique(resources, cpu_id, MRC_CONCAT_STR("fibq[" << idx << "]")); }); } diff --git a/cpp/mrc/src/internal/system/fiber_task_queue.cpp b/cpp/mrc/src/internal/system/fiber_task_queue.cpp index 709be264e..5af806d21 100644 --- a/cpp/mrc/src/internal/system/fiber_task_queue.cpp +++ b/cpp/mrc/src/internal/system/fiber_task_queue.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -39,12 +38,16 @@ namespace mrc::system { -FiberTaskQueue::FiberTaskQueue(const ThreadingResources& resources, CpuSet cpu_affinity, std::size_t channel_size) : +FiberTaskQueue::FiberTaskQueue(const ThreadingResources& resources, + CpuSet cpu_affinity, + std::string thread_name, + std::size_t channel_size) : m_queue(channel_size), m_cpu_affinity(std::move(cpu_affinity)), - m_thread(resources.make_thread("fiberq", m_cpu_affinity, [this] { + m_thread(resources.make_thread(std::move(thread_name), m_cpu_affinity, [this] { main(); })) + { DVLOG(10) << "awaiting fiber task queue worker thread running on cpus " << m_cpu_affinity; enqueue([] {}).get(); @@ -106,7 +109,7 @@ void FiberTaskQueue::launch(task_pkg_t&& pkg) const boost::fibers::fiber fiber(std::move(pkg.first)); auto& props(fiber.properties()); props.set_priority(pkg.second.priority); - DVLOG(10) << *this << ": created fiber " << fiber.get_id() << " with priority " << pkg.second.priority; + DVLOG(20) << *this << ": created fiber " << fiber.get_id() << " with priority " << pkg.second.priority; fiber.detach(); } diff --git a/cpp/mrc/src/internal/system/fiber_task_queue.hpp b/cpp/mrc/src/internal/system/fiber_task_queue.hpp index c58c8190b..ccd7499b5 100644 --- a/cpp/mrc/src/internal/system/fiber_task_queue.hpp +++ b/cpp/mrc/src/internal/system/fiber_task_queue.hpp @@ -27,6 +27,7 @@ #include #include +#include #include namespace mrc::system { @@ -36,7 +37,10 @@ class ThreadingResources; class FiberTaskQueue final : public core::FiberTaskQueue { public: - FiberTaskQueue(const ThreadingResources& resources, CpuSet cpu_affinity, std::size_t channel_size = 64); + FiberTaskQueue(const ThreadingResources& resources, + CpuSet cpu_affinity, + std::string thread_name, + std::size_t channel_size = 64); ~FiberTaskQueue() final; DELETE_COPYABILITY(FiberTaskQueue); diff --git a/cpp/mrc/src/internal/system/host_partition_provider.cpp b/cpp/mrc/src/internal/system/host_partition_provider.cpp index 953833435..42a579547 100644 --- a/cpp/mrc/src/internal/system/host_partition_provider.cpp +++ b/cpp/mrc/src/internal/system/host_partition_provider.cpp @@ -17,6 +17,7 @@ #include "internal/system/host_partition_provider.hpp" +#include "internal/system/host_partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" @@ -25,7 +26,6 @@ #include namespace mrc::system { -class HostPartition; HostPartitionProvider::HostPartitionProvider(const SystemProvider& _system, std::size_t _host_partition_id) : SystemProvider(_system), diff --git a/cpp/mrc/src/internal/system/partition_provider.cpp b/cpp/mrc/src/internal/system/partition_provider.cpp index 33feb2c77..7597da9cc 100644 --- a/cpp/mrc/src/internal/system/partition_provider.cpp +++ b/cpp/mrc/src/internal/system/partition_provider.cpp @@ -17,6 +17,7 @@ #include "internal/system/partition_provider.hpp" +#include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" diff --git a/cpp/mrc/src/internal/system/thread.cpp b/cpp/mrc/src/internal/system/thread.cpp index 413e86f6c..04345006f 100644 --- a/cpp/mrc/src/internal/system/thread.cpp +++ b/cpp/mrc/src/internal/system/thread.cpp @@ -90,13 +90,13 @@ void ThreadResources::initialize_thread(const std::string& desc, const CpuSet& c { std::stringstream ss; ss << "cpu_id: " << cpu_affinity.first(); - affinity = ss.str(); + affinity = MRC_CONCAT_STR("cpu[" << cpu_affinity.str() << "]"); } else { std::stringstream ss; ss << "cpus: " << cpu_affinity.str(); - affinity = ss.str(); + affinity = MRC_CONCAT_STR("cpu[" << cpu_affinity.str() << "]"); auto numa_set = topology.numaset_for_cpuset(cpu_affinity); if (numa_set.weight() != 1) { @@ -110,13 +110,13 @@ void ThreadResources::initialize_thread(const std::string& desc, const CpuSet& c DVLOG(10) << "tid: " << std::this_thread::get_id() << "; setting cpu affinity to " << affinity; auto rc = hwloc_set_cpubind(topology.handle(), &cpu_affinity.bitmap(), HWLOC_CPUBIND_THREAD); CHECK_NE(rc, -1); - set_current_thread_name(MRC_CONCAT_STR("[" << desc << "; " << affinity << "]")); + set_current_thread_name(MRC_CONCAT_STR(desc << ";" << affinity)); } else { DVLOG(10) << "thread_binding is disabled; tid: " << std::this_thread::get_id() << " will use the affinity of caller"; - set_current_thread_name(MRC_CONCAT_STR("[" << desc << "; tid:" << std::this_thread::get_id() << "]")); + set_current_thread_name(MRC_CONCAT_STR(desc << ";tid[" << std::this_thread::get_id() << "]")); } // todo(ryan) - enable thread/memory binding should be a system option, not specifically a fiber_pool option diff --git a/cpp/mrc/src/internal/system/threading_resources.cpp b/cpp/mrc/src/internal/system/threading_resources.cpp index 27001092a..1e0f8c16b 100644 --- a/cpp/mrc/src/internal/system/threading_resources.cpp +++ b/cpp/mrc/src/internal/system/threading_resources.cpp @@ -19,9 +19,10 @@ #include "internal/system/fiber_manager.hpp" +#include "mrc/types.hpp" + #include -#include #include namespace mrc::system { diff --git a/cpp/mrc/src/internal/ucx/receive_manager.cpp b/cpp/mrc/src/internal/ucx/receive_manager.cpp index 2796bf84e..70cda928a 100644 --- a/cpp/mrc/src/internal/ucx/receive_manager.cpp +++ b/cpp/mrc/src/internal/ucx/receive_manager.cpp @@ -23,7 +23,6 @@ #include "mrc/types.hpp" #include -#include #include #include // for launch, launch::post #include // for ucp_tag_probe_nb, ucp_tag_recv_info diff --git a/cpp/mrc/src/internal/ucx/ucx_resources.cpp b/cpp/mrc/src/internal/ucx/ucx_resources.cpp index 458dd9814..1ce368662 100644 --- a/cpp/mrc/src/internal/ucx/ucx_resources.cpp +++ b/cpp/mrc/src/internal/ucx/ucx_resources.cpp @@ -30,7 +30,6 @@ #include "mrc/cuda/common.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/utils/contains.hpp b/cpp/mrc/src/internal/utils/contains.hpp index 61d613692..690b8e9b2 100644 --- a/cpp/mrc/src/internal/utils/contains.hpp +++ b/cpp/mrc/src/internal/utils/contains.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,13 +29,15 @@ bool contains(const ContainerT& container, const KeyT& key) } template -class KeyIterator : public std::iterator +class KeyIterator { public: + using iterator_category_t = std::bidirectional_iterator_tag; + using value_type = C::key_type; + using difference_type = C::difference_type; + using pointer_t = C::pointer; + using reference_t = C::reference; + KeyIterator() = default; explicit KeyIterator(typename C::const_iterator it) : m_iter(it) {} diff --git a/cpp/mrc/src/internal/utils/parse_config.cpp b/cpp/mrc/src/internal/utils/parse_config.cpp index 7d49ce615..780739eb3 100644 --- a/cpp/mrc/src/internal/utils/parse_config.cpp +++ b/cpp/mrc/src/internal/utils/parse_config.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,29 +19,15 @@ #include "./parse_ints.hpp" +#include "mrc/utils/string_utils.hpp" // for split_string_to_vector + #include #include // for uint32_t #include // for atoi -#include #include #include // for move -namespace { - -std::vector split_string_on(std::string str, char delim) -{ - std::vector tokens; - std::istringstream f(str); - std::string s; - while (std::getline(f, s, delim)) - { - tokens.push_back(s); - } - return tokens; -} -} // namespace - namespace mrc { ConfigurationMap parse_config(std::string config_str) @@ -50,9 +36,9 @@ ConfigurationMap parse_config(std::string config_str) bool left_wildcard = false; - for (const auto& entry : split_string_on(config_str, ';')) + for (const auto& entry : split_string_to_vector(config_str, ";")) { - auto tokens = split_string_on(entry, ':'); + auto tokens = split_string_to_vector(entry, ":"); int concurrency = 1; std::vector s; @@ -76,7 +62,7 @@ ConfigurationMap parse_config(std::string config_str) concurrency = std::atoi(tokens[1].c_str()); case 1: // parse segments - s = split_string_on(tokens[0], ','); + s = split_string_to_vector(tokens[0], ","); segments.insert(s.begin(), s.end()); break; @@ -86,7 +72,7 @@ ConfigurationMap parse_config(std::string config_str) "::;[repeated]"); } - config.push_back(std::make_tuple(std::move(segments), concurrency, std::move(groups))); + config.emplace_back(std::move(segments), concurrency, std::move(groups)); } return config; diff --git a/cpp/mrc/src/internal/utils/parse_ints.cpp b/cpp/mrc/src/internal/utils/parse_ints.cpp index 60c716982..c999339e2 100644 --- a/cpp/mrc/src/internal/utils/parse_ints.cpp +++ b/cpp/mrc/src/internal/utils/parse_ints.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,8 @@ #include "./parse_ints.hpp" +#include "mrc/utils/string_utils.hpp" // for split_string_to_vector + #include #include @@ -31,17 +33,6 @@ int convert_string2_int(const std::string& str) return x; } -std::vector split_string_to_array(const std::string& str, char splitter) -{ - std::vector tokens; - std::stringstream ss(str); - std::string temp; - while (getline(ss, temp, splitter)) // split into new "lines" based on character - { - tokens.push_back(temp); - } - return tokens; -} } // namespace namespace mrc { @@ -49,10 +40,10 @@ namespace mrc { std::vector parse_ints(const std::string& data) { std::vector result; - std::vector tokens = split_string_to_array(data, ','); + std::vector tokens = split_string_to_vector(data, ","); for (auto& token : tokens) { - std::vector range = split_string_to_array(token, '-'); + std::vector range = split_string_to_vector(token, "-"); if (range.size() == 1) { result.push_back(convert_string2_int(range[0])); diff --git a/cpp/mrc/src/public/core/thread.cpp b/cpp/mrc/src/public/core/thread.cpp index 0553a8fe0..f81ecb38d 100644 --- a/cpp/mrc/src/public/core/thread.cpp +++ b/cpp/mrc/src/public/core/thread.cpp @@ -20,7 +20,6 @@ #include "mrc/coroutines/thread_pool.hpp" #include -#include #include #include #include diff --git a/cpp/mrc/src/public/coroutines/io_scheduler.cpp b/cpp/mrc/src/public/coroutines/io_scheduler.cpp new file mode 100644 index 000000000..a52f7a756 --- /dev/null +++ b/cpp/mrc/src/public/coroutines/io_scheduler.cpp @@ -0,0 +1,585 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/io_scheduler.hpp" + +#include "mrc/coroutines/poll.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/task_container.hpp" +#include "mrc/coroutines/time.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::chrono_literals; + +namespace mrc::coroutines { + +std::shared_ptr IoScheduler::get_instance() +{ + static std::shared_ptr instance; + static std::mutex instance_mutex{}; + + if (instance == nullptr) + { + auto lock = std::lock_guard(instance_mutex); + + if (instance == nullptr) + { + instance = std::make_shared(); + } + } + + return instance; +} + +IoScheduler::IoScheduler(Options opts) : + m_opts(std::move(opts)), + m_epoll_fd(epoll_create1(EPOLL_CLOEXEC)), + m_shutdown_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)), + m_timer_fd(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)), + m_schedule_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)), + m_owned_tasks(new mrc::coroutines::TaskContainer(std::shared_ptr(this, [](auto _) {}))) +{ + if (opts.execution_strategy == ExecutionStrategy::process_tasks_on_thread_pool) + { + m_thread_pool = std::make_unique(std::move(m_opts.pool)); + } + + epoll_event e{}; + e.events = EPOLLIN; + + e.data.ptr = const_cast(MShutdownPtr); + epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_shutdown_fd, &e); + + e.data.ptr = const_cast(MTimerPtr); + epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_timer_fd, &e); + + e.data.ptr = const_cast(MSchedulePtr); + epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_schedule_fd, &e); + + if (m_opts.thread_strategy == ThreadStrategy::spawn) + { + m_io_thread = std::thread([this]() { + process_events_dedicated_thread(); + }); + } + // else manual mode, the user must call process_events. +} + +IoScheduler::~IoScheduler() +{ + shutdown(); + + if (m_io_thread.joinable()) + { + m_io_thread.join(); + } + + if (m_epoll_fd != -1) + { + close(m_epoll_fd); + m_epoll_fd = -1; + } + if (m_timer_fd != -1) + { + close(m_timer_fd); + m_timer_fd = -1; + } + if (m_schedule_fd != -1) + { + close(m_schedule_fd); + m_schedule_fd = -1; + } + + if (m_owned_tasks != nullptr) + { + delete static_cast(m_owned_tasks); + m_owned_tasks = nullptr; + } +} + +auto IoScheduler::process_events(std::chrono::milliseconds timeout) -> std::size_t +{ + process_events_manual(timeout); + return size(); +} + +auto IoScheduler::schedule(mrc::coroutines::Task&& task) -> void +{ + auto* ptr = static_cast(m_owned_tasks); + ptr->start(std::move(task)); +} + +auto IoScheduler::schedule_after(std::chrono::milliseconds amount) -> mrc::coroutines::Task +{ + return yield_for(amount); +} + +auto IoScheduler::schedule_at(time_point_t time) -> mrc::coroutines::Task +{ + return yield_until(time); +} + +auto IoScheduler::yield_for(std::chrono::milliseconds amount) -> mrc::coroutines::Task +{ + if (amount <= 0ms) + { + co_await schedule(); + } + else + { + // Yield/timeout tasks are considered live in the scheduler and must be accounted for. Note + // that if the user gives an invalid amount and schedule() is directly called it will account + // for the scheduled task there. + m_size.fetch_add(1, std::memory_order::release); + + // Yielding does not requiring setting the timer position on the poll info since + // it doesn't have a corresponding 'event' that can trigger, it always waits for + // the timeout to occur before resuming. + + detail::PollInfo pi{}; + add_timer_token(clock_t::now() + amount, pi); + co_await pi; + + m_size.fetch_sub(1, std::memory_order::release); + } + co_return; +} + +auto IoScheduler::yield_until(time_point_t time) -> mrc::coroutines::Task +{ + auto now = clock_t::now(); + + // If the requested time is in the past (or now!) bail out! + if (time <= now) + { + co_await schedule(); + } + else + { + m_size.fetch_add(1, std::memory_order::release); + + auto amount = std::chrono::duration_cast(time - now); + + detail::PollInfo pi{}; + add_timer_token(now + amount, pi); + co_await pi; + + m_size.fetch_sub(1, std::memory_order::release); + } + co_return; +} + +auto IoScheduler::poll(fd_t fd, mrc::coroutines::PollOperation op, std::chrono::milliseconds timeout) + -> mrc::coroutines::Task +{ + // Because the size will drop when this coroutine suspends every poll needs to undo the subtraction + // on the number of active tasks in the scheduler. When this task is resumed by the event loop. + m_size.fetch_add(1, std::memory_order::release); + + // Setup two events, a timeout event and the actual poll for op event. + // Whichever triggers first will delete the other to guarantee only one wins. + // The resume token will be set by the scheduler to what the event turned out to be. + + bool timeout_requested = (timeout > 0ms); + + detail::PollInfo pi{}; + pi.m_fd = fd; + + if (timeout_requested) + { + pi.m_timer_pos = add_timer_token(clock_t::now() + timeout, pi); + } + + epoll_event e{}; + e.events = static_cast(op) | EPOLLONESHOT | EPOLLRDHUP; + e.data.ptr = π + if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, fd, &e) == -1) + { + std::cerr << "epoll ctl error on fd " << fd << "\n"; + } + + // The event loop will 'clean-up' whichever event didn't win since the coroutine is scheduled + // onto the thread poll its possible the other type of event could trigger while its waiting + // to execute again, thus restarting the coroutine twice, that would be quite bad. + auto result = co_await pi; + m_size.fetch_sub(1, std::memory_order::release); + co_return result; +} + +auto IoScheduler::shutdown() noexcept -> void +{ + // Only allow shutdown to occur once. + if (not m_shutdown_requested.exchange(true, std::memory_order::acq_rel)) + { + if (m_thread_pool != nullptr) + { + m_thread_pool->shutdown(); + } + + // Signal the event loop to stop asap, triggering the event fd is safe. + uint64_t value{1}; + auto written = ::write(m_shutdown_fd, &value, sizeof(value)); + (void)written; + + if (m_io_thread.joinable()) + { + m_io_thread.join(); + } + } +} + +auto IoScheduler::garbage_collect() noexcept -> void +{ + auto* ptr = static_cast(m_owned_tasks); + ptr->garbage_collect(); +} + +auto IoScheduler::process_events_manual(std::chrono::milliseconds timeout) -> void +{ + bool expected{false}; + if (m_io_processing.compare_exchange_strong(expected, true, std::memory_order::release, std::memory_order::relaxed)) + { + process_events_execute(timeout); + m_io_processing.exchange(false, std::memory_order::release); + } +} + +auto IoScheduler::process_events_dedicated_thread() -> void +{ + if (m_opts.on_io_thread_start_functor != nullptr) + { + m_opts.on_io_thread_start_functor(); + } + + m_io_processing.exchange(true, std::memory_order::release); + // Execute tasks until stopped or there are no more tasks to complete. + while (!m_shutdown_requested.load(std::memory_order::acquire) || size() > 0) + { + process_events_execute(MDefaultTimeout); + } + m_io_processing.exchange(false, std::memory_order::release); + + if (m_opts.on_io_thread_stop_functor != nullptr) + { + m_opts.on_io_thread_stop_functor(); + } +} + +auto IoScheduler::process_events_execute(std::chrono::milliseconds timeout) -> void +{ + auto event_count = epoll_wait(m_epoll_fd, m_events.data(), MMaxEvents, timeout.count()); + if (event_count > 0) + { + for (std::size_t i = 0; i < static_cast(event_count); ++i) + { + epoll_event& event = m_events[i]; + void* handle_ptr = event.data.ptr; + + if (handle_ptr == MTimerPtr) + { + // Process all events that have timed out. + process_timeout_execute(); + } + else if (handle_ptr == MSchedulePtr) + { + // Process scheduled coroutines. + process_scheduled_execute_inline(); + } + else if (handle_ptr == MShutdownPtr) [[unlikely]] + { + // Nothing to do , just needed to wake-up and smell the flowers + } + else + { + // Individual poll task wake-up. + process_event_execute(static_cast(handle_ptr), event_to_poll_status(event.events)); + } + } + } + + // Its important to not resume any handles until the full set is accounted for. If a timeout + // and an event for the same handle happen in the same epoll_wait() call then inline processing + // will destruct the poll_info object before the second event is handled. This is also possible + // with thread pool processing, but probably has an extremely low chance of occuring due to + // the thread switch required. If m_max_events == 1 this would be unnecessary. + + if (!m_handles_to_resume.empty()) + { + if (m_opts.execution_strategy == ExecutionStrategy::process_tasks_inline) + { + for (auto& handle : m_handles_to_resume) + { + handle.resume(); + } + } + else + { + m_thread_pool->resume(m_handles_to_resume); + } + + m_handles_to_resume.clear(); + } +} + +auto IoScheduler::event_to_poll_status(uint32_t events) -> PollStatus +{ + if (((events & EPOLLIN) != 0) || ((events & EPOLLOUT) != 0)) + { + return PollStatus::event; + } + + if ((events & EPOLLERR) != 0) + { + return PollStatus::error; + } + + if (((events & EPOLLRDHUP) != 0) || ((events & EPOLLHUP) != 0)) + { + return PollStatus::closed; + } + + throw std::runtime_error{"invalid epoll state"}; +} + +auto IoScheduler::process_scheduled_execute_inline() -> void +{ + std::vector> tasks{}; + { + // Acquire the entire list, and then reset it. + std::scoped_lock lk{m_scheduled_tasks_mutex}; + tasks.swap(m_scheduled_tasks); + + // Clear the schedule eventfd if this is a scheduled task. + eventfd_t value{0}; + eventfd_read(m_schedule_fd, &value); + + // Clear the in memory flag to reduce eventfd_* calls on scheduling. + m_schedule_fd_triggered.exchange(false, std::memory_order::release); + } + + // This set of handles can be safely resumed now since they do not have a corresponding timeout event. + for (auto& task : tasks) + { + task.resume(); + } + m_size.fetch_sub(tasks.size(), std::memory_order::release); +} + +auto IoScheduler::process_event_execute(detail::PollInfo* pi, PollStatus status) -> void +{ + if (!pi->m_processed) + { + std::atomic_thread_fence(std::memory_order::acquire); + // Its possible the event and the timeout occurred in the same epoll, make sure only one + // is ever processed, the other is discarded. + pi->m_processed = true; + + // Given a valid fd always remove it from epoll so the next poll can blindly EPOLL_CTL_ADD. + if (pi->m_fd != -1) + { + epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, pi->m_fd, nullptr); + } + + // Since this event triggered, remove its corresponding timeout if it has one. + if (pi->m_timer_pos.has_value()) + { + remove_timer_token(pi->m_timer_pos.value()); + } + + pi->m_poll_status = status; + + while (pi->m_awaiting_coroutine == nullptr) + { + std::atomic_thread_fence(std::memory_order::acquire); + } + + m_handles_to_resume.emplace_back(pi->m_awaiting_coroutine); + } +} + +auto IoScheduler::process_timeout_execute() -> void +{ + std::vector poll_infos{}; + auto now = clock_t::now(); + + { + std::scoped_lock lk{m_timed_events_mutex}; + while (!m_timed_events.empty()) + { + auto first = m_timed_events.begin(); + auto [tp, pi] = *first; + + if (tp <= now) + { + m_timed_events.erase(first); + poll_infos.emplace_back(pi); + } + else + { + break; + } + } + } + + for (auto* pi : poll_infos) + { + if (!pi->m_processed) + { + // Its possible the event and the timeout occurred in the same epoll, make sure only one + // is ever processed, the other is discarded. + pi->m_processed = true; + + // Since this timed out, remove its corresponding event if it has one. + if (pi->m_fd != -1) + { + epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, pi->m_fd, nullptr); + } + + while (pi->m_awaiting_coroutine == nullptr) + { + std::atomic_thread_fence(std::memory_order::acquire); + // std::cerr << "process_event_execute() has a nullptr event\n"; + } + + m_handles_to_resume.emplace_back(pi->m_awaiting_coroutine); + pi->m_poll_status = mrc::coroutines::PollStatus::timeout; + } + } + + // Update the time to the next smallest time point, re-take the current now time + // since updating and resuming tasks could shift the time. + update_timeout(clock_t::now()); +} + +auto IoScheduler::add_timer_token(time_point_t tp, detail::PollInfo& pi) -> timed_events_t::iterator +{ + std::scoped_lock lk{m_timed_events_mutex}; + auto pos = m_timed_events.emplace(tp, &pi); + + // If this item was inserted as the smallest time point, update the timeout. + if (pos == m_timed_events.begin()) + { + update_timeout(clock_t::now()); + } + + return pos; +} + +auto IoScheduler::remove_timer_token(timed_events_t::iterator pos) -> void +{ + { + std::scoped_lock lk{m_timed_events_mutex}; + auto is_first = (m_timed_events.begin() == pos); + + m_timed_events.erase(pos); + + // If this was the first item, update the timeout. It would be acceptable to just let it + // also fire the timeout as the event loop will ignore it since nothing will have timed + // out but it feels like the right thing to do to update it to the correct timeout value. + if (is_first) + { + update_timeout(clock_t::now()); + } + } +} + +auto IoScheduler::update_timeout(time_point_t now) -> void +{ + if (!m_timed_events.empty()) + { + auto& [tp, pi] = *m_timed_events.begin(); + + auto amount = tp - now; + + auto seconds = std::chrono::duration_cast(amount); + amount -= seconds; + auto nanoseconds = std::chrono::duration_cast(amount); + + // As a safeguard if both values end up as zero (or negative) then trigger the timeout + // immediately as zero disarms timerfd according to the man pages and negative values + // will result in an error return value. + if (seconds <= 0s) + { + seconds = 0s; + if (nanoseconds <= 0ns) + { + // just trigger "immediately"! + nanoseconds = 1ns; + } + } + + itimerspec ts{}; + ts.it_value.tv_sec = seconds.count(); + ts.it_value.tv_nsec = nanoseconds.count(); + + if (timerfd_settime(m_timer_fd, 0, &ts, nullptr) == -1) + { + std::cerr << "Failed to set timerfd errorno=[" << std::string{strerror(errno)} << "]."; + } + } + else + { + // Setting these values to zero disables the timer. + itimerspec ts{}; + ts.it_value.tv_sec = 0; + ts.it_value.tv_nsec = 0; + if (timerfd_settime(m_timer_fd, 0, &ts, nullptr) == -1) + { + std::cerr << "Failed to set timerfd errorno=[" << std::string{strerror(errno)} << "]."; + } + } +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/coroutines/task_container.cpp b/cpp/mrc/src/public/coroutines/task_container.cpp new file mode 100644 index 000000000..85a765517 --- /dev/null +++ b/cpp/mrc/src/public/coroutines/task_container.cpp @@ -0,0 +1,186 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/task_container.hpp" + +#include "mrc/coroutines/scheduler.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +TaskContainer::TaskContainer(std::shared_ptr e, std::size_t max_concurrent_tasks) : + m_scheduler_lifetime(std::move(e)), + m_scheduler(m_scheduler_lifetime.get()), + m_max_concurrent_tasks(max_concurrent_tasks) +{ + if (m_scheduler_lifetime == nullptr) + { + throw std::runtime_error{"TaskContainer cannot have a nullptr executor"}; + } +} + +TaskContainer::~TaskContainer() +{ + // This will hang the current thread.. but if tasks are not complete thats also pretty bad. + while (not empty()) + { + garbage_collect(); + } +} + +auto TaskContainer::start(Task&& user_task, GarbageCollectPolicy cleanup) -> void +{ + auto lock = std::unique_lock(m_mutex); + + m_size += 1; + + if (cleanup == GarbageCollectPolicy::yes) + { + gc_internal(); + } + + // Store the task inside a cleanup task for self deletion. + auto pos = m_tasks.emplace(m_tasks.end(), std::nullopt); + auto task = make_cleanup_task(std::move(user_task), pos); + *pos = std::move(task); + m_next_tasks.push(pos); + + auto current_task_count = m_size - m_next_tasks.size(); + + if (m_max_concurrent_tasks == 0 or current_task_count < m_max_concurrent_tasks) + { + try_start_next_task(std::move(lock)); + } +} + +auto TaskContainer::garbage_collect() -> std::size_t +{ + auto lock = std::scoped_lock(m_mutex); + return gc_internal(); +} + +auto TaskContainer::size() -> std::size_t +{ + auto lock = std::scoped_lock(m_mutex); + return m_size; +} + +auto TaskContainer::empty() -> bool +{ + return size() == 0; +} + +auto TaskContainer::capacity() -> std::size_t +{ + auto lock = std::scoped_lock(m_mutex); + return m_tasks.size(); +} + +auto TaskContainer::garbage_collect_and_yield_until_empty() -> Task +{ + while (not empty()) + { + garbage_collect(); + co_await m_scheduler->yield(); + } +} + +TaskContainer::TaskContainer(Scheduler& e) : m_scheduler(&e) {} +auto TaskContainer::gc_internal() -> std::size_t +{ + if (m_tasks_to_delete.empty()) + { + return 0; + } + + std::size_t delete_count = m_tasks_to_delete.size(); + + for (const auto& pos : m_tasks_to_delete) + { + // Destroy the cleanup task and the user task. + if (pos->has_value()) + { + pos->value().destroy(); + } + + m_tasks.erase(pos); + } + + m_tasks_to_delete.clear(); + + return delete_count; +} + +void TaskContainer::try_start_next_task(std::unique_lock lock) +{ + if (m_next_tasks.empty()) + { + // no tasks to process + return; + } + + auto pos = m_next_tasks.front(); + m_next_tasks.pop(); + + // release the lock before starting the task + lock.unlock(); + + pos->value().resume(); +} + +auto TaskContainer::make_cleanup_task(Task user_task, task_position_t pos) -> Task +{ + // Immediately move the task onto the executor. + co_await m_scheduler->yield(); + + try + { + // Await the users task to complete. + co_await user_task; + } catch (const std::exception& e) + { + // TODO(MDD): what would be a good way to report this to the user...? Catching here is required + // since the co_await will unwrap the unhandled exception on the task. + // The user's task should ideally be wrapped in a catch all and handle it themselves, but + // that cannot be guaranteed. + LOG(ERROR) << "coro::task_container user_task had an unhandled exception e.what()= " << e.what() << "\n"; + } catch (...) + { + // don't crash if they throw something that isn't derived from std::exception + LOG(ERROR) << "coro::task_container user_task had unhandle exception, not derived from std::exception.\n"; + } + + auto lock = std::unique_lock(m_mutex); + m_tasks_to_delete.push_back(pos); + // This has to be done within scope lock to make sure this coroutine task completes before the + // task container object destructs -- if it was waiting on .empty() to become true. + m_size -= 1; + + try_start_next_task(std::move(lock)); + + co_return; +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/coroutines/test_scheduler.cpp b/cpp/mrc/src/public/coroutines/test_scheduler.cpp new file mode 100644 index 000000000..fba53c250 --- /dev/null +++ b/cpp/mrc/src/public/coroutines/test_scheduler.cpp @@ -0,0 +1,115 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/test_scheduler.hpp" + +#include +#include + +namespace mrc::coroutines { + +TestScheduler::Operation::Operation(TestScheduler* self, std::chrono::time_point time) : + m_self(self), + m_time(time) +{} + +bool TestScheduler::ItemCompare::operator()(item_t& lhs, item_t& rhs) +{ + return lhs.second > rhs.second; +} + +void TestScheduler::Operation::await_suspend(std::coroutine_handle<> handle) +{ + m_self->m_queue.emplace(std::move(handle), m_time); +} + +void TestScheduler::resume(std::coroutine_handle<> handle) noexcept +{ + m_queue.emplace(std::move(handle), std::chrono::steady_clock::now()); +} + +mrc::coroutines::Task<> TestScheduler::yield() +{ + co_return co_await TestScheduler::Operation{this, m_time}; +} + +mrc::coroutines::Task<> TestScheduler::yield_for(std::chrono::milliseconds time) +{ + co_return co_await TestScheduler::Operation{this, m_time + time}; +} + +mrc::coroutines::Task<> TestScheduler::yield_until(std::chrono::time_point time) +{ + co_return co_await TestScheduler::Operation{this, time}; +} + +std::chrono::time_point TestScheduler::time() +{ + return m_time; +} + +bool TestScheduler::resume_next() +{ + using namespace std::chrono_literals; + + if (m_queue.empty()) + { + return false; + } + + auto handle = m_queue.top(); + + m_queue.pop(); + + m_time = handle.second; + + if (not m_queue.empty()) + { + m_time = m_queue.top().second; + } + + handle.first.resume(); + + return true; +} + +bool TestScheduler::resume_for(std::chrono::milliseconds time) +{ + return resume_until(m_time + time); +} + +bool TestScheduler::resume_until(std::chrono::time_point time) +{ + m_time = time; + + while (not m_queue.empty()) + { + if (m_queue.top().second <= m_time) + { + m_queue.top().first.resume(); + m_queue.pop(); + } + else + { + return true; + } + } + + return false; +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/coroutines/thread_pool.cpp b/cpp/mrc/src/public/coroutines/thread_pool.cpp index e2724409e..805a64d2a 100644 --- a/cpp/mrc/src/public/coroutines/thread_pool.cpp +++ b/cpp/mrc/src/public/coroutines/thread_pool.cpp @@ -39,7 +39,6 @@ #include "mrc/coroutines/thread_pool.hpp" #include -#include #include #include diff --git a/cpp/mrc/src/public/exceptions/exception_catcher.cpp b/cpp/mrc/src/public/exceptions/exception_catcher.cpp new file mode 100644 index 000000000..c139436f7 --- /dev/null +++ b/cpp/mrc/src/public/exceptions/exception_catcher.cpp @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace mrc { + +void ExceptionCatcher::push_exception(std::exception_ptr ex) +{ + auto lock = std::lock_guard(m_mutex); + m_exceptions.push(ex); +} + +bool ExceptionCatcher::has_exception() +{ + auto lock = std::lock_guard(m_mutex); + return not m_exceptions.empty(); +} + +void ExceptionCatcher::rethrow_next_exception() +{ + auto lock = std::lock_guard(m_mutex); + + if (m_exceptions.empty()) + { + return; + } + + auto ex = m_exceptions.front(); + + m_exceptions.pop(); + + std::rethrow_exception(ex); +} + +} // namespace mrc diff --git a/cpp/mrc/src/public/modules/sample_modules.cpp b/cpp/mrc/src/public/modules/sample_modules.cpp index fe850615c..405dcfe3c 100644 --- a/cpp/mrc/src/public/modules/sample_modules.cpp +++ b/cpp/mrc/src/public/modules/sample_modules.cpp @@ -26,10 +26,8 @@ #include -#include #include #include -#include namespace mrc::modules { diff --git a/cpp/mrc/src/public/utils/string_utils.cpp b/cpp/mrc/src/public/utils/string_utils.cpp new file mode 100644 index 000000000..5ed572f4c --- /dev/null +++ b/cpp/mrc/src/public/utils/string_utils.cpp @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/utils/string_utils.hpp" + +#include // for split +// We already have included we don't need these others, it is also the only public header +// with a definition for boost::is_any_of, so even if we replaced string.hpp with these others we would still need to +// include string.hpp or a detail/ header +// IWYU pragma: no_include +// IWYU pragma: no_include +// IWYU pragma: no_include + +namespace mrc { +std::vector split_string_to_vector(const std::string& str, const std::string& delimiter) +{ + std::vector results; + boost::split(results, str, boost::is_any_of(delimiter)); + return results; +} + +} // namespace mrc diff --git a/cpp/mrc/src/tests/CMakeLists.txt b/cpp/mrc/src/tests/CMakeLists.txt index 9a746e718..8ef8676fe 100644 --- a/cpp/mrc/src/tests/CMakeLists.txt +++ b/cpp/mrc/src/tests/CMakeLists.txt @@ -61,6 +61,7 @@ add_executable(test_mrc_private test_resources.cpp test_reusable_pool.cpp test_runnable.cpp + test_service.cpp test_system.cpp test_topology.cpp test_ucx.cpp diff --git a/cpp/mrc/src/tests/nodes/common_nodes.cpp b/cpp/mrc/src/tests/nodes/common_nodes.cpp index f7432f670..1c7acd824 100644 --- a/cpp/mrc/src/tests/nodes/common_nodes.cpp +++ b/cpp/mrc/src/tests/nodes/common_nodes.cpp @@ -28,13 +28,11 @@ #include #include -#include #include #include #include #include #include -#include using namespace mrc; using namespace mrc::memory::literals; diff --git a/cpp/mrc/src/tests/nodes/common_nodes.hpp b/cpp/mrc/src/tests/nodes/common_nodes.hpp index aa1ff13d2..bb19235e3 100644 --- a/cpp/mrc/src/tests/nodes/common_nodes.hpp +++ b/cpp/mrc/src/tests/nodes/common_nodes.hpp @@ -30,7 +30,6 @@ #include #include #include -#include namespace test::nodes { diff --git a/cpp/mrc/src/tests/pipelines/multi_segment.cpp b/cpp/mrc/src/tests/pipelines/multi_segment.cpp index 5157ef5d6..05dabd28c 100644 --- a/cpp/mrc/src/tests/pipelines/multi_segment.cpp +++ b/cpp/mrc/src/tests/pipelines/multi_segment.cpp @@ -18,7 +18,9 @@ #include "common_pipelines.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/pipeline/pipeline.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/egress_ports.hpp" @@ -29,11 +31,8 @@ #include #include -#include #include #include -#include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/segments/common_segments.cpp b/cpp/mrc/src/tests/segments/common_segments.cpp index 9e0f6b61d..eb1d0126d 100644 --- a/cpp/mrc/src/tests/segments/common_segments.cpp +++ b/cpp/mrc/src/tests/segments/common_segments.cpp @@ -28,7 +28,6 @@ #include #include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/test_control_plane.cpp b/cpp/mrc/src/tests/test_control_plane.cpp index 96d85945c..b49e5ae0d 100644 --- a/cpp/mrc/src/tests/test_control_plane.cpp +++ b/cpp/mrc/src/tests/test_control_plane.cpp @@ -27,6 +27,7 @@ #include "internal/runnable/runnable_resources.hpp" #include "internal/runtime/partition.hpp" #include "internal/runtime/runtime.hpp" +#include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" #include "internal/system/system_provider.hpp" @@ -43,7 +44,6 @@ #include "mrc/pubsub/subscriber.hpp" #include "mrc/types.hpp" -#include #include #include #include @@ -66,7 +66,7 @@ static auto make_runtime(std::function options_lambda = { auto resources = std::make_unique( system::SystemProvider(tests::make_system([&](Options& options) { - options.topology().user_cpuset("0-3"); + options.topology().user_cpuset("0"); options.topology().restrict_gpus(true); options.placement().resources_strategy(PlacementResources::Dedicated); options.placement().cpu_strategy(PlacementStrategy::PerMachine); @@ -85,7 +85,10 @@ class TestControlPlane : public ::testing::Test TEST_F(TestControlPlane, LifeCycle) { - auto sr = make_runtime(); + auto sr = make_runtime([](Options& options) { + options.enable_server(true); + options.architect_url("localhost:13337"); + }); auto server = std::make_unique(sr->partition(0).resources().runnable()); server->service_start(); @@ -121,6 +124,35 @@ TEST_F(TestControlPlane, SingleClientConnectDisconnect) server->service_await_join(); } +TEST_F(TestControlPlane, SingleClientConnectDisconnectSingleCore) +{ + // Similar to SingleClientConnectDisconnect except both client & server are locked to the same core + // making issue #379 easier to reproduce. + auto sr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + }); + auto server = std::make_unique(sr->partition(0).resources().runnable()); + + server->service_start(); + server->service_await_live(); + + auto cr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + options.architect_url("localhost:13337"); + }); + + // the total number of partition is system dependent + auto expected_partitions = cr->resources().system().partitions().flattened().size(); + EXPECT_EQ(cr->partition(0).resources().network()->control_plane().client().connections().instance_ids().size(), + expected_partitions); + + // destroying the resources should gracefully shutdown the data plane and the control plane. + cr.reset(); + + server->service_stop(); + server->service_await_join(); +} + TEST_F(TestControlPlane, DoubleClientConnectExchangeDisconnect) { auto sr = make_runtime(); diff --git a/cpp/mrc/src/tests/test_grpc.cpp b/cpp/mrc/src/tests/test_grpc.cpp index 68acc2913..95ef5801a 100644 --- a/cpp/mrc/src/tests/test_grpc.cpp +++ b/cpp/mrc/src/tests/test_grpc.cpp @@ -43,21 +43,16 @@ #include "mrc/runnable/runner.hpp" #include "mrc/types.hpp" -#include #include #include #include #include #include -#include -#include #include #include -#include #include #include -#include // Avoid forward declaring template specialization base classes // IWYU pragma: no_forward_declare grpc::ServerAsyncReaderWriter diff --git a/cpp/mrc/src/tests/test_memory.cpp b/cpp/mrc/src/tests/test_memory.cpp index 2544827d3..65059071d 100644 --- a/cpp/mrc/src/tests/test_memory.cpp +++ b/cpp/mrc/src/tests/test_memory.cpp @@ -38,11 +38,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include diff --git a/cpp/mrc/src/tests/test_network.cpp b/cpp/mrc/src/tests/test_network.cpp index 1a14cebf4..509649eed 100644 --- a/cpp/mrc/src/tests/test_network.cpp +++ b/cpp/mrc/src/tests/test_network.cpp @@ -38,6 +38,7 @@ #include "internal/ucx/registration_cache.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/memory/adaptors.hpp" #include "mrc/memory/buffer.hpp" #include "mrc/memory/literals.hpp" @@ -62,15 +63,11 @@ #include #include #include -#include #include #include #include -#include -#include #include #include -#include using namespace mrc; using namespace mrc::memory::literals; diff --git a/cpp/mrc/src/tests/test_next.cpp b/cpp/mrc/src/tests/test_next.cpp index da54e0a3f..1886664f7 100644 --- a/cpp/mrc/src/tests/test_next.cpp +++ b/cpp/mrc/src/tests/test_next.cpp @@ -25,6 +25,7 @@ #include "mrc/channel/ingress.hpp" #include "mrc/data/reusable_pool.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/generic_node.hpp" #include "mrc/node/generic_sink.hpp" #include "mrc/node/generic_source.hpp" @@ -64,12 +65,10 @@ #include #include #include -#include #include #include #include #include -#include using namespace mrc; @@ -573,7 +572,7 @@ TEST_F(TestNext, RxWithReusableOnNextAndOnError) }); static_assert(rxcpp::detail::is_on_next_of>::value, " "); - static_assert(rxcpp::detail::is_on_next_of>::value, " "); + static_assert(rxcpp::detail::is_on_next_of>::value, " "); auto observer = rxcpp::make_observer_dynamic( [](data_t&& int_ptr) { diff --git a/cpp/mrc/src/tests/test_pipeline.cpp b/cpp/mrc/src/tests/test_pipeline.cpp index 1b6e9c85f..0f23a6fa2 100644 --- a/cpp/mrc/src/tests/test_pipeline.cpp +++ b/cpp/mrc/src/tests/test_pipeline.cpp @@ -35,7 +35,9 @@ #include "mrc/node/queue.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" // for RxSinkBase #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" // for RxSourceBase #include "mrc/options/engine_groups.hpp" #include "mrc/options/options.hpp" #include "mrc/options/placement.hpp" @@ -67,7 +69,6 @@ #include #include #include -#include #include #include #include @@ -111,7 +112,6 @@ static void run_custom_manager(std::unique_ptr pipeline, } }); - manager->service_start(); manager->push_updates(std::move(update)); manager->service_await_join(); @@ -139,7 +139,6 @@ static void run_manager(std::unique_ptr pipeline, bool dela } }); - manager->service_start(); manager->push_updates(std::move(update)); manager->service_await_join(); diff --git a/cpp/mrc/src/tests/test_remote_descriptor.cpp b/cpp/mrc/src/tests/test_remote_descriptor.cpp index df4468897..33c85a440 100644 --- a/cpp/mrc/src/tests/test_remote_descriptor.cpp +++ b/cpp/mrc/src/tests/test_remote_descriptor.cpp @@ -39,7 +39,6 @@ #include "mrc/runtime/remote_descriptor_handle.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/tests/test_resources.cpp b/cpp/mrc/src/tests/test_resources.cpp index b6b4c953f..6f1abebd0 100644 --- a/cpp/mrc/src/tests/test_resources.cpp +++ b/cpp/mrc/src/tests/test_resources.cpp @@ -28,7 +28,6 @@ #include "mrc/options/placement.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/tests/test_runnable.cpp b/cpp/mrc/src/tests/test_runnable.cpp index c5bc0a048..6c303d8a2 100644 --- a/cpp/mrc/src/tests/test_runnable.cpp +++ b/cpp/mrc/src/tests/test_runnable.cpp @@ -47,14 +47,12 @@ #include #include #include -#include #include #include #include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/test_service.cpp b/cpp/mrc/src/tests/test_service.cpp new file mode 100644 index 000000000..39a6a6b95 --- /dev/null +++ b/cpp/mrc/src/tests/test_service.cpp @@ -0,0 +1,407 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tests/common.hpp" // IWYU pragma: associated + +#include "internal/service.hpp" + +#include "mrc/exceptions/runtime_error.hpp" + +#include + +#include +#include // for size_t +#include // for function +#include // for move + +namespace mrc { + +class SimpleService : public Service +{ + public: + SimpleService(bool do_call_in_destructor = true) : + Service("SimpleService"), + m_do_call_in_destructor(do_call_in_destructor) + {} + + ~SimpleService() override + { + if (m_do_call_in_destructor) + { + Service::call_in_destructor(); + } + } + + size_t start_call_count() const + { + return m_start_call_count.load(); + } + + size_t stop_call_count() const + { + return m_stop_call_count.load(); + } + + size_t kill_call_count() const + { + return m_kill_call_count.load(); + } + + size_t await_live_call_count() const + { + return m_await_live_call_count.load(); + } + + size_t await_join_call_count() const + { + return m_await_join_call_count.load(); + } + + void set_start_callback(std::function callback) + { + m_start_callback = std::move(callback); + } + + void set_stop_callback(std::function callback) + { + m_stop_callback = std::move(callback); + } + + void set_kill_callback(std::function callback) + { + m_kill_callback = std::move(callback); + } + + void set_await_live_callback(std::function callback) + { + m_await_live_callback = std::move(callback); + } + + void set_await_join_callback(std::function callback) + { + m_await_join_callback = std::move(callback); + } + + private: + void do_service_start() final + { + if (m_start_callback) + { + m_start_callback(); + } + + m_start_call_count++; + } + + void do_service_stop() final + { + if (m_stop_callback) + { + m_stop_callback(); + } + + m_stop_call_count++; + } + + void do_service_kill() final + { + if (m_kill_callback) + { + m_kill_callback(); + } + + m_kill_call_count++; + } + + void do_service_await_live() final + { + if (m_await_live_callback) + { + m_await_live_callback(); + } + + m_await_live_call_count++; + } + + void do_service_await_join() final + { + if (m_await_join_callback) + { + m_await_join_callback(); + } + + m_await_join_call_count++; + } + + bool m_do_call_in_destructor{true}; + + std::atomic_size_t m_start_call_count{0}; + std::atomic_size_t m_stop_call_count{0}; + std::atomic_size_t m_kill_call_count{0}; + std::atomic_size_t m_await_live_call_count{0}; + std::atomic_size_t m_await_join_call_count{0}; + + std::function m_start_callback; + std::function m_stop_callback; + std::function m_kill_callback; + std::function m_await_live_callback; + std::function m_await_join_callback; +}; + +class TestService : public ::testing::Test +{ + protected: +}; + +TEST_F(TestService, LifeCycle) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + EXPECT_EQ(service.start_call_count(), 1); + + service.service_await_live(); + + EXPECT_EQ(service.await_live_call_count(), 1); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + EXPECT_EQ(service.await_join_call_count(), 1); + + EXPECT_EQ(service.stop_call_count(), 0); + EXPECT_EQ(service.kill_call_count(), 0); +} + +TEST_F(TestService, ServiceNotStarted) +{ + SimpleService service; + + EXPECT_ANY_THROW(service.service_await_live()); + EXPECT_ANY_THROW(service.service_stop()); + EXPECT_ANY_THROW(service.service_kill()); + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, ServiceStop) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Stopping); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 1); +} + +TEST_F(TestService, ServiceKill) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, ServiceStopThenKill) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Stopping); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 1); + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, ServiceKillThenStop) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 0); + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, MultipleStartCalls) +{ + SimpleService service; + + service.service_start(); + + // Call again (should be an error) + EXPECT_ANY_THROW(service.service_start()); + + EXPECT_EQ(service.start_call_count(), 1); +} + +TEST_F(TestService, MultipleStopCalls) +{ + SimpleService service; + + service.service_start(); + + // Multiple calls to stop are fine + service.service_stop(); + service.service_stop(); + + EXPECT_EQ(service.stop_call_count(), 1); +} + +TEST_F(TestService, MultipleKillCalls) +{ + SimpleService service; + + service.service_start(); + + // Multiple calls to kill are fine + service.service_kill(); + service.service_kill(); + + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, MultipleJoinCalls) +{ + SimpleService service; + + service.service_start(); + + service.service_await_live(); + + service.service_await_join(); + service.service_await_join(); + + EXPECT_EQ(service.await_join_call_count(), 1); +} + +TEST_F(TestService, StartWithException) +{ + SimpleService service; + + service.set_start_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + EXPECT_ANY_THROW(service.service_start()); + + EXPECT_EQ(service.state(), ServiceState::Completed); +} + +TEST_F(TestService, LiveWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, MultipleLiveWithException) +{ + SimpleService service; + + service.set_await_live_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_live()); + EXPECT_ANY_THROW(service.service_await_live()); +} + +TEST_F(TestService, JoinWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Join Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, MultipleJoinWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Join Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); + EXPECT_ANY_THROW(service.service_await_join()); +} + +} // namespace mrc diff --git a/cpp/mrc/src/tests/test_ucx.cpp b/cpp/mrc/src/tests/test_ucx.cpp index a80321017..8f65b6f34 100644 --- a/cpp/mrc/src/tests/test_ucx.cpp +++ b/cpp/mrc/src/tests/test_ucx.cpp @@ -39,7 +39,6 @@ #include #include #include -#include using namespace mrc; using namespace ucx; diff --git a/cpp/mrc/tests/CMakeLists.txt b/cpp/mrc/tests/CMakeLists.txt index 821e0d8a2..2d524caac 100644 --- a/cpp/mrc/tests/CMakeLists.txt +++ b/cpp/mrc/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,9 +15,12 @@ # Keep all source files sorted!!! add_executable(test_mrc + coroutines/test_async_generator.cpp coroutines/test_event.cpp + coroutines/test_io_scheduler.cpp coroutines/test_latch.cpp coroutines/test_ring_buffer.cpp + coroutines/test_task_container.cpp coroutines/test_task.cpp modules/test_mirror_tap_module.cpp modules/test_mirror_tap_orchestrator.cpp @@ -35,6 +38,7 @@ add_executable(test_mrc test_node.cpp test_pipeline.cpp test_segment.cpp + test_string_utils.cpp test_thread.cpp test_type_utils.cpp ) diff --git a/cpp/mrc/tests/benchmarking/test_benchmarking.hpp b/cpp/mrc/tests/benchmarking/test_benchmarking.hpp index 99de4e475..c9f7e368d 100644 --- a/cpp/mrc/tests/benchmarking/test_benchmarking.hpp +++ b/cpp/mrc/tests/benchmarking/test_benchmarking.hpp @@ -31,13 +31,11 @@ #include #include -#include #include #include #include #include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/benchmarking/test_stat_gather.hpp b/cpp/mrc/tests/benchmarking/test_stat_gather.hpp index 746be4356..0af0df8ca 100644 --- a/cpp/mrc/tests/benchmarking/test_stat_gather.hpp +++ b/cpp/mrc/tests/benchmarking/test_stat_gather.hpp @@ -29,14 +29,12 @@ #include #include -#include #include #include #include #include #include #include -#include namespace mrc { class TestSegmentResources; diff --git a/cpp/mrc/tests/coroutines/test_async_generator.cpp b/cpp/mrc/tests/coroutines/test_async_generator.cpp new file mode 100644 index 000000000..81626a28c --- /dev/null +++ b/cpp/mrc/tests/coroutines/test_async_generator.cpp @@ -0,0 +1,133 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" + +#include + +#include + +using namespace mrc; + +class TestCoroAsyncGenerator : public ::testing::Test +{}; + +TEST_F(TestCoroAsyncGenerator, Iterator) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + auto task = [&]() -> coroutines::Task<> { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroAsyncGenerator, LoopOnGenerator) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + auto task = [&]() -> coroutines::Task<> { + for (int i = 0; i < 2; i++) + { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + } + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroAsyncGenerator, MultipleBegins) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + // this test shows that begin() and operator++() perform essentially the same function + // both advance the generator to the next state + // while a generator is an iterable, it doesn't hold the entire sequence in memory, it does + // what it suggests, it generates the next item from the previous + + auto task = [&]() -> coroutines::Task<> { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + iter = co_await generator.begin(); + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + }; + + coroutines::sync_wait(task()); +} diff --git a/cpp/mrc/tests/coroutines/test_event.cpp b/cpp/mrc/tests/coroutines/test_event.cpp index 68689637d..61326e0b3 100644 --- a/cpp/mrc/tests/coroutines/test_event.cpp +++ b/cpp/mrc/tests/coroutines/test_event.cpp @@ -48,7 +48,6 @@ #include #include #include -#include #include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_io_scheduler.cpp b/cpp/mrc/tests/coroutines/test_io_scheduler.cpp new file mode 100644 index 000000000..26efb93c1 --- /dev/null +++ b/cpp/mrc/tests/coroutines/test_io_scheduler.cpp @@ -0,0 +1,82 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/io_scheduler.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/time.hpp" +#include "mrc/coroutines/when_all.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +using namespace mrc; +using namespace std::chrono_literals; + +class TestCoroIoScheduler : public ::testing::Test +{}; + +TEST_F(TestCoroIoScheduler, YieldFor) +{ + auto scheduler = coroutines::IoScheduler::get_instance(); + + auto task = [scheduler]() -> coroutines::Task<> { + co_await scheduler->yield_for(10ms); + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroIoScheduler, YieldUntil) +{ + auto scheduler = coroutines::IoScheduler::get_instance(); + + auto task = [scheduler]() -> coroutines::Task<> { + co_await scheduler->yield_until(coroutines::clock_t::now() + 10ms); + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroIoScheduler, Concurrent) +{ + auto scheduler = coroutines::IoScheduler::get_instance(); + + auto task = [scheduler]() -> coroutines::Task<> { + co_await scheduler->yield_for(10ms); + }; + + auto start = coroutines::clock_t::now(); + + std::vector> tasks; + + for (uint32_t i = 0; i < 1000; i++) + { + tasks.push_back(task()); + } + + coroutines::sync_wait(coroutines::when_all(std::move(tasks))); + + ASSERT_LT(coroutines::clock_t::now() - start, 20ms); +} diff --git a/cpp/mrc/tests/coroutines/test_latch.cpp b/cpp/mrc/tests/coroutines/test_latch.cpp index 1136bf76e..5be3b31e7 100644 --- a/cpp/mrc/tests/coroutines/test_latch.cpp +++ b/cpp/mrc/tests/coroutines/test_latch.cpp @@ -44,7 +44,6 @@ #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_ring_buffer.cpp b/cpp/mrc/tests/coroutines/test_ring_buffer.cpp index fb9afa1c4..a5b0163a2 100644 --- a/cpp/mrc/tests/coroutines/test_ring_buffer.cpp +++ b/cpp/mrc/tests/coroutines/test_ring_buffer.cpp @@ -49,7 +49,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/mrc/tests/coroutines/test_task.cpp b/cpp/mrc/tests/coroutines/test_task.cpp index ffc40a3ef..60cbfafa5 100644 --- a/cpp/mrc/tests/coroutines/test_task.cpp +++ b/cpp/mrc/tests/coroutines/test_task.cpp @@ -49,9 +49,7 @@ #include #include #include -#include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_task_container.cpp b/cpp/mrc/tests/coroutines/test_task_container.cpp new file mode 100644 index 000000000..3a5a1bbf0 --- /dev/null +++ b/cpp/mrc/tests/coroutines/test_task_container.cpp @@ -0,0 +1,92 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/task_container.hpp" +#include "mrc/coroutines/test_scheduler.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include + +class TestCoroTaskContainer : public ::testing::Test +{}; + +TEST_F(TestCoroTaskContainer, LifeCycle) {} + +TEST_F(TestCoroTaskContainer, MaxSimultaneousTasks) +{ + using namespace std::chrono_literals; + + const int32_t num_threads = 16; + const int32_t num_tasks_per_thread = 16; + const int32_t num_tasks = num_threads * num_tasks_per_thread; + const int32_t max_concurrent_tasks = 2; + + auto on = std::make_shared(); + auto task_container = mrc::coroutines::TaskContainer(on, max_concurrent_tasks); + + auto start_time = on->time(); + + std::vector> execution_times; + + auto delay = [](std::shared_ptr on, + std::vector>& execution_times) + -> mrc::coroutines::Task<> { + co_await on->yield_for(100ms); + execution_times.emplace_back(on->time()); + }; + + std::vector threads; + + for (auto i = 0; i < num_threads; i++) + { + threads.emplace_back([&]() { + for (auto i = 0; i < num_tasks_per_thread; i++) + { + task_container.start(delay(on, execution_times)); + } + }); + } + + for (auto& thread : threads) + { + thread.join(); + } + + auto task = task_container.garbage_collect_and_yield_until_empty(); + + task.resume(); + + while (on->resume_next()) {} + + mrc::coroutines::sync_wait(task); + + ASSERT_EQ(execution_times.size(), num_tasks); + + for (auto i = 0; i < execution_times.size(); i++) + { + ASSERT_EQ(execution_times[i], start_time + (i / max_concurrent_tasks + 1) * 100ms) << "Failed at index " << i; + } +} diff --git a/cpp/mrc/tests/logging/test_logging.cpp b/cpp/mrc/tests/logging/test_logging.cpp index f72cb113c..0d26a82bb 100644 --- a/cpp/mrc/tests/logging/test_logging.cpp +++ b/cpp/mrc/tests/logging/test_logging.cpp @@ -21,8 +21,6 @@ #include -#include - namespace mrc { TEST_CLASS(Logging); diff --git a/cpp/mrc/tests/modules/dynamic_module.cpp b/cpp/mrc/tests/modules/dynamic_module.cpp index 3db4e08cd..9538ed825 100644 --- a/cpp/mrc/tests/modules/dynamic_module.cpp +++ b/cpp/mrc/tests/modules/dynamic_module.cpp @@ -19,13 +19,13 @@ #include "mrc/modules/segment_modules.hpp" #include "mrc/node/rx_source.hpp" #include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" #include "mrc/utils/type_utils.hpp" #include "mrc/version.hpp" #include #include -#include #include #include #include diff --git a/cpp/mrc/tests/modules/test_mirror_tap_module.cpp b/cpp/mrc/tests/modules/test_mirror_tap_module.cpp index 7f68a354b..165382a94 100644 --- a/cpp/mrc/tests/modules/test_mirror_tap_module.cpp +++ b/cpp/mrc/tests/modules/test_mirror_tap_module.cpp @@ -20,10 +20,10 @@ #include "mrc/cuda/device_guard.hpp" #include "mrc/experimental/modules/mirror_tap/mirror_tap.hpp" #include "mrc/modules/properties/persistent.hpp" -#include "mrc/node/operators/broadcast.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -38,11 +38,9 @@ #include #include -#include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp b/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp index ceeba44e2..2de1cf98c 100644 --- a/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp +++ b/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp @@ -20,9 +20,10 @@ #include "mrc/cuda/device_guard.hpp" #include "mrc/experimental/modules/mirror_tap/mirror_tap_orchestrator.hpp" #include "mrc/modules/properties/persistent.hpp" -#include "mrc/node/operators/broadcast.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -37,12 +38,10 @@ #include #include -#include #include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/modules/test_module_util.cpp b/cpp/mrc/tests/modules/test_module_util.cpp index 989ec4ed1..f064df81a 100644 --- a/cpp/mrc/tests/modules/test_module_util.cpp +++ b/cpp/mrc/tests/modules/test_module_util.cpp @@ -20,13 +20,11 @@ #include "mrc/modules/module_registry_util.hpp" #include "mrc/modules/properties/persistent.hpp" #include "mrc/modules/sample_modules.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/version.hpp" #include +#include -#include -#include #include #include #include diff --git a/cpp/mrc/tests/modules/test_segment_modules.cpp b/cpp/mrc/tests/modules/test_segment_modules.cpp index 6c23a930f..ac4f1ec79 100644 --- a/cpp/mrc/tests/modules/test_segment_modules.cpp +++ b/cpp/mrc/tests/modules/test_segment_modules.cpp @@ -67,6 +67,7 @@ TEST_F(TestSegmentModules, ModuleInitializationTest) { using namespace modules; + GTEST_SKIP() << "To be re-enabled by issue #390"; auto init_wrapper = [](segment::IBuilder& builder) { auto config_1 = nlohmann::json(); auto config_2 = nlohmann::json(); @@ -118,7 +119,7 @@ TEST_F(TestSegmentModules, ModuleInitializationTest) Executor executor(options); executor.register_pipeline(std::move(m_pipeline)); - executor.stop(); + executor.start(); executor.join(); } diff --git a/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp b/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp index c5cb376f8..cab4d21ac 100644 --- a/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp +++ b/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp @@ -39,13 +39,11 @@ #include #include -#include #include #include #include #include #include -#include using namespace mrc; using namespace mrc::modules::stream_buffers; @@ -57,6 +55,7 @@ TEST_F(TestStreamBufferModule, InitailizationTest) { using namespace modules; + GTEST_SKIP() << "To be re-enabled by issue #390"; auto init_wrapper = [](segment::IBuilder& builder) { auto config1 = nlohmann::json(); auto mirror_buffer1 = builder.make_module("mirror_tap", config1); @@ -70,7 +69,7 @@ TEST_F(TestStreamBufferModule, InitailizationTest) Executor executor(options); executor.register_pipeline(std::move(m_pipeline)); - executor.stop(); + executor.start(); executor.join(); } diff --git a/cpp/mrc/tests/test_channel.cpp b/cpp/mrc/tests/test_channel.cpp index 6d796dba6..1a5f8ef2e 100644 --- a/cpp/mrc/tests/test_channel.cpp +++ b/cpp/mrc/tests/test_channel.cpp @@ -27,7 +27,6 @@ #include #include -#include #include // for sleep_for #include // for duration, system_clock, milliseconds, time_point @@ -35,7 +34,6 @@ #include // for uint64_t #include // for ref, reference_wrapper #include -#include #include // IWYU thinks algorithm is needed for: auto channel = std::make_shared>(2); // IWYU pragma: no_include diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 86e42dfb5..346750683 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,15 +19,21 @@ #include "mrc/channel/buffered_channel.hpp" // IWYU pragma: keep #include "mrc/channel/forward.hpp" +#include "mrc/edge/edge.hpp" // for Edge #include "mrc/edge/edge_builder.hpp" #include "mrc/edge/edge_channel.hpp" +#include "mrc/edge/edge_holder.hpp" // for EdgeHolder #include "mrc/edge/edge_readable.hpp" #include "mrc/edge/edge_writable.hpp" +#include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/generic_source.hpp" #include "mrc/node/operators/broadcast.hpp" #include "mrc/node/operators/combine_latest.hpp" #include "mrc/node/operators/node_component.hpp" +#include "mrc/node/operators/round_robin_router_typeless.hpp" #include "mrc/node/operators/router.hpp" +#include "mrc/node/operators/with_latest_from.hpp" +#include "mrc/node/operators/zip.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/sink_channel_owner.hpp" #include "mrc/node/sink_properties.hpp" @@ -39,10 +45,13 @@ #include #include // for observable_member +#include #include +#include #include #include #include +#include #include #include #include @@ -121,25 +130,73 @@ template class TestSource : public WritableAcceptor, public ReadableProvider, public SourceChannelOwner { public: - TestSource() + TestSource(std::vector values) : + m_init_values(values), + m_values(std::deque(std::make_move_iterator(values.begin()), std::make_move_iterator(values.end()))) { this->set_channel(std::make_unique>()); } + TestSource(std::initializer_list values) : + TestSource(std::vector(std::make_move_iterator(values.begin()), std::make_move_iterator(values.end()))) + {} + + TestSource(size_t count) : TestSource(gen_values(count)) {} + + TestSource() : TestSource(3) {} + void run() + { + // Just push them all + this->push(m_values.size()); + } + + void push_one() + { + this->push(1); + } + + void push(size_t count = 1) { auto output = this->get_writable_edge(); - for (int i = 0; i < 3; i++) + for (size_t i = 0; i < count; ++i) { - if (output->await_write(T(i)) != channel::Status::success) + if (output->await_write(std::move(m_values.front())) != channel::Status::success) { - break; + this->release_edge_connection(); + throw exceptions::MrcRuntimeError("Failed to push values. await_write returned non-success status"); } + + m_values.pop(); } - this->release_edge_connection(); + if (m_values.empty()) + { + this->release_edge_connection(); + } + } + + const std::vector& get_init_values() + { + return m_init_values; } + + private: + static std::vector gen_values(size_t count) + { + std::vector values; + + for (size_t i = 0; i < count; ++i) + { + values.emplace_back(i); + } + + return values; + } + + std::vector m_init_values; + std::queue m_values; }; template @@ -153,15 +210,8 @@ class TestNode : public WritableProvider, public: TestNode() { - this->set_channel(std::make_unique>()); - } - - void set_channel(std::unique_ptr> channel) - { - edge::EdgeChannel edge_channel(std::move(channel)); - - SinkChannelOwner::do_set_channel(edge_channel); - SourceChannelOwner::do_set_channel(edge_channel); + SinkChannelOwner::set_channel(std::make_unique>()); + SourceChannelOwner::set_channel(std::make_unique>()); } void run() @@ -175,7 +225,12 @@ class TestNode : public WritableProvider, { VLOG(10) << "Node got value: " << t; - output->await_write(std::move(t)); + if (output->await_write(std::move(t)) != channel::Status::success) + { + SinkChannelOwner::release_edge_connection(); + SourceChannelOwner::release_edge_connection(); + throw exceptions::MrcRuntimeError("Failed to push values. await_write returned non-success status"); + } } VLOG(10) << "Node exited run"; @@ -203,12 +258,21 @@ class TestSink : public WritableProvider, public ReadableAcceptor, public while (input->await_read(t) == channel::Status::success) { VLOG(10) << "Sink got value"; + m_values.emplace_back(std::move(t)); } VLOG(10) << "Sink exited run"; this->release_edge_connection(); } + + const std::vector& get_values() + { + return m_values; + } + + private: + std::vector m_values; }; template @@ -233,17 +297,40 @@ template class TestSourceComponent : public GenericSourceComponent { public: - TestSourceComponent() = default; + TestSourceComponent(std::vector values) : + m_init_values(values), + m_values(std::deque(std::make_move_iterator(values.begin()), std::make_move_iterator(values.end()))) + {} + + TestSourceComponent(std::initializer_list values) : + TestSourceComponent( + std::vector(std::make_move_iterator(values.begin()), std::make_move_iterator(values.end()))) + {} + + TestSourceComponent(size_t count) : TestSourceComponent(gen_values(count)) {} + + TestSourceComponent() : TestSourceComponent(3) {} + + const std::vector& get_init_values() + { + return m_init_values; + } protected: channel::Status get_data(T& data) override { - data = m_value++; + // Close after all values have been pulled + if (m_values.empty()) + { + return channel::Status::closed; + } + + data = std::move(m_values.front()); + m_values.pop(); VLOG(10) << "TestSourceComponent emmitted value: " << data; - // Close after 3 - return m_value >= 3 ? channel::Status::closed : channel::Status::success; + return channel::Status::success; } void on_complete() override @@ -252,7 +339,20 @@ class TestSourceComponent : public GenericSourceComponent } private: - T m_value{1}; + static std::vector gen_values(size_t count) + { + std::vector values; + + for (size_t i = 0; i < count; ++i) + { + values.emplace_back(i); + } + + return values; + } + + std::vector m_init_values; + std::queue m_values; }; template @@ -271,7 +371,7 @@ class TestNodeComponent : public NodeComponent { VLOG(10) << "TestNodeComponent got value: " << t; - return this->get_writable_edge()->await_write(t + 1); + return this->get_writable_edge()->await_write(t); } void do_on_complete() override @@ -315,10 +415,18 @@ class TestSinkComponent : public WritableProvider })); } + const std::vector& get_values() + { + return m_values; + } + + protected: channel::Status await_write(int&& t) { VLOG(10) << "TestSinkComponent got value: " << t; + m_values.emplace_back(std::move(t)); + return channel::Status::success; } @@ -326,6 +434,9 @@ class TestSinkComponent : public WritableProvider { VLOG(10) << "TestSinkComponent completed"; } + + private: + std::vector m_values; }; template @@ -398,6 +509,8 @@ TEST_F(TestEdges, SourceToSink) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToSinkUpcast) @@ -409,6 +522,15 @@ TEST_F(TestEdges, SourceToSinkUpcast) source->run(); sink->run(); + + std::vector source_float_vals; + + for (const auto& v : source->get_init_values()) + { + source_float_vals.push_back(v); + } + + EXPECT_EQ(source_float_vals, sink->get_values()); } TEST_F(TestEdges, SourceToSinkTypeless) @@ -420,6 +542,8 @@ TEST_F(TestEdges, SourceToSinkTypeless) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToNodeToSink) @@ -434,6 +558,8 @@ TEST_F(TestEdges, SourceToNodeToSink) source->run(); node->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToNodeToNodeToSink) @@ -451,6 +577,8 @@ TEST_F(TestEdges, SourceToNodeToNodeToSink) node1->run(); node2->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToSinkMultiFail) @@ -475,6 +603,8 @@ TEST_F(TestEdges, SourceToSinkComponent) mrc::make_edge(*source, *sink); source->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceComponentToSink) @@ -485,6 +615,8 @@ TEST_F(TestEdges, SourceComponentToSink) mrc::make_edge(*source, *sink); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceComponentToNodeToSink) @@ -498,6 +630,8 @@ TEST_F(TestEdges, SourceComponentToNodeToSink) node->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToNodeComponentToSink) @@ -511,6 +645,8 @@ TEST_F(TestEdges, SourceToNodeComponentToSink) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToNodeToSinkComponent) @@ -524,6 +660,8 @@ TEST_F(TestEdges, SourceToNodeToSinkComponent) source->run(); node->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToNodeComponentToSinkComponent) @@ -536,6 +674,8 @@ TEST_F(TestEdges, SourceToNodeComponentToSinkComponent) mrc::make_edge(*node, *sink); source->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToRxNodeComponentToSinkComponent) @@ -556,6 +696,8 @@ TEST_F(TestEdges, SourceToRxNodeComponentToSinkComponent) source->run(); EXPECT_TRUE(node->stream_fn_called); + + EXPECT_EQ((std::vector{0, 2, 4}), sink->get_values()); } TEST_F(TestEdges, SourceComponentToNodeToSinkComponent) @@ -568,6 +710,8 @@ TEST_F(TestEdges, SourceComponentToNodeToSinkComponent) mrc::make_edge(*node, *sink); node->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToQueueToSink) @@ -581,6 +725,8 @@ TEST_F(TestEdges, SourceToQueueToSink) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToQueueToNodeToSink) @@ -597,6 +743,8 @@ TEST_F(TestEdges, SourceToQueueToNodeToSink) source->run(); node->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToQueueToMultiSink) @@ -613,6 +761,9 @@ TEST_F(TestEdges, SourceToQueueToMultiSink) source->run(); sink1->run(); sink2->run(); + + EXPECT_EQ(source->get_init_values(), sink1->get_values()); + EXPECT_EQ(std::vector{}, sink2->get_values()); } TEST_F(TestEdges, SourceToQueueToDifferentSinks) @@ -632,6 +783,9 @@ TEST_F(TestEdges, SourceToQueueToDifferentSinks) node->run(); sink1->run(); sink2->run(); + + EXPECT_EQ((std::vector{}), sink1->get_values()); + EXPECT_EQ(source->get_init_values(), sink2->get_values()); } TEST_F(TestEdges, SourceToRouterToSinks) @@ -648,6 +802,9 @@ TEST_F(TestEdges, SourceToRouterToSinks) source->run(); sink1->run(); sink2->run(); + + EXPECT_EQ((std::vector{1}), sink1->get_values()); + EXPECT_EQ((std::vector{0, 2}), sink2->get_values()); } TEST_F(TestEdges, SourceToRouterToDifferentSinks) @@ -663,6 +820,39 @@ TEST_F(TestEdges, SourceToRouterToDifferentSinks) source->run(); sink1->run(); + + EXPECT_EQ((std::vector{1}), sink1->get_values()); + EXPECT_EQ((std::vector{0, 2}), sink2->get_values()); +} + +TEST_F(TestEdges, SourceToRoundRobinRouterTypelessToDifferentSinks) +{ + auto source = std::make_shared>(); + auto router = std::make_shared(); + auto sink1 = std::make_shared>(); + auto sink2 = std::make_shared>(); + + mrc::make_edge(*source, *router); + mrc::make_edge(*router, *sink1); + mrc::make_edge(*router, *sink2); + + source->run(); + sink1->run(); +} + +TEST_F(TestEdges, SourceToRoundRobinRouterTypelessToDifferentSinks) +{ + auto source = std::make_shared>(); + auto router = std::make_shared(); + auto sink1 = std::make_shared>(); + auto sink2 = std::make_shared>(); + + mrc::make_edge(*source, *router); + mrc::make_edge(*router, *sink1); + mrc::make_edge(*router, *sink2); + + source->run(); + sink1->run(); } TEST_F(TestEdges, SourceToBroadcastToSink) @@ -676,6 +866,8 @@ TEST_F(TestEdges, SourceToBroadcastToSink) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToBroadcastTypelessToSinkSinkFirst) @@ -689,6 +881,8 @@ TEST_F(TestEdges, SourceToBroadcastTypelessToSinkSinkFirst) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToBroadcastTypelessToSinkSourceFirst) @@ -702,6 +896,8 @@ TEST_F(TestEdges, SourceToBroadcastTypelessToSinkSourceFirst) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToMultipleBroadcastTypelessToSinkSinkFirst) @@ -717,6 +913,8 @@ TEST_F(TestEdges, SourceToMultipleBroadcastTypelessToSinkSinkFirst) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, SourceToMultipleBroadcastTypelessToSinkSourceFirst) @@ -732,6 +930,8 @@ TEST_F(TestEdges, SourceToMultipleBroadcastTypelessToSinkSourceFirst) source->run(); sink->run(); + + EXPECT_EQ(source->get_init_values(), sink->get_values()); } TEST_F(TestEdges, MultiSourceToMultipleBroadcastTypelessToMultiSink) @@ -753,6 +953,12 @@ TEST_F(TestEdges, MultiSourceToMultipleBroadcastTypelessToMultiSink) source2->run(); sink1->run(); sink2->run(); + + auto expected = source1->get_init_values(); + expected.insert(expected.end(), source2->get_init_values().begin(), source2->get_init_values().end()); + + EXPECT_EQ(expected, sink1->get_values()); + EXPECT_EQ(expected, sink2->get_values()); } TEST_F(TestEdges, SourceToBroadcastToMultiSink) @@ -767,6 +973,11 @@ TEST_F(TestEdges, SourceToBroadcastToMultiSink) mrc::make_edge(*broadcast, *sink2); source->run(); + sink1->run(); + sink2->run(); + + EXPECT_EQ(source->get_init_values(), sink1->get_values()); + EXPECT_EQ(source->get_init_values(), sink2->get_values()); } TEST_F(TestEdges, SourceToBroadcastToDifferentSinks) @@ -781,6 +992,10 @@ TEST_F(TestEdges, SourceToBroadcastToDifferentSinks) mrc::make_edge(*broadcast, *sink2); source->run(); + sink1->run(); + + EXPECT_EQ(source->get_init_values(), sink1->get_values()); + EXPECT_EQ(source->get_init_values(), sink2->get_values()); } TEST_F(TestEdges, SourceToBroadcastToSinkComponents) @@ -795,6 +1010,9 @@ TEST_F(TestEdges, SourceToBroadcastToSinkComponents) mrc::make_edge(*broadcast, *sink2); source->run(); + + EXPECT_EQ(source->get_init_values(), sink1->get_values()); + EXPECT_EQ(source->get_init_values(), sink2->get_values()); } TEST_F(TestEdges, SourceComponentDoubleToSinkFloat) @@ -805,6 +1023,8 @@ TEST_F(TestEdges, SourceComponentDoubleToSinkFloat) mrc::make_edge(*source, *sink); sink->run(); + + EXPECT_EQ((std::vector{0, 1, 2}), sink->get_values()); } TEST_F(TestEdges, CombineLatest) @@ -824,6 +1044,212 @@ TEST_F(TestEdges, CombineLatest) source2->run(); sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{2, 0}, + std::tuple{2, 1}, + std::tuple{2, 2}, + })); +} + +TEST_F(TestEdges, Zip) +{ + auto source1 = std::make_shared>(); + auto source2 = std::make_shared>(); + + auto zip = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); + mrc::make_edge(*zip, *sink); + + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 0}, + std::tuple{1, 1}, + std::tuple{2, 2}, + })); +} + +TEST_F(TestEdges, ZipEarlyClose) +{ + // Have one source emit different counts than the other + auto source1 = std::make_shared>(3); + auto source2 = std::make_shared>(4); + + auto zip = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); + mrc::make_edge(*zip, *sink); + + source1->run(); + + // Should throw when pushing last value + EXPECT_THROW(source2->run(), exceptions::MrcRuntimeError); +} + +TEST_F(TestEdges, ZipLateClose) +{ + // Have one source emit different counts than the other + auto source1 = std::make_shared>(4); + auto source2 = std::make_shared>(3); + + auto zip = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); + mrc::make_edge(*zip, *sink); + + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 0}, + std::tuple{1, 1}, + std::tuple{2, 2}, + })); +} + +TEST_F(TestEdges, ZipEarlyReset) +{ + // Have one source emit different counts than the other + auto source1 = std::make_shared>(4); + auto source2 = std::make_shared>(3); + + auto zip = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, zip->get_sink<0>()); + mrc::make_edge(*source2, zip->get_sink<1>()); + mrc::make_edge(*zip, *sink); + + // After the edges have been made, reset the zip to ensure that it can be kept alive by its children + zip.reset(); + + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 0}, + std::tuple{1, 1}, + std::tuple{2, 2}, + })); +} + +TEST_F(TestEdges, WithLatestFrom) +{ + auto source1 = std::make_shared>(5); + auto source2 = std::make_shared>(5); + auto source3 = std::make_shared>(std::vector{"a", "b", "c", "d", "e"}); + + auto with_latest = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, *with_latest->get_sink<0>()); + mrc::make_edge(*source2, *with_latest->get_sink<1>()); + mrc::make_edge(*source3, *with_latest->get_sink<2>()); + mrc::make_edge(*with_latest, *sink); + + // Push 2 from each + source2->push(2); + source1->push(2); + source3->push(2); + + // Push 2 from each + source2->push(2); + source1->push(2); + source3->push(2); + + // Push the rest + source3->run(); + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 1, "a"}, + std::tuple{1, 1, "a"}, + std::tuple{2, 3, "b"}, + std::tuple{3, 3, "b"}, + std::tuple{4, 3, "e"}, + })); +} + +TEST_F(TestEdges, WithLatestFromUnevenPrimary) +{ + auto source1 = std::make_shared>(5); + auto source2 = std::make_shared>(3); + + auto with_latest = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, *with_latest->get_sink<0>()); + mrc::make_edge(*source2, *with_latest->get_sink<1>()); + mrc::make_edge(*with_latest, *sink); + + source2->run(); + source1->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 2}, + std::tuple{1, 2}, + std::tuple{2, 2}, + std::tuple{3, 2}, + std::tuple{4, 2}, + })); +} + +TEST_F(TestEdges, WithLatestFromUnevenSecondary) +{ + auto source1 = std::make_shared>(3); + auto source2 = std::make_shared>(5); + + auto with_latest = std::make_shared>(); + + auto sink = std::make_shared>>(); + + mrc::make_edge(*source1, *with_latest->get_sink<0>()); + mrc::make_edge(*source2, *with_latest->get_sink<1>()); + mrc::make_edge(*with_latest, *sink); + + source1->run(); + source2->run(); + + sink->run(); + + EXPECT_EQ(sink->get_values(), + (std::vector>{ + std::tuple{0, 0}, + std::tuple{1, 0}, + std::tuple{2, 0}, + })); } TEST_F(TestEdges, SourceToNull) @@ -996,4 +1422,37 @@ TEST_F(TestEdges, EdgeTapWithSpliceRxComponent) EXPECT_TRUE(node->stream_fn_called); } + +template +class TestEdgeHolder : public edge::EdgeHolder +{ + public: + bool has_active_connection() const + { + return this->check_active_connection(false); + } + + void call_release_edge_connection() + { + this->release_edge_connection(); + } + + void call_init_owned_edge(std::shared_ptr> edge) + { + this->init_owned_edge(std::move(edge)); + } +}; + +TEST_F(TestEdges, EdgeHolderIsConnected) +{ + TestEdgeHolder edge_holder; + auto edge = std::make_shared>(); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_init_owned_edge(edge); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_release_edge_connection(); + EXPECT_FALSE(edge_holder.has_active_connection()); +} } // namespace mrc diff --git a/cpp/mrc/tests/test_executor.cpp b/cpp/mrc/tests/test_executor.cpp index e8da2fe0b..989dfe2f1 100644 --- a/cpp/mrc/tests/test_executor.cpp +++ b/cpp/mrc/tests/test_executor.cpp @@ -17,7 +17,9 @@ #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/engine_groups.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" @@ -41,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -49,7 +50,6 @@ #include #include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/test_mrc.hpp b/cpp/mrc/tests/test_mrc.hpp index e8971c3c5..79143c444 100644 --- a/cpp/mrc/tests/test_mrc.hpp +++ b/cpp/mrc/tests/test_mrc.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -130,7 +130,7 @@ class ParallelTester /** * @brief Method to call at the parallelization test point by all threads. Can be used in gtest with - * `EXPECT_TRUE(parallel_test.wait_for(100ms));` to fail if parallelization isnt met + * `EXPECT_TRUE(parallel_test.wait_for(250ms));` to fail if parallelization isnt met * * @tparam RepT Duration Rep type * @tparam PeriodT Duration Period type diff --git a/cpp/mrc/tests/test_node.cpp b/cpp/mrc/tests/test_node.cpp index 428c41d2c..34ea01a85 100644 --- a/cpp/mrc/tests/test_node.cpp +++ b/cpp/mrc/tests/test_node.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,7 +40,6 @@ #include #include #include -#include #include #include #include @@ -573,7 +572,7 @@ TEST_P(ParallelTests, SourceMultiThread) } DVLOG(1) << context.info() << " Enqueueing value: '" << i << "'" << std::endl; - ASSERT_TRUE(parallel_test.wait_for(100ms)); + ASSERT_TRUE(parallel_test.wait_for(250ms)); s.on_next(i); } @@ -673,7 +672,7 @@ TEST_P(ParallelTests, SinkMultiThread) // Print value DVLOG(1) << context.info() << " Sink got value: '" << x << "'" << std::endl; - EXPECT_TRUE(parallel_test.wait_for(100ms)); + EXPECT_TRUE(parallel_test.wait_for(250ms)); ++next_count; }, @@ -745,7 +744,7 @@ TEST_P(ParallelTests, NodeMultiThread) DVLOG(1) << context.info() << " Node got value: '" << x << "'" << std::endl; - EXPECT_TRUE(parallel_test.wait_for(100ms)); + EXPECT_TRUE(parallel_test.wait_for(250ms)); // Double the value return x * 2; })); diff --git a/cpp/mrc/tests/test_pipeline.cpp b/cpp/mrc/tests/test_pipeline.cpp index c34731302..6d1bc4499 100644 --- a/cpp/mrc/tests/test_pipeline.cpp +++ b/cpp/mrc/tests/test_pipeline.cpp @@ -16,7 +16,9 @@ */ #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -33,12 +35,9 @@ #include #include -#include #include #include -#include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/test_segment.cpp b/cpp/mrc/tests/test_segment.cpp index be1bbc29c..bd3b09d78 100644 --- a/cpp/mrc/tests/test_segment.cpp +++ b/cpp/mrc/tests/test_segment.cpp @@ -23,7 +23,6 @@ #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -40,9 +39,7 @@ #include #include -#include #include -#include #include #include #include diff --git a/cpp/mrc/tests/test_string_utils.cpp b/cpp/mrc/tests/test_string_utils.cpp new file mode 100644 index 000000000..fbaf5b14e --- /dev/null +++ b/cpp/mrc/tests/test_string_utils.cpp @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./test_mrc.hpp" // IWYU pragma: associated + +#include "mrc/utils/string_utils.hpp" // for split_string_to_vector + +#include // for EXPECT_EQ + +#include +#include + +namespace mrc { + +TEST_CLASS(StringUtils); + +TEST_F(TestStringUtils, TestSplitStringToVector) +{ + struct TestValues + { + std::string str; + std::string delimiter; + std::vector expected_result; + }; + + std::vector values = { + {"Hello,World,!", ",", {"Hello", "World", "!"}}, + {"a/b/c", "/", {"a", "b", "c"}}, + {"/a/b/c", "/", {"", "a", "b", "c"}}, // leading delimeter + {"a/b/c/", "/", {"a", "b", "c", ""}}, // trailing delimeter + {"abcd", "/", {"abcd"}}, // no delimeter + {"", "/", {""}}, // empty string + {"/", "/", {"", ""}}, // single delimeter + {"//", "/", {"", "", ""}}, // duplicate delimeter + }; + + for (const auto& value : values) + { + auto result = mrc::split_string_to_vector(value.str, value.delimiter); + EXPECT_EQ(result, value.expected_result); + } +} + +} // namespace mrc diff --git a/cpp/mrc/tests/test_thread.cpp b/cpp/mrc/tests/test_thread.cpp index c19753734..88785379f 100644 --- a/cpp/mrc/tests/test_thread.cpp +++ b/cpp/mrc/tests/test_thread.cpp @@ -25,7 +25,6 @@ #include #include -#include using namespace mrc; diff --git a/dependencies.yaml b/dependencies.yaml new file mode 100644 index 000000000..22e206c14 --- /dev/null +++ b/dependencies.yaml @@ -0,0 +1,152 @@ +# Dependency list for https://github.com/rapidsai/dependency-file-generator +files: + all: + output: conda + matrix: + cuda: ["12.5"] + arch: [x86_64] + includes: + - build + - checks + - developer_productivity + - code_style + - testing + - benchmarking + - ci + - examples + - documentation + - python + - cudatoolkit + + ci: + output: conda + matrix: + cuda: ["12.5"] + arch: [x86_64] + includes: + - build + - code_style + - testing + - benchmarking + - ci + - documentation + - python + - cudatoolkit + + checks: + output: none + includes: + - checks + +channels: + - conda-forge + - rapidsai + - rapidsai-nightly + - nvidia + +dependencies: + + build: + common: + - output_types: [conda] + packages: + - boost-cpp=1.84 + - ccache + - cmake=3.27 + - cuda-nvcc + - cxx-compiler + - glog>=0.7.1,<0.8 + - gtest=1.14 + - gxx=12.1 + - libgrpc=1.62.2 + - libhwloc=2.9.2 + - librmm=24.10 + - libxml2=2.11.6 + - ninja=1.11 + - nlohmann_json=3.11 + - numactl=2.0.18 + - pkg-config=0.29 + - pybind11-stubgen=0.10 + - scikit-build=0.17 + - ucx=1.15 + + checks: + common: + - output_types: [conda] + packages: + - pre-commit + + developer_productivity: + common: + - output_types: [conda] + packages: + - bash-completion + - clang-tools=16 + - clang=16 + - clangdev=16 + - clangxx=16 + - flake8 + - gdb + - libclang-cpp=16 + - libclang=16 + - llvmdev=16 + - yapf + + code_style: + common: + - output_types: [conda] + packages: + - include-what-you-use=0.20 + + testing: + common: + - output_types: [conda] + packages: + - pytest + - pytest-asyncio + - pytest-timeout + + benchmarking: + common: + - output_types: [conda] + packages: + - benchmark=1.8.3 + + ci: + common: + - output_types: [conda] + packages: + - codecov=2.1 + - gcovr=5.2 + - pre-commit + + examples: + common: + - output_types: [conda] + packages: + - numpy=1.24 + + documentation: + common: + - output_types: [conda] + packages: + - doxygen=1.9.2 + - python-graphviz + + python: + common: + - output_types: [conda] + packages: + - python=3.10 + + cudatoolkit: + specific: + - output_types: [conda] + matrices: + - matrix: + cuda: "12.5" + packages: + - cuda-cudart-dev=12.5 + - cuda-nvml-dev=12.5 + - cuda-nvrtc-dev=12.5 + - cuda-version=12.5 diff --git a/docs/quickstart/CMakeLists.txt b/docs/quickstart/CMakeLists.txt index 1b87c4b87..b201248c6 100644 --- a/docs/quickstart/CMakeLists.txt +++ b/docs/quickstart/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,7 @@ list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../../external/utili include(morpheus_utils/load) project(mrc-quickstart - VERSION 23.07.00 + VERSION 24.10.00 LANGUAGES C CXX ) @@ -42,7 +42,7 @@ set(OPTION_PREFIX "MRC") morpheus_utils_python_configure() -rapids_find_package(mrc REQUIRED) +morpheus_utils_configure_mrc() rapids_find_package(CUDAToolkit REQUIRED) # To make it easier for CI to find output files, set the default executable suffix to .x if not set diff --git a/docs/quickstart/environment_cpp.yml b/docs/quickstart/environment_cpp.yml index 379bf6477..860d54a4c 100644 --- a/docs/quickstart/environment_cpp.yml +++ b/docs/quickstart/environment_cpp.yml @@ -24,14 +24,14 @@ dependencies: - isort - libtool - ninja=1.10 - - numactl-libs-cos7-x86_64 + - numactl=2.0.18 - numpy>=1.21 - nvcc_linux-64=11.8 - pkg-config=0.29 - python=3.10 - scikit-build>=0.12 - - mrc=23.07 - - sysroot_linux-64=2.17 + - mrc=24.10 + - sysroot_linux-64>=2.28 - pip: - cython - flake8 diff --git a/docs/quickstart/hybrid/mrc_qs_hybrid/ex00_wrap_data_objects/CMakeLists.txt b/docs/quickstart/hybrid/mrc_qs_hybrid/ex00_wrap_data_objects/CMakeLists.txt index c46b9b0bd..b1996a71a 100644 --- a/docs/quickstart/hybrid/mrc_qs_hybrid/ex00_wrap_data_objects/CMakeLists.txt +++ b/docs/quickstart/hybrid/mrc_qs_hybrid/ex00_wrap_data_objects/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -mrc_quickstart_add_pybind11_module( +mrc_add_pybind11_module( data MODULE_ROOT ${QUICKSTART_HYBRID_HOME} diff --git a/docs/quickstart/hybrid/mrc_qs_hybrid/ex01_wrap_nodes/CMakeLists.txt b/docs/quickstart/hybrid/mrc_qs_hybrid/ex01_wrap_nodes/CMakeLists.txt index 60ede0c59..329e222d7 100644 --- a/docs/quickstart/hybrid/mrc_qs_hybrid/ex01_wrap_nodes/CMakeLists.txt +++ b/docs/quickstart/hybrid/mrc_qs_hybrid/ex01_wrap_nodes/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -mrc_quickstart_add_pybind11_module( +mrc_add_pybind11_module( nodes MODULE_ROOT ${QUICKSTART_HYBRID_HOME} diff --git a/docs/quickstart/python/mrc_qs_python/_version.py b/docs/quickstart/python/mrc_qs_python/_version.py index 1ca6b055c..9d81e4e25 100644 --- a/docs/quickstart/python/mrc_qs_python/_version.py +++ b/docs/quickstart/python/mrc_qs_python/_version.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,8 +29,7 @@ import re import subprocess import sys -from typing import Callable -from typing import Dict +from typing import Callable, Dict def get_keywords(): diff --git a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md index 705c79e15..95cb08d01 100644 --- a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md +++ b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md @@ -27,36 +27,33 @@ Lets look at a more complex example: value_count = 0 value_sum = 0 -def node_fn(src: mrc.Observable, dst: mrc.Subscriber): - def update_obj(x: MyCustomClass): - nonlocal value_count - nonlocal value_sum +def update_obj(x: MyCustomClass): + nonlocal value_count + nonlocal value_sum - # Alter the value property of the class - x.value = x.value * 2 + # Alter the value property of the class + x.value = x.value * 2 - # Update the sum values - value_count += 1 - value_sum += x.value + # Update the sum values + value_count += 1 + value_sum += x.value - return x + return x - def on_completed(): +def on_completed(): - # Prevent divide by 0. Just in case - if (value_count <= 0): - return + # Prevent divide by 0. Just in case + if (value_count <= 0): + return - return MyCustomClass(value_sum / value_count, "Mean") - - src.pipe( - ops.filter(lambda x: x.value % 2 == 0), - ops.map(update_obj), - ops.on_completed(on_completed) - ).subscribe(dst) + return MyCustomClass(value_sum / value_count, "Mean") # Make an intermediate node -node = seg.make_node_full("node", node_fn) +node = seg.make_node("node", + ops.filter(lambda x: x.value % 2 == 0), + ops.map(update_obj), + ops.on_completed(on_completed) +) ``` In this example, we are using 3 different operators: `filter`, `map`, and `on_completed`: @@ -66,7 +63,7 @@ In this example, we are using 3 different operators: `filter`, `map`, and `on_co - The `map` operator can transform the incoming value and return a new value - In our example, we are doubling the `value` property and recording the total count and total sum of this property - The `on_completed` function is only called once when there are no more messages to process. You can optionally return a value which will be passed on to the rest of the pipeline. - - In our example, we are calculating the average from the sum and count values and emitting a new obect with the value set to the mean + - In our example, we are calculating the average from the sum and count values and emitting a new object with the value set to the mean In combination, these operators perform a higher level functionality to modify the stream, record some information, and finally print an analysis of all emitted values. Let's see it in practice. diff --git a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/run.py b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/run.py index e181ad053..3fb1b30a9 100644 --- a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/run.py +++ b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/run.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,9 +15,10 @@ import dataclasses -import mrc from mrc.core import operators as ops +import mrc + @dataclasses.dataclass class MyCustomClass: diff --git a/docs/quickstart/python/versioneer.py b/docs/quickstart/python/versioneer.py index 5e21cd07d..350aa2069 100644 --- a/docs/quickstart/python/versioneer.py +++ b/docs/quickstart/python/versioneer.py @@ -286,8 +286,7 @@ import re import subprocess import sys -from typing import Callable -from typing import Dict +from typing import Callable, Dict class VersioneerConfig: diff --git a/external/utilities b/external/utilities index a5b9689e3..6e10e2c9a 160000 --- a/external/utilities +++ b/external/utilities @@ -1 +1 @@ -Subproject commit a5b9689e3a82fe5b49245b0a02c907ea70aed7b8 +Subproject commit 6e10e2c9a686041fdeb3d9e100874c6fa55f0856 diff --git a/mrc.code-workspace b/mrc.code-workspace index 30e2eec34..85b9e0856 100644 --- a/mrc.code-workspace +++ b/mrc.code-workspace @@ -4,12 +4,17 @@ // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp // List of extensions which should be recommended for users of this workspace. "recommendations": [ + "eeyore.yapf", + "esbenp.prettier-vscode", "josetr.cmake-language-support-vscode", "llvm-vs-code-extensions.vscode-clangd", "matepek.vscode-catch2-test-adapter", + "ms-python.flake8", + "ms-python.isort", + "ms-python.pylint", "ms-vscode.cmake-tools", "stkb.rewrap", - "twxs.cmake" + "twxs.cmake", ], // List of extensions recommended by VS Code that should not be recommended for users of this workspace. "unwantedRecommendations": [ @@ -46,6 +51,10 @@ { "description": "Skip stdio-common files", "text": "-interpreter-exec console \"skip -gfi **/bits/*.h\"" + }, + { + "description": "Skip stdio-common files in Conda", + "text": "-interpreter-exec console \"skip -rfu ^std::.*\"" } // { // "description": "Stay on same thread when debugging", @@ -53,6 +62,10 @@ // } ], "stopAtEntry": false, + "symbolLoadInfo": { + "exceptionList": "libmrc*.so", + "loadAll": false + }, "type": "cppdbg" }, { @@ -187,13 +200,14 @@ "editor.semanticHighlighting.enabled": true, "editor.suggest.insertMode": "replace", "editor.tabSize": 4, - "editor.wordBasedSuggestions": false, + "editor.wordBasedSuggestions": "off", "editor.wordWrapColumn": 120 }, "[python]": { "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, + "editor.defaultFormatter": "eeyore.yapf", "editor.formatOnSave": true, "editor.tabSize": 4 }, @@ -202,7 +216,9 @@ "-DMRC_PYTHON_INPLACE_BUILD:BOOL=ON" // Allow inplace build for python. Use `pip install -e .` from the python folder to install ], "cmake.format.allowOptionalArgumentIndentation": true, - "editor.rulers": [120], + "editor.rulers": [ + 120 + ], "files.insertFinalNewline": true, "files.trimTrailingWhitespace": true, "files.watcherExclude": { @@ -212,27 +228,21 @@ "**/.hg/store/**": true, "**/node_modules/*/**": true }, + "flake8.args": [ + "--style=${workspaceFolder}/python/setup.cfg" + ], "isort.args": [ "--settings-file=${workspaceFolder}/python/setup.cfg" ], + "pylint.args": [ + "--rcfile=${workspaceFolder}/python/.pylintrc" + ], "python.analysis.extraPaths": [ "python" ], "python.autoComplete.extraPaths": [ "./python" ], - "python.formatting.provider": "yapf", - "python.formatting.yapfArgs": [ - "--style=${workspaceFolder}/python/setup.cfg" - ], - "python.linting.flake8Args": [ - "--config=${workspaceFolder}/python/setup.cfg" - ], - "python.linting.flake8Enabled": true, - "python.linting.pylintArgs": [ - "--rcfile=${workspaceFolder}/python/.pylintrc" - ], - "python.linting.pylintEnabled": true, "python.testing.cwd": "${workspaceFolder}/python", "python.testing.pytestArgs": [ "-s" @@ -265,6 +275,10 @@ { "description": "Skip stdio-common files", "text": "-interpreter-exec console \"skip -gfi **/bits/*.h\"" + }, + { + "description": "Skip stdio-common files everywhere", + "text": "-interpreter-exec console \"skip -rfu ^std::.*\"" } // { // "description": "Stay on same thread when debugging", @@ -288,6 +302,9 @@ } }, "testMate.cpp.log.logpanel": true, - "testMate.cpp.test.executables": "{build,Build,BUILD,out,Out,OUT}/**/*{test,Test,TEST}_*.x" + "testMate.cpp.test.executables": "{build,Build,BUILD,out,Out,OUT}/**/*{test,Test,TEST}_*.x", + "yapf.args": [ + "--style=${workspaceFolder}/python/setup.cfg" + ] } } diff --git a/protos/CMakeLists.txt b/protos/CMakeLists.txt index 93a538f88..e9cd0e325 100644 --- a/protos/CMakeLists.txt +++ b/protos/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,7 +93,7 @@ add_dependencies(${PROJECT_NAME}_style_checks mrc_protos-headers-target) install( TARGETS mrc_protos mrc_architect_protos - EXPORT ${PROJECT_NAME}-core-exports + EXPORT ${PROJECT_NAME}-exports PUBLIC_HEADER DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/protos" ) diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 2a661c98b..9fb4f1bf5 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,3 +1,7 @@ include versioneer.py include mrc/_version.py -recursive-include mrc *.so py.typed *.pyi +recursive-include mrc py.typed *.pyi +recursive-include mrc/_pymrc/tests *.so +recursive-include mrc/benchmarking *.so +recursive-include mrc/core *.so +recursive-include mrc/tests *.so diff --git a/python/mrc/_pymrc/CMakeLists.txt b/python/mrc/_pymrc/CMakeLists.txt index b1aa4eb77..adfc03c21 100644 --- a/python/mrc/_pymrc/CMakeLists.txt +++ b/python/mrc/_pymrc/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -18,6 +18,7 @@ find_package(prometheus-cpp REQUIRED) # Keep all source files sorted!!! add_library(pymrc + src/coro.cpp src/executor.cpp src/logging.cpp src/module_registry.cpp @@ -36,6 +37,7 @@ add_library(pymrc src/utilities/acquire_gil.cpp src/utilities/deserializers.cpp src/utilities/function_wrappers.cpp + src/utilities/json_values.cpp src/utilities/object_cache.cpp src/utilities/object_wrappers.cpp src/utilities/serializers.cpp @@ -49,8 +51,9 @@ target_link_libraries(pymrc PUBLIC ${PROJECT_NAME}::libmrc ${Python_LIBRARIES} - prometheus-cpp::core pybind11::pybind11 + PRIVATE + prometheus-cpp::core ) target_include_directories(pymrc @@ -71,7 +74,7 @@ rapids_cmake_install_lib_dir(lib_dir) install( TARGETS pymrc DESTINATION ${lib_dir} - EXPORT ${PROJECT_NAME}-core-exports + EXPORT ${PROJECT_NAME}-exports COMPONENT Python ) diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp new file mode 100644 index 000000000..506182a0d --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -0,0 +1,364 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/asyncio_scheduler.hpp" +#include "pymrc/edge_adapter.hpp" +#include "pymrc/node.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mrc::pymrc { + +/** + * @brief A wrapper for executing a function as an async boost fiber, the result of which is a + * C++20 coroutine awaiter. + */ +template +class BoostFutureAwaitableOperation +{ + class Awaiter; + + public: + BoostFutureAwaitableOperation(std::function fn) : m_fn(std::move(fn)) {} + + /** + * @brief Calls the wrapped function as an asyncboost fiber and returns a C++20 coroutine awaiter. + */ + template + auto operator()(ArgsT&&... args) -> Awaiter + { + // Make a copy of m_fn here so we can call this operator again + return Awaiter(m_fn, std::forward(args)...); + } + + private: + class Awaiter + { + public: + using return_t = typename std::function::result_type; + + template + Awaiter(std::function fn, ArgsT&&... args) + { + m_future = boost::fibers::async(boost::fibers::launch::post, fn, std::forward(args)...); + } + + bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> continuation) noexcept + { + // Launch a new fiber that waits on the future and then resumes the coroutine + boost::fibers::async( + boost::fibers::launch::post, + [this](std::coroutine_handle<> continuation) { + // Wait on the future + m_future.wait(); + + // Resume the coroutine + continuation.resume(); + }, + std::move(continuation)); + } + + auto await_resume() + { + return m_future.get(); + } + + private: + boost::fibers::future m_future; + std::function)> m_inner_fn; + }; + + std::function m_fn; +}; + +/** + * @brief A MRC Sink which receives from a channel using an awaitable interface. + */ +template +class AsyncSink : public mrc::node::WritableProvider, + public mrc::node::ReadableAcceptor, + public mrc::node::SinkChannelOwner, + public pymrc::AutoRegSinkAdapter, + public pymrc::AutoRegEgressPort +{ + protected: + AsyncSink() : + m_read_async([this](T& value) { + return this->get_readable_edge()->await_read(value); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously reads a value from the sink's channel + */ + coroutines::Task read_async(T& value) + { + co_return co_await m_read_async(std::ref(value)); + } + + private: + BoostFutureAwaitableOperation m_read_async; +}; + +/** + * @brief A MRC Source which produces to a channel using an awaitable interface. + */ +template +class AsyncSource : public mrc::node::WritableAcceptor, + public mrc::node::ReadableProvider, + public mrc::node::SourceChannelOwner, + public pymrc::AutoRegSourceAdapter, + public pymrc::AutoRegIngressPort +{ + protected: + AsyncSource() : + m_write_async([this](T&& value) { + return this->get_writable_edge()->await_write(std::move(value)); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously writes a value to the source's channel + */ + coroutines::Task write_async(T&& value) + { + co_return co_await m_write_async(std::move(value)); + } + + private: + BoostFutureAwaitableOperation m_write_async; +}; + +/** + * @brief A MRC Runnable base class which hosts it's own asyncio loop and exposes a flatmap hook + */ +template +class AsyncioRunnable : public AsyncSink, + public AsyncSource, + public mrc::runnable::RunnableWithContext<> +{ + using state_t = mrc::runnable::Runnable::State; + using task_buffer_t = mrc::coroutines::ClosableRingBuffer; + + public: + ~AsyncioRunnable() override = default; + + private: + /** + * @brief Runnable's entrypoint. + */ + void run(mrc::runnable::Context& ctx) override; + + /** + * @brief Runnable's state control, for stopping from MRC. + */ + void on_state_update(const state_t& state) final; + + /** + * @brief The top-level coroutine which is run while the asyncio event loop is running. + */ + coroutines::Task<> main_task(std::shared_ptr scheduler); + + /** + * @brief The per-value coroutine run asynchronously alongside other calls. + */ + coroutines::Task<> process_one(InputT value, + std::shared_ptr on, + ExceptionCatcher& catcher); + + /** + * @brief Value's read from the sink's channel are fed to this function and yields from the + * resulting generator are written to the source's channel. + */ + virtual mrc::coroutines::AsyncGenerator on_data(InputT&& value, + std::shared_ptr on) = 0; + + std::stop_source m_stop_source; +}; + +template +void AsyncioRunnable::run(mrc::runnable::Context& ctx) +{ + std::exception_ptr exception; + + { + py::gil_scoped_acquire gil; + + auto asyncio = py::module_::import("asyncio"); + + auto loop = [](auto& asyncio) -> PyObjectHolder { + try + { + return asyncio.attr("get_running_loop")(); + } catch (...) + { + return py::none(); + } + }(asyncio); + + if (not loop.is_none()) + { + throw std::runtime_error("asyncio loop already running, but runnable is expected to create it."); + } + + // Need to create a loop + DVLOG(10) << "AsyncioRunnable::run() > Creating new event loop"; + + // Gets (or more likely, creates) an event loop and runs it forever until stop is called + loop = asyncio.attr("new_event_loop")(); + + // Set the event loop as the current event loop + asyncio.attr("set_event_loop")(loop); + + // TODO(MDD): Eventually we should get this from the context object. For now, just create it directly + auto scheduler = std::make_shared(loop); + + auto py_awaitable = coro::BoostFibersMainPyAwaitable(this->main_task(scheduler)); + + DVLOG(10) << "AsyncioRunnable::run() > Calling run_until_complete() on main_task()"; + + try + { + loop.attr("run_until_complete")(std::move(py_awaitable)); + } catch (...) + { + exception = std::current_exception(); + } + + loop.attr("close")(); + } + + // Sync all progress engines if there are more than one + ctx.barrier(); + + // Only drop the output edges if we are rank 0 + if (ctx.rank() == 0) + { + // Need to drop the output edges + mrc::node::SourceProperties::release_edge_connection(); + mrc::node::SinkProperties::release_edge_connection(); + } + + if (exception != nullptr) + { + std::rethrow_exception(exception); + } +} + +template +coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr scheduler) +{ + coroutines::TaskContainer outstanding_tasks(scheduler, 8); + + ExceptionCatcher catcher{}; + + while (not m_stop_source.stop_requested() and not catcher.has_exception()) + { + InputT data; + + auto read_status = co_await this->read_async(data); + + if (read_status != mrc::channel::Status::success) + { + break; + } + + outstanding_tasks.start(this->process_one(std::move(data), scheduler, catcher)); + } + + co_await outstanding_tasks.garbage_collect_and_yield_until_empty(); + + catcher.rethrow_next_exception(); +} + +template +coroutines::Task<> AsyncioRunnable::process_one(InputT value, + std::shared_ptr on, + ExceptionCatcher& catcher) +{ + co_await on->yield(); + + try + { + // Call the on_data function + auto on_data_gen = this->on_data(std::move(value), on); + + auto iter = co_await on_data_gen.begin(); + + while (iter != on_data_gen.end()) + { + // Weird bug, cant directly move the value into the async_write call + auto data = std::move(*iter); + + co_await this->write_async(std::move(data)); + + // Advance the iterator + co_await ++iter; + } + } catch (...) + { + catcher.push_exception(std::current_exception()); + } +} + +template +void AsyncioRunnable::on_state_update(const state_t& state) +{ + switch (state) + { + case state_t::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case state_t::Kill: + m_stop_source.request_stop(); + break; + + default: + break; + } +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp new file mode 100644 index 000000000..47246cad7 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/coro.hpp" +#include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace mrc::pymrc { + +/** + * @brief A MRC Scheduler which allows resuming C++20 coroutines on an Asyncio event loop. + */ +class AsyncioScheduler : public mrc::coroutines::Scheduler +{ + private: + class ContinueOnLoopOperation + { + public: + ContinueOnLoopOperation(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + static bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> handle) noexcept + { + AsyncioScheduler::resume(m_loop, handle); + } + + static void await_resume() noexcept {} + + private: + PyObjectHolder m_loop; + }; + + static void resume(PyObjectHolder loop, std::coroutine_handle<> handle) noexcept + { + pybind11::gil_scoped_acquire acquire; + loop.attr("call_soon_threadsafe")(pybind11::cpp_function([handle]() { + pybind11::gil_scoped_release release; + handle.resume(); + })); + } + + public: + AsyncioScheduler(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + /** + * @brief Resumes a coroutine on the scheduler's Asyncio event loop + */ + void resume(std::coroutine_handle<> handle) noexcept override + { + AsyncioScheduler::resume(m_loop, handle); + } + + /** + * @brief Suspends the current function and resumes it on the scheduler's Asyncio event loop + */ + [[nodiscard]] coroutines::Task<> yield() override + { + co_await ContinueOnLoopOperation(m_loop); + } + + [[nodiscard]] coroutines::Task<> yield_for(std::chrono::milliseconds amount) override + { + co_await coroutines::IoScheduler::get_instance()->yield_for(amount); + co_await ContinueOnLoopOperation(m_loop); + }; + + [[nodiscard]] coroutines::Task<> yield_until(mrc::coroutines::time_point_t time) override + { + co_await coroutines::IoScheduler::get_instance()->yield_until(time); + co_await ContinueOnLoopOperation(m_loop); + }; + + private: + mrc::pymrc::PyHolder m_loop; +}; + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/coro.hpp b/python/mrc/_pymrc/include/pymrc/coro.hpp new file mode 100644 index 000000000..ad8224a58 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/coro.hpp @@ -0,0 +1,444 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include // for operator<<, basic_ostringstream +#include // for runtime_error +#include // for string +#include + +// Dont directly include python headers +// IWYU pragma: no_include + +namespace mrc::pymrc::coro { + +class PYBIND11_EXPORT StopIteration : public pybind11::stop_iteration +{ + public: + StopIteration(pybind11::object&& result) : stop_iteration("--"), m_result(std::move(result)){}; + ~StopIteration() override; + + void set_error() const override + { + PyErr_SetObject(PyExc_StopIteration, this->m_result.ptr()); + } + + private: + pybind11::object m_result; +}; + +class PYBIND11_EXPORT CppToPyAwaitable : public std::enable_shared_from_this +{ + public: + CppToPyAwaitable() = default; + + template + CppToPyAwaitable(mrc::coroutines::Task&& task) + { + auto converter = [](mrc::coroutines::Task incoming_task) -> mrc::coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + mrc::pymrc::PyHolder holder; + + if constexpr (std::is_same_v) + { + co_await incoming_task; + + // Need the GIL to make the return object + pybind11::gil_scoped_acquire gil; + + holder = pybind11::none(); + } + else + { + auto result = co_await incoming_task; + + // Need the GIL to cast the return object + pybind11::gil_scoped_acquire gil; + + holder = pybind11::cast(std::move(result)); + } + + co_return holder; + }; + + m_task = converter(std::move(task)); + } + + CppToPyAwaitable(mrc::coroutines::Task&& task) : m_task(std::move(task)) {} + + std::shared_ptr iter() + { + return this->shared_from_this(); + } + + std::shared_ptr await() + { + return this->shared_from_this(); + } + + void next() + { + // Need to release the GIL before waiting + pybind11::gil_scoped_release nogil; + + // Run the tick function which will resume the coroutine + this->tick(); + + if (m_task.is_ready()) + { + pybind11::gil_scoped_acquire gil; + + // job done -> throw + auto exception = StopIteration(std::move(m_task.promise().result())); + + // Destroy the task now that we have the value + m_task.destroy(); + + throw exception; + } + } + + protected: + virtual void tick() + { + if (!m_has_resumed) + { + m_has_resumed = true; + + m_task.resume(); + } + } + + bool m_has_resumed{false}; + mrc::coroutines::Task m_task; +}; + +/** + * @brief Similar to CppToPyAwaitable but will yield to other fibers when waiting for the coroutine to finish. Use this + * once per loop at the main entry point for the asyncio loop + * + */ +class PYBIND11_EXPORT BoostFibersMainPyAwaitable : public CppToPyAwaitable +{ + public: + using CppToPyAwaitable::CppToPyAwaitable; + + protected: + void tick() override + { + // Call the base class and then see if any fibers need processing by calling yield + CppToPyAwaitable::tick(); + + bool has_fibers = boost::fibers::has_ready_fibers(); + + if (has_fibers) + { + // Yield to other fibers + boost::this_fiber::yield(); + } + } +}; + +class PYBIND11_EXPORT PyTaskToCppAwaitable +{ + public: + PyTaskToCppAwaitable() = default; + PyTaskToCppAwaitable(mrc::pymrc::PyObjectHolder&& task) : m_task(std::move(task)) + { + pybind11::gil_scoped_acquire acquire; + + auto asyncio = pybind11::module_::import("asyncio"); + + if (not asyncio.attr("isfuture")(m_task).cast()) + { + if (not asyncio.attr("iscoroutine")(m_task).cast()) + { + throw std::runtime_error(MRC_CONCAT_STR("PyTaskToCppAwaitable expected task or coroutine but got " + << pybind11::repr(m_task).cast())); + } + + m_task = asyncio.attr("create_task")(m_task); + } + } + + static bool await_ready() noexcept + { + // Always suspend + return false; + } + + void await_suspend(std::coroutine_handle<> caller) noexcept + { + pybind11::gil_scoped_acquire gil; + + auto done_callback = pybind11::cpp_function([this, caller](pybind11::object future) { + try + { + // Save the result value + m_result = future.attr("result")(); + } catch (pybind11::error_already_set) + { + m_exception_ptr = std::current_exception(); + } + + pybind11::gil_scoped_release nogil; + + // Resume the coroutine + caller.resume(); + }); + + m_task.attr("add_done_callback")(done_callback); + } + + mrc::pymrc::PyHolder await_resume() + { + if (m_exception_ptr) + { + std::rethrow_exception(m_exception_ptr); + } + + return std::move(m_result); + } + + private: + mrc::pymrc::PyObjectHolder m_task; + mrc::pymrc::PyHolder m_result; + std::exception_ptr m_exception_ptr; +}; + +// ====== HELPER MACROS ====== + +#define MRC_PYBIND11_FAIL_ABSTRACT(cname, fnname) \ + pybind11::pybind11_fail(MRC_CONCAT_STR("Tried to call pure virtual function \"" << PYBIND11_STRINGIFY(cname) \ + << "::" << fnname << "\"")); + +// ====== OVERRIDE PURE TEMPLATE ====== +#define MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE_NAME(ret_type, abstract_cname, cname, name, fn, ...) \ + do \ + { \ + PYBIND11_OVERRIDE_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + if constexpr (std::is_same_v) \ + { \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(abstract_cname), name); \ + } \ + else \ + { \ + return cname::fn(__VA_ARGS__); \ + } \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE(ret_type, abstract_cname, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE_NAME(PYBIND11_TYPE(ret_type), \ + PYBIND11_TYPE(abstract_cname), \ + PYBIND11_TYPE(cname), \ + #fn, \ + fn, \ + __VA_ARGS__) +// ====== OVERRIDE PURE TEMPLATE ====== + +// ====== OVERRIDE COROUTINE IMPL ====== +#define MRC_PYBIND11_OVERRIDE_CORO_IMPL(ret_type, cname, name, ...) \ + do \ + { \ + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; \ + pybind11::gil_scoped_acquire gil; \ + pybind11::function override = pybind11::get_override(static_cast(this), name); \ + if (override) \ + { \ + auto o_coro = override(__VA_ARGS__); \ + auto asyncio_module = pybind11::module::import("asyncio"); \ + /* Return type must be a coroutine to allow calling asyncio.create_task() */ \ + if (!asyncio_module.attr("iscoroutine")(o_coro).cast()) \ + { \ + pybind11::pybind11_fail(MRC_CONCAT_STR("Return value from overriden async function " \ + << PYBIND11_STRINGIFY(cname) << "::" << name \ + << " did not return a coroutine. Returned: " \ + << pybind11::str(o_coro).cast())); \ + } \ + auto o_task = asyncio_module.attr("create_task")(o_coro); \ + mrc::pymrc::PyHolder o_result; \ + { \ + pybind11::gil_scoped_release nogil; \ + o_result = co_await mrc::pymrc::coro::PyTaskToCppAwaitable(std::move(o_task)); \ + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL after returning from co_await"; \ + } \ + if (pybind11::detail::cast_is_temporary_value_reference::value) \ + { \ + static pybind11::detail::override_caster_t caster; \ + co_return pybind11::detail::cast_ref(std::move(o_result), caster); \ + } \ + co_return pybind11::detail::cast_safe(std::move(o_result)); \ + } \ + } while (false) +// ====== OVERRIDE COROUTINE IMPL====== + +// ====== OVERRIDE COROUTINE ====== +#define MRC_PYBIND11_OVERRIDE_CORO_NAME(ret_type, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + return cname::fn(__VA_ARGS__); \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO(ret_type, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) +// ====== OVERRIDE COROUTINE ====== + +// ====== OVERRIDE COROUTINE PURE====== +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_NAME(ret_type, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(cname), name); \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO_PURE(ret_type, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) +// ====== OVERRIDE COROUTINE PURE====== + +// ====== OVERRIDE COROUTINE PURE TEMPLATE====== +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE_NAME(ret_type, abstract_cname, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + if constexpr (std::is_same_v) \ + { \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(abstract_cname), name); \ + } \ + else \ + { \ + co_return co_await cname::fn(__VA_ARGS__); \ + } \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE(ret_type, abstract_cname, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE_NAME(PYBIND11_TYPE(ret_type), \ + PYBIND11_TYPE(abstract_cname), \ + PYBIND11_TYPE(cname), \ + #fn, \ + fn, \ + __VA_ARGS__) +// ====== OVERRIDE COROUTINE PURE TEMPLATE====== + +} // namespace mrc::pymrc::coro + +// NOLINTNEXTLINE(modernize-concat-nested-namespaces) +namespace PYBIND11_NAMESPACE { +namespace detail { + +/** + * @brief Provides a type caster for converting a C++ coroutine to a python awaitable. Include this file in any pybind11 + * module to automatically convert the types. Allows for converting arguments and return values. + * + * @tparam ReturnT The return type of the coroutine + */ +template +struct type_caster> +{ + public: + /** + * This macro establishes the name 'inty' in + * function signatures and declares a local variable + * 'value' of type inty + */ + PYBIND11_TYPE_CASTER(mrc::coroutines::Task, _("typing.Awaitable[") + make_caster::name + _("]")); + + /** + * Conversion part 1 (Python->C++): convert a PyObject into a inty + * instance or return false upon failure. The second argument + * indicates whether implicit conversions should be applied. + */ + bool load(handle src, bool convert) + { + if (!src || src.is_none()) + { + return false; + } + + if (!PyCoro_CheckExact(src.ptr())) + { + return false; + } + + auto cpp_coro = [](mrc::pymrc::PyHolder py_task) -> mrc::coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + // Always assume we are resuming without the GIL + pybind11::gil_scoped_acquire gil; + + auto asyncio_task = pybind11::module_::import("asyncio").attr("create_task")(py_task); + + mrc::pymrc::PyHolder py_result; + { + // Release the GIL before awaiting + pybind11::gil_scoped_release nogil; + + py_result = co_await mrc::pymrc::coro::PyTaskToCppAwaitable(std::move(asyncio_task)); + } + + // Now cast back to the C++ type + if (pybind11::detail::cast_is_temporary_value_reference::value) + { + static pybind11::detail::override_caster_t caster; + co_return pybind11::detail::cast_ref(std::move(py_result), caster); + } + co_return pybind11::detail::cast_safe(std::move(py_result)); + }; + + value = cpp_coro(pybind11::reinterpret_borrow(std::move(src))); + + return true; + } + + /** + * Conversion part 2 (C++ -> Python): convert an inty instance into + * a Python object. The second and third arguments are used to + * indicate the return value policy and parent object (for + * ``return_value_policy::reference_internal``) and are generally + * ignored by implicit casters. + */ + static handle cast(mrc::coroutines::Task src, return_value_policy policy, handle parent) + { + // Wrap the object in a CppToPyAwaitable + std::shared_ptr awaitable = + std::make_shared(std::move(src)); + + // Convert the object to a python object + auto py_awaitable = pybind11::cast(std::move(awaitable)); + + return py_awaitable.release(); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/python/mrc/_pymrc/include/pymrc/segment.hpp b/python/mrc/_pymrc/include/pymrc/segment.hpp index 2da23cace..94bce476e 100644 --- a/python/mrc/_pymrc/include/pymrc/segment.hpp +++ b/python/mrc/_pymrc/include/pymrc/segment.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -143,11 +143,6 @@ class BuilderProxy const std::string& name, pybind11::function gen_factory); - static std::shared_ptr make_source( - mrc::segment::IBuilder& self, - const std::string& name, - const std::function& f); - static std::shared_ptr make_source_component(mrc::segment::IBuilder& self, const std::string& name, pybind11::iterator source_iterator); diff --git a/python/mrc/_pymrc/include/pymrc/subscriber.hpp b/python/mrc/_pymrc/include/pymrc/subscriber.hpp index 5a079906f..6cc793dd5 100644 --- a/python/mrc/_pymrc/include/pymrc/subscriber.hpp +++ b/python/mrc/_pymrc/include/pymrc/subscriber.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -47,6 +47,12 @@ class SubscriberProxy static bool is_subscribed(PyObjectSubscriber* self); }; +class SubscriptionProxy +{ + public: + static bool is_subscribed(PySubscription* self); +}; + class ObservableProxy { public: diff --git a/python/mrc/_pymrc/include/pymrc/types.hpp b/python/mrc/_pymrc/include/pymrc/types.hpp index fcaa9942b..5446ec28a 100644 --- a/python/mrc/_pymrc/include/pymrc/types.hpp +++ b/python/mrc/_pymrc/include/pymrc/types.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,9 +21,12 @@ #include "mrc/segment/object.hpp" +#include #include -#include +#include // for function +#include +#include namespace mrc::pymrc { @@ -37,4 +40,16 @@ using PyNode = mrc::segment::ObjectProperties; using PyObjectOperateFn = std::function; // NOLINTEND(readability-identifier-naming) +using python_map_t = std::map; + +/** + * @brief Unserializable handler function type, invoked by `cast_from_pyobject` when an object cannot be serialized to + * JSON. Implementations should return a valid json object, or throw an exception if the object cannot be serialized. + * @param source : pybind11 object + * @param path : string json path to object + * @return nlohmann::json. + */ +using unserializable_handler_fn_t = + std::function; + } // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp index f6f5c3c30..83e243d63 100644 --- a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp +++ b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp b/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp new file mode 100644 index 000000000..8c3db1aab --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp @@ -0,0 +1,214 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/types.hpp" // for python_map_t & unserializable_handler_fn_t + +#include +#include // for PYBIND11_EXPORT, pybind11::object, type_caster + +#include // for size_t +#include // for map +#include // for string +#include // for move +// IWYU pragma: no_include +// IWYU pragma: no_include + +namespace mrc::pymrc { + +#pragma GCC visibility push(default) + +/** + * @brief Immutable container for holding Python values as JSON objects if possible, and as pybind11::object otherwise. + * The container can be copied and moved, but the underlying JSON object is immutable. + **/ +class PYBIND11_EXPORT JSONValues +{ + public: + JSONValues(); + JSONValues(pybind11::object values); + JSONValues(nlohmann::json values); + + JSONValues(const JSONValues& other) = default; + JSONValues(JSONValues&& other) = default; + ~JSONValues() = default; + + JSONValues& operator=(const JSONValues& other) = default; + JSONValues& operator=(JSONValues&& other) = default; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided Python object. If `value` is + * serializable as JSON it will be stored as JSON, otherwise it will be stored as-is. + * @param path The path in the JSON object where the value should be set. + * @param value The Python object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, const pybind11::object& value) const; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided JSON object. + * @param path The path in the JSON object where the value should be set. + * @param value The JSON object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, nlohmann::json value) const; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided JSONValues object. + * @param path The path in the JSON object where the value should be set. + * @param value The JSONValues object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, const JSONValues& value) const; + + /** + * @brief Returns the number of unserializable Python objects. + * @return The number of unserializable Python objects. + */ + std::size_t num_unserializable() const; + + /** + * @brief Checks if there are any unserializable Python objects. + * @return True if there are unserializable Python objects, false otherwise. + */ + bool has_unserializable() const; + + /** + * @brief Convert to a Python object. + * @return The Python object representation of the values. + */ + pybind11::object to_python() const; + + /** + * @brief Returns a constant reference to the underlying JSON object. Any unserializable Python objects, will be + * represented in the JSON object with a string place-holder with the value `"**pymrc_placeholder"`. + * @return A constant reference to the JSON object. + */ + nlohmann::json::const_reference view_json() const; + + /** + * @brief Converts the JSON object to a JSON object. If any unserializable Python objects are present, the + * `unserializable_handler_fn` will be invoked to handle the object. + * @param unserializable_handler_fn Optional function to handle unserializable objects. + * @return The JSON string representation of the JSON object. + */ + nlohmann::json to_json(unserializable_handler_fn_t unserializable_handler_fn) const; + + /** + * @brief Converts a Python object to a JSON string. Convienence function that matches the + * `unserializable_handler_fn_t` signature. Convienent for use with `to_json` and `get_json`. + * @param obj The Python object to convert. + * @param path The path in the JSON object where the value should be set. + * @return The JSON string representation of the Python object. + */ + static nlohmann::json stringify(const pybind11::object& obj, const std::string& path); + + /** + * @brief Returns the object at the specified path as a Python object. + * @param path Path to the specified object. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return Python representation of the object at the specified path. + */ + pybind11::object get_python(const std::string& path) const; + + /** + * @brief Returns the object at the specified path. If the object is an unserializable Python object the + * `unserializable_handler_fn` will be invoked. + * @param path Path to the specified object. + * @param unserializable_handler_fn Function to handle unserializable objects. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return The JSON object at the specified path. + */ + nlohmann::json get_json(const std::string& path, unserializable_handler_fn_t unserializable_handler_fn) const; + + /** + * @brief Return a new JSONValues object with the value at the specified path. + * @param path Path to the specified object. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return The value at the specified path. + */ + JSONValues operator[](const std::string& path) const; + + private: + JSONValues(nlohmann::json&& values, python_map_t&& py_objects); + nlohmann::json unserializable_handler(const pybind11::object& obj, const std::string& path); + + nlohmann::json m_serialized_values; + python_map_t m_py_objects; +}; + +#pragma GCC visibility pop +} // namespace mrc::pymrc + +/****** Pybind11 caster ******************/ + +// NOLINTNEXTLINE(modernize-concat-nested-namespaces) +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster +{ + public: + /** + * This macro establishes a local variable 'value' of type JSONValues + */ + PYBIND11_TYPE_CASTER(mrc::pymrc::JSONValues, _("object")); + + /** + * Conversion part 1 (Python->C++): convert a PyObject into JSONValues + * instance or return false upon failure. The second argument + * indicates whether implicit conversions should be applied. + */ + bool load(handle src, bool convert) + { + if (!src) + { + return false; + } + + if (src.is_none()) + { + value = mrc::pymrc::JSONValues(); + } + else + { + value = std::move(mrc::pymrc::JSONValues(pybind11::reinterpret_borrow(src))); + } + + return true; + } + + /** + * Conversion part 2 (C++ -> Python): convert a JSONValues instance into + * a Python object. The second and third arguments are used to + * indicate the return value policy and parent object (for + * ``return_value_policy::reference_internal``) and are generally + * ignored by implicit casters. + */ + static handle cast(mrc::pymrc::JSONValues src, return_value_policy policy, handle parent) + { + return src.to_python().release(); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp b/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp index 2721eb5db..68c106064 100644 --- a/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp +++ b/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp @@ -17,6 +17,8 @@ #pragma once +#include "pymrc/types.hpp" + #include #include @@ -95,7 +97,7 @@ class __attribute__((visibility("default"))) PythonObjectCache */ void atexit_callback(); - std::map m_object_cache; + python_map_t m_object_cache; }; #pragma GCC visibility pop diff --git a/python/mrc/_pymrc/include/pymrc/utils.hpp b/python/mrc/_pymrc/include/pymrc/utils.hpp index f80838c3d..714605e6a 100644 --- a/python/mrc/_pymrc/include/pymrc/utils.hpp +++ b/python/mrc/_pymrc/include/pymrc/utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,8 @@ #pragma once +#include "pymrc/types.hpp" + #include #include #include @@ -31,8 +33,25 @@ namespace mrc::pymrc { #pragma GCC visibility push(default) pybind11::object cast_from_json(const nlohmann::json& source); + +/** + * @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, a pybind11::type_error + * exception be thrown. + * @param source : pybind11 object + * @return nlohmann::json. + */ nlohmann::json cast_from_pyobject(const pybind11::object& source); +/** + * @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, the unserializable_handler_fn + * will be invoked to handle the object. + * @param source : pybind11 object + * @param unserializable_handler_fn : unserializable_handler_fn_t + * @return nlohmann::json. + */ +nlohmann::json cast_from_pyobject(const pybind11::object& source, + unserializable_handler_fn_t unserializable_handler_fn); + void import_module_object(pybind11::module_&, const std::string&, const std::string&); void import_module_object(pybind11::module_& dest, const pybind11::module_& mod); @@ -54,6 +73,13 @@ void from_import_as(pybind11::module_& dest, const std::string& from, const std: */ const std::type_info* cpptype_info_from_object(pybind11::object& obj); +/** + * @brief Given a pybind11 object, return the Python type name essentially the same as `str(type(obj))` + * @param obj : pybind11 object + * @return std::string. + */ +std::string get_py_type_name(const pybind11::object& obj); + void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level = 1); #pragma GCC visibility pop diff --git a/python/mrc/_pymrc/src/coro.cpp b/python/mrc/_pymrc/src/coro.cpp new file mode 100644 index 000000000..8bb57cb84 --- /dev/null +++ b/python/mrc/_pymrc/src/coro.cpp @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/coro.hpp" + +namespace mrc::pymrc::coro { + +namespace py = pybind11; + +StopIteration::~StopIteration() = default; + +} // namespace mrc::pymrc::coro diff --git a/python/mrc/_pymrc/src/executor.cpp b/python/mrc/_pymrc/src/executor.cpp index a62e2c1e7..8e1ad5c67 100644 --- a/python/mrc/_pymrc/src/executor.cpp +++ b/python/mrc/_pymrc/src/executor.cpp @@ -25,7 +25,6 @@ #include "mrc/types.hpp" #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/logging.cpp b/python/mrc/_pymrc/src/logging.cpp index 1150340e8..73455caa5 100644 --- a/python/mrc/_pymrc/src/logging.cpp +++ b/python/mrc/_pymrc/src/logging.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -88,7 +88,8 @@ void log(const std::string& msg, int py_level, const std::string& filename, int LOG(WARNING) << "Log called prior to calling init_logging, initialized with defaults"; } - google::LogMessage(filename.c_str(), line, static_cast(py_level_to_mrc(py_level))).stream() << msg; + google::LogMessage(filename.c_str(), line, static_cast(py_level_to_mrc(py_level))).stream() + << msg; } } // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/src/module_registry.cpp b/python/mrc/_pymrc/src/module_registry.cpp index 424eb2b68..bedcf7ebf 100644 --- a/python/mrc/_pymrc/src/module_registry.cpp +++ b/python/mrc/_pymrc/src/module_registry.cpp @@ -28,7 +28,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/module_wrappers/pickle.cpp b/python/mrc/_pymrc/src/module_wrappers/pickle.cpp index fd6e99290..378fa83e2 100644 --- a/python/mrc/_pymrc/src/module_wrappers/pickle.cpp +++ b/python/mrc/_pymrc/src/module_wrappers/pickle.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include diff --git a/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp b/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp index 9a4106f76..7eac9864f 100644 --- a/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp +++ b/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp @@ -20,10 +20,9 @@ #include "pymrc/utilities/object_cache.hpp" #include -#include +#include // IWYU pragma: keep #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/segment.cpp b/python/mrc/_pymrc/src/segment.cpp index 4e60e63e4..ec78dc927 100644 --- a/python/mrc/_pymrc/src/segment.cpp +++ b/python/mrc/_pymrc/src/segment.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +28,9 @@ #include "mrc/channel/status.hpp" #include "mrc/edge/edge_builder.hpp" #include "mrc/node/port_registry.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/runnable/context.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" #include #include @@ -44,15 +41,14 @@ #include #include #include -#include #include #include +#include #include #include #include #include #include -#include // IWYU thinks we need array for py::print // IWYU pragma: no_include @@ -238,6 +234,11 @@ std::shared_ptr build_source(mrc::segment::IBuil { subscriber.on_next(std::move(next_val)); } + else + { + DVLOG(10) << ctx.info() << " Source unsubscribed. Stopping"; + break; + } } } catch (const std::exception& e) @@ -257,6 +258,61 @@ std::shared_ptr build_source(mrc::segment::IBuil return self.construct_object>(name, wrapper); } +class SubscriberFuncWrapper : public mrc::pymrc::PythonSource +{ + public: + using base_t = mrc::pymrc::PythonSource; + using typename base_t::source_type_t; + using typename base_t::subscriber_fn_t; + + SubscriberFuncWrapper(py::function gen_factory) : PythonSource(build()), m_gen_factory{std::move(gen_factory)} {} + + private: + subscriber_fn_t build() + { + return [this](rxcpp::subscriber subscriber) { + auto& ctx = runnable::Context::get_runtime_context(); + + try + { + DVLOG(10) << ctx.info() << " Starting source"; + py::gil_scoped_acquire gil; + PySubscription subscription = subscriber.get_subscription(); + py::object py_sub = py::cast(subscription); + auto py_iter = m_gen_factory.operator()(std::move(py_sub)); + PyIteratorWrapper iter_wrapper{std::move(py_iter)}; + + for (auto next_val : iter_wrapper) + { + // Only send if its subscribed. Very important to ensure the object has been moved! + if (subscriber.is_subscribed()) + { + py::gil_scoped_release no_gil; + subscriber.on_next(std::move(next_val)); + } + else + { + DVLOG(10) << ctx.info() << " Source unsubscribed. Stopping"; + break; + } + } + + } catch (const std::exception& e) + { + LOG(ERROR) << ctx.info() << "Error occurred in source. Error msg: " << e.what(); + + subscriber.on_error(std::current_exception()); + return; + } + subscriber.on_completed(); + + DVLOG(10) << ctx.info() << " Source complete"; + }; + } + + PyFuncWrapper m_gen_factory{}; +}; + std::shared_ptr build_source_component(mrc::segment::IBuilder& self, const std::string& name, PyIteratorWrapper iter_wrapper) @@ -305,6 +361,32 @@ std::shared_ptr BuilderProxy::make_source(mrc::s const std::string& name, py::function gen_factory) { + // Determine if the gen_factory is expecting to receive a subscription object + auto inspect_mod = py::module::import("inspect"); + auto signature = inspect_mod.attr("signature")(gen_factory); + auto params = signature.attr("parameters"); + auto num_params = py::len(params); + bool expects_subscription = false; + + if (num_params > 0) + { + // We know there is at least one parameter. Check if the first parameter is a subscription object + // Note, when we receive a function that has been bound with `functools.partial(fn, arg1=some_value)`, the + // parameter is still visible in the signature of the partial object. + auto mrc_mod = py::module::import("mrc"); + auto param_values = params.attr("values")(); + auto first_param = py::iter(param_values); + auto type_hint = py::object((*first_param).attr("annotation")); + expects_subscription = (type_hint.is(mrc_mod.attr("Subscription")) || + type_hint.equal(py::str("mrc.Subscription")) || + type_hint.equal(py::str("Subscription"))); + } + + if (expects_subscription) + { + return self.construct_object(name, std::move(gen_factory)); + } + return build_source(self, name, PyIteratorWrapper(std::move(gen_factory))); } diff --git a/python/mrc/_pymrc/src/subscriber.cpp b/python/mrc/_pymrc/src/subscriber.cpp index 35f795175..6d94efff9 100644 --- a/python/mrc/_pymrc/src/subscriber.cpp +++ b/python/mrc/_pymrc/src/subscriber.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -116,6 +115,12 @@ bool SubscriberProxy::is_subscribed(PyObjectSubscriber* self) return self->is_subscribed(); } +bool SubscriptionProxy::is_subscribed(PySubscription* self) +{ + // No GIL here + return self->is_subscribed(); +} + PySubscription ObservableProxy::subscribe(PyObjectObservable* self, PyObjectObserver& observer) { // Call the internal subscribe function @@ -128,12 +133,6 @@ PySubscription ObservableProxy::subscribe(PyObjectObservable* self, PyObjectSubs return self->subscribe(subscriber); } -template -PyObjectObservable pipe_ops(const PyObjectObservable* self, OpsT&&... ops) -{ - return (*self | ... | ops); -} - PyObjectObservable ObservableProxy::pipe(const PyObjectObservable* self, py::args args) { std::vector operators; @@ -150,66 +149,19 @@ PyObjectObservable ObservableProxy::pipe(const PyObjectObservable* self, py::arg operators.emplace_back(op.get_operate_fn()); } - switch (operators.size()) + if (operators.empty()) + { + throw std::runtime_error("pipe() must be given at least one argument"); + } + + auto result = *self | operators[0]; + + for (auto i = 1; i < operators.size(); i++) { - case 1: - return pipe_ops(self, operators[0]); - case 2: - return pipe_ops(self, operators[0], operators[1]); - case 3: - return pipe_ops(self, operators[0], operators[1], operators[2]); - case 4: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3]); - case 5: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3], operators[4]); - case 6: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3], operators[4], operators[5]); - case 7: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6]); - case 8: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7]); - case 9: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7], - operators[8]); - case 10: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7], - operators[8], - operators[9]); - default: - // Not supported error - throw std::runtime_error("pipe() only supports up 10 arguments. Please use another pipe() to use more"); + result = result | operators[i]; } + + return result; } } // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/src/utilities/json_values.cpp b/python/mrc/_pymrc/src/utilities/json_values.cpp new file mode 100644 index 000000000..ebc8061f5 --- /dev/null +++ b/python/mrc/_pymrc/src/utilities/json_values.cpp @@ -0,0 +1,299 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/utilities/json_values.hpp" + +#include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utils.hpp" + +#include "mrc/utils/string_utils.hpp" // for MRC_CONCAT_STR, split_string_to_array + +#include +#include + +#include // for function +#include // for next +#include // for map +#include // for operator<< & stringstream +#include // for runtime_error +#include // for move +#include // for vector + +namespace py = pybind11; +using namespace std::string_literals; + +namespace { + +std::vector split_path(const std::string& path) +{ + return mrc::split_string_to_vector(path, "/"s); +} + +struct PyFoundObject +{ + py::object obj; + py::object index = py::none(); +}; + +PyFoundObject find_object_at_path(py::object& obj, + std::vector::const_iterator path, + std::vector::const_iterator path_end) +{ + // Terminal case + const auto& path_str = *path; + if (path_str.empty()) + { + return PyFoundObject(obj); + } + + // Nested object, since obj is a de-serialized python object the only valid container types will be dict and + // list. There are one of two possibilities here: + // 1. The next_path is terminal and we should assign value to the container + // 2. The next_path is not terminal and we should recurse into the container + auto next_path = std::next(path); + + if (py::isinstance(obj) || py::isinstance(obj)) + { + py::object index; + if (py::isinstance(obj)) + { + index = py::cast(path_str); + } + else + { + index = py::cast(std::stoul(path_str)); + } + + if (next_path == path_end) + { + return PyFoundObject{obj, index}; + } + + py::object next_obj = obj[index]; + return find_object_at_path(next_obj, next_path, path_end); + } + + throw std::runtime_error("Invalid path"); +} + +PyFoundObject find_object_at_path(py::object& obj, const std::string& path) +{ + auto path_parts = split_path(path); + + // Since our paths always begin with a '/', the first element will always be empty in the case where path="/" + // path_parts will be {"", ""} and we can skip the first element + auto itr = path_parts.cbegin(); + return find_object_at_path(obj, std::next(itr), path_parts.cend()); +} + +void patch_object(py::object& obj, const std::string& path, const py::object& value) +{ + if (path == "/") + { + // Special case for the root object since find_object_at_path will return a copy not a reference we need to + // perform the assignment here + obj = value; + } + else + { + auto found = find_object_at_path(obj, path); + DCHECK(!found.index.is_none()); + found.obj[found.index] = value; + } +} + +std::string validate_path(const std::string& path) +{ + if (path.empty() || path[0] != '/') + { + return "/" + path; + } + + return path; +} +} // namespace + +namespace mrc::pymrc { +JSONValues::JSONValues() : JSONValues(nlohmann::json()) {} + +JSONValues::JSONValues(py::object values) +{ + AcquireGIL gil; + m_serialized_values = cast_from_pyobject(values, [this](const py::object& source, const std::string& path) { + return this->unserializable_handler(source, path); + }); +} + +JSONValues::JSONValues(nlohmann::json values) : m_serialized_values(std::move(values)) {} + +JSONValues::JSONValues(nlohmann::json&& values, python_map_t&& py_objects) : + m_serialized_values(std::move(values)), + m_py_objects(std::move(py_objects)) +{} + +std::size_t JSONValues::num_unserializable() const +{ + return m_py_objects.size(); +} + +bool JSONValues::has_unserializable() const +{ + return !m_py_objects.empty(); +} + +py::object JSONValues::to_python() const +{ + AcquireGIL gil; + py::object results = cast_from_json(m_serialized_values); + for (const auto& [path, obj] : m_py_objects) + { + DCHECK(path[0] == '/'); + DVLOG(10) << "Restoring object at path: " << path; + patch_object(results, path, obj); + } + + return results; +} + +nlohmann::json::const_reference JSONValues::view_json() const +{ + return m_serialized_values; +} + +nlohmann::json JSONValues::to_json(unserializable_handler_fn_t unserializable_handler_fn) const +{ + // start with a copy + nlohmann::json json_doc = m_serialized_values; + nlohmann::json patches = nlohmann::json::array(); + for (const auto& [path, obj] : m_py_objects) + { + nlohmann::json patch{{"op", "replace"}, {"path", path}, {"value", unserializable_handler_fn(obj, path)}}; + patches.emplace_back(std::move(patch)); + } + + if (!patches.empty()) + { + json_doc.patch_inplace(patches); + } + + return json_doc; +} + +JSONValues JSONValues::operator[](const std::string& path) const +{ + auto validated_path = validate_path(path); + + if (validated_path == "/") + { + return *this; // Return a copy of the object + } + + nlohmann::json::json_pointer node_json_ptr(validated_path); + if (!m_serialized_values.contains(node_json_ptr)) + { + throw std::runtime_error(MRC_CONCAT_STR("Path: '" << path << "' not found in json")); + } + + // take a copy of the sub-object + nlohmann::json value = m_serialized_values[node_json_ptr]; + python_map_t py_objects; + for (const auto& [py_path, obj] : m_py_objects) + { + if (py_path.find(validated_path) == 0) + { + py_objects[py_path] = obj; + } + } + + return {std::move(value), std::move(py_objects)}; +} + +pybind11::object JSONValues::get_python(const std::string& path) const +{ + return (*this)[path].to_python(); +} + +nlohmann::json JSONValues::get_json(const std::string& path, + unserializable_handler_fn_t unserializable_handler_fn) const +{ + return (*this)[path].to_json(unserializable_handler_fn); +} + +nlohmann::json JSONValues::stringify(const pybind11::object& obj, const std::string& path) +{ + AcquireGIL gil; + return py::str(obj).cast(); +} + +JSONValues JSONValues::set_value(const std::string& path, const pybind11::object& value) const +{ + AcquireGIL gil; + py::object py_obj = this->to_python(); + patch_object(py_obj, validate_path(path), value); + return {py_obj}; +} + +JSONValues JSONValues::set_value(const std::string& path, nlohmann::json value) const +{ + // Two possibilities: + // 1) We don't have any unserializable objects, in which case we can just update the JSON object + // 2) We do have unserializable objects, in which case we need to cast value to python and call the python + // version of set_value + + if (!has_unserializable()) + { + // The add operation will update an existing value if it exists, or add a new value if it does not + // ref: https://datatracker.ietf.org/doc/html/rfc6902#section-4.1 + nlohmann::json patch{{"op", "add"}, {"path", validate_path(path)}, {"value", value}}; + nlohmann::json patches = nlohmann::json::array({std::move(patch)}); + auto new_values = m_serialized_values.patch(std::move(patches)); + return {std::move(new_values)}; + } + + AcquireGIL gil; + py::object py_obj = cast_from_json(value); + return set_value(path, py_obj); +} + +JSONValues JSONValues::set_value(const std::string& path, const JSONValues& value) const +{ + if (value.has_unserializable()) + { + AcquireGIL gil; + py::object py_obj = value.to_python(); + return set_value(path, py_obj); + } + + return set_value(path, value.to_json([](const py::object& source, const std::string& path) { + DLOG(FATAL) << "Should never be called"; + return nlohmann::json(); // unreachable but needed to satisfy the signature + })); +} + +nlohmann::json JSONValues::unserializable_handler(const py::object& obj, const std::string& path) +{ + /* We don't know how to serialize the Object, throw it into m_py_objects and return a place-holder */ + + // Take a non-const copy of the object + py::object non_const_copy = obj; + DVLOG(10) << "Storing unserializable object at path: " << path; + m_py_objects[path] = std::move(non_const_copy); + + return "**pymrc_placeholder"s; +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/src/utilities/object_cache.cpp b/python/mrc/_pymrc/src/utilities/object_cache.cpp index 604a21200..574afc2a2 100644 --- a/python/mrc/_pymrc/src/utilities/object_cache.cpp +++ b/python/mrc/_pymrc/src/utilities/object_cache.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/utils.cpp b/python/mrc/_pymrc/src/utils.cpp index ba6a70584..22379b594 100644 --- a/python/mrc/_pymrc/src/utils.cpp +++ b/python/mrc/_pymrc/src/utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,9 @@ #include "pymrc/utils.hpp" +#include "pymrc/utilities/acquire_gil.hpp" + +#include #include #include #include @@ -25,11 +28,12 @@ #include #include +#include // for function +#include #include #include namespace mrc::pymrc { - namespace py = pybind11; using nlohmann::json; @@ -72,6 +76,18 @@ const std::type_info* cpptype_info_from_object(py::object& obj) return nullptr; } +std::string get_py_type_name(const pybind11::object& obj) +{ + if (!obj) + { + // calling py::type::of on a null object will trigger an abort + return ""; + } + + const auto py_type = py::type::of(obj); + return py_type.attr("__name__").cast(); +} + py::object cast_from_json(const json& source) { if (source.is_null()) @@ -123,7 +139,9 @@ py::object cast_from_json(const json& source) // throw std::runtime_error("Unsupported conversion type."); } -json cast_from_pyobject(const py::object& source) +json cast_from_pyobject_impl(const py::object& source, + unserializable_handler_fn_t unserializable_handler_fn, + const std::string& parent_path = "") { // Dont return via initializer list with JSON. It performs type deduction and gives different results // NOLINTBEGIN(modernize-return-braced-init-list) @@ -131,50 +149,89 @@ json cast_from_pyobject(const py::object& source) { return json(); } + if (py::isinstance(source)) { const auto py_dict = source.cast(); auto json_obj = json::object(); for (const auto& p : py_dict) { - json_obj[py::cast(p.first)] = cast_from_pyobject(p.second.cast()); + std::string key{p.first.cast()}; + std::string path{parent_path + "/" + key}; + json_obj[key] = cast_from_pyobject_impl(p.second.cast(), unserializable_handler_fn, path); } return json_obj; } + if (py::isinstance(source) || py::isinstance(source)) { const auto py_list = source.cast(); auto json_arr = json::array(); for (const auto& p : py_list) { - json_arr.push_back(cast_from_pyobject(p.cast())); + std::string path{parent_path + "/" + std::to_string(json_arr.size())}; + json_arr.push_back(cast_from_pyobject_impl(p.cast(), unserializable_handler_fn, path)); } return json_arr; } + if (py::isinstance(source)) { return json(py::cast(source)); } + if (py::isinstance(source)) { return json(py::cast(source)); } + if (py::isinstance(source)) { return json(py::cast(source)); } + if (py::isinstance(source)) { return json(py::cast(source)); } - // else unsupported return null - return json(); + // else unsupported return throw a type error + { + AcquireGIL gil; + std::ostringstream error_message; + std::string path{parent_path}; + if (path.empty()) + { + path = "/"; + } + + if (unserializable_handler_fn != nullptr) + { + return unserializable_handler_fn(source, path); + } + + error_message << "Object (" << py::str(source).cast() << ") of type: " << get_py_type_name(source) + << " at path: " << path << " is not JSON serializable"; + + DVLOG(5) << error_message.str(); + throw py::type_error(error_message.str()); + } + // NOLINTEND(modernize-return-braced-init-list) } +json cast_from_pyobject(const py::object& source, unserializable_handler_fn_t unserializable_handler_fn) +{ + return cast_from_pyobject_impl(source, unserializable_handler_fn); +} + +json cast_from_pyobject(const py::object& source) +{ + return cast_from_pyobject_impl(source, nullptr); +} + void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level) { PyErr_WarnEx(PyExc_DeprecationWarning, deprecation_message.c_str(), stack_level); diff --git a/python/mrc/_pymrc/src/watchers.cpp b/python/mrc/_pymrc/src/watchers.cpp index d474d7ae4..114bc6dac 100644 --- a/python/mrc/_pymrc/src/watchers.cpp +++ b/python/mrc/_pymrc/src/watchers.cpp @@ -24,8 +24,8 @@ #include "mrc/benchmarking/tracer.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" #include #include @@ -34,11 +34,9 @@ #include #include -#include #include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/_pymrc/tests/CMakeLists.txt b/python/mrc/_pymrc/tests/CMakeLists.txt index 4ac354a78..c056bb2cc 100644 --- a/python/mrc/_pymrc/tests/CMakeLists.txt +++ b/python/mrc/_pymrc/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +17,14 @@ list(APPEND CMAKE_MESSAGE_CONTEXT "tests") find_package(pybind11 REQUIRED) +add_subdirectory(coro) + # Keep all source files sorted!!! add_executable(test_pymrc + test_asyncio_runnable.cpp test_codable_pyobject.cpp test_executor.cpp + test_json_values.cpp test_main.cpp test_object_cache.cpp test_pickle_wrapper.cpp diff --git a/python/mrc/_pymrc/tests/coro/CMakeLists.txt b/python/mrc/_pymrc/tests/coro/CMakeLists.txt new file mode 100644 index 000000000..788d04832 --- /dev/null +++ b/python/mrc/_pymrc/tests/coro/CMakeLists.txt @@ -0,0 +1,29 @@ +# ============================================================================= +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +list(APPEND CMAKE_MESSAGE_CONTEXT "coro") + +set(MODULE_SOURCE_FILES) + +# Add the module file +list(APPEND MODULE_SOURCE_FILES module.cpp) + +# Create the python module +mrc_add_pybind11_module(coro + INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + SOURCE_FILES ${MODULE_SOURCE_FILES} + LINK_TARGETS mrc::pymrc +) + +list(POP_BACK CMAKE_MESSAGE_CONTEXT) diff --git a/python/mrc/_pymrc/tests/coro/module.cpp b/python/mrc/_pymrc/tests/coro/module.cpp new file mode 100644 index 000000000..c5332c78c --- /dev/null +++ b/python/mrc/_pymrc/tests/coro/module.cpp @@ -0,0 +1,70 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +mrc::coroutines::Task subtract(int a, int b) +{ + co_return a - b; +} + +mrc::coroutines::Task call_fib_async(mrc::pymrc::PyHolder fib, int value, int minus) +{ + auto result = co_await subtract(value, minus); + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fib, auto result) { + pybind11::gil_scoped_acquire acquire; + return fib(result); + }(fib, result)); +} + +mrc::coroutines::Task raise_at_depth_async(mrc::pymrc::PyHolder fn, int depth) +{ + if (depth <= 0) + { + throw std::runtime_error("depth reached zero in c++"); + } + + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn, auto depth) { + pybind11::gil_scoped_acquire acquire; + return fn(depth - 1); + }(fn, depth)); +} + +mrc::coroutines::Task call_async(mrc::pymrc::PyHolder fn) +{ + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn) { + pybind11::gil_scoped_acquire acquire; + return fn(); + }(fn)); +} + +PYBIND11_MODULE(coro, _module) +{ + pybind11::module_::import("mrc.core.coro"); // satisfies automatic type conversions for tasks + + _module.def("call_fib_async", &call_fib_async); + _module.def("raise_at_depth_async", &raise_at_depth_async); + _module.def("call_async", &call_async); +} diff --git a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp new file mode 100644 index 000000000..997ae978e --- /dev/null +++ b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp @@ -0,0 +1,335 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/asyncio_runnable.hpp" +#include "pymrc/coro.hpp" +#include "pymrc/executor.hpp" +#include "pymrc/pipeline.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_source.hpp" +#include "mrc/options/engine_groups.hpp" +#include "mrc/options/options.hpp" +#include "mrc/options/topology.hpp" +#include "mrc/runnable/types.hpp" +#include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { +class Scheduler; +} // namespace mrc::coroutines + +namespace py = pybind11; +namespace pymrc = mrc::pymrc; +using namespace std::string_literals; +using namespace py::literals; + +class __attribute__((visibility("default"))) TestAsyncioRunnable : public ::testing::Test +{ + public: + static void SetUpTestSuite() + { + m_interpreter = std::make_unique(); + pybind11::gil_scoped_acquire acquire; + pybind11::module_::import("mrc.core.coro"); + } + + static void TearDownTestSuite() + { + m_interpreter.reset(); + } + + private: + static std::unique_ptr m_interpreter; +}; + +std::unique_ptr TestAsyncioRunnable::m_interpreter; + +class __attribute__((visibility("default"))) PythonCallbackAsyncioRunnable : public pymrc::AsyncioRunnable +{ + public: + PythonCallbackAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {} + + mrc::coroutines::AsyncGenerator on_data(int&& value, std::shared_ptr on) override + { + py::gil_scoped_acquire acquire; + + auto coroutine = m_operation(py::cast(value)); + + pymrc::PyObjectHolder result; + { + py::gil_scoped_release release; + + result = co_await pymrc::coro::PyTaskToCppAwaitable(std::move(coroutine)); + } + + auto result_casted = py::cast(result); + + py::gil_scoped_release release; + + co_yield result_casted; + }; + + private: + pymrc::PyObjectHolder m_operation; +}; + +TEST_F(TestAsyncioRunnable, UseAsyncioTasks) +{ + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + import asyncio + await asyncio.sleep(0) + return value * 2 + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + exec.join(); + + EXPECT_EQ(counter, 60); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioGeneratorThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + yield value + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + raise RuntimeError("oops") + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +template +auto run_operation(OperationT& operation) -> mrc::coroutines::Task +{ + auto stop_source = std::stop_source(); + + auto coro = [](auto& operation, auto stop_source) -> mrc::coroutines::Task { + try + { + auto value = co_await operation(); + stop_source.request_stop(); + co_return value; + } catch (...) + { + stop_source.request_stop(); + throw; + } + }(operation, stop_source); + + coro.resume(); + + while (not stop_source.stop_requested()) + { + if (boost::fibers::has_ready_fibers()) + { + boost::this_fiber::yield(); + } + } + + co_return co_await coro; +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanReturn) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + using namespace std::chrono_literals; + boost::this_fiber::sleep_for(10ms); + return 5; + }); + + ASSERT_EQ(mrc::coroutines::sync_wait(run_operation(operation)), 5); +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanThrow) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + throw std::runtime_error("oops"); + return 5; + }); + + ASSERT_THROW(mrc::coroutines::sync_wait(run_operation(operation)), std::runtime_error); +} diff --git a/python/mrc/_pymrc/tests/test_executor.cpp b/python/mrc/_pymrc/tests/test_executor.cpp index 41e284d91..20ea8b10d 100644 --- a/python/mrc/_pymrc/tests/test_executor.cpp +++ b/python/mrc/_pymrc/tests/test_executor.cpp @@ -33,11 +33,9 @@ #include #include -#include #include #include #include -#include namespace py = pybind11; namespace pymrc = mrc::pymrc; diff --git a/python/mrc/_pymrc/tests/test_json_values.cpp b/python/mrc/_pymrc/tests/test_json_values.cpp new file mode 100644 index 000000000..93c1e0e85 --- /dev/null +++ b/python/mrc/_pymrc/tests/test_json_values.cpp @@ -0,0 +1,561 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_pymrc.hpp" + +#include "pymrc/types.hpp" +#include "pymrc/utilities/json_values.hpp" + +#include +#include +#include +#include // IWYU pragma: keep + +#include +#include // for size_t +#include // for initializer_list +#include +#include +#include // for pair +#include +// We already included pybind11.h don't need these others +// IWYU pragma: no_include +// IWYU pragma: no_include +// IWYU pragma: no_include + +namespace py = pybind11; +using namespace mrc::pymrc; +using namespace std::string_literals; +using namespace pybind11::literals; // to bring in the `_a` literal + +PYMRC_TEST_CLASS(JSONValues); + +py::dict mk_py_dict() +{ + // return a simple python dict with a nested dict, a list, an integer, and a float + std::array alphabet = {"a", "b", "c"}; + return py::dict("this"_a = py::dict("is"_a = "a test"s), + "alphabet"_a = py::cast(alphabet), + "ncc"_a = 1701, + "cost"_a = 47.47); +} + +nlohmann::json mk_json() +{ + // return a simple json object comparable to that returned by mk_py_dict + return {{"this", {{"is", "a test"}}}, {"alphabet", {"a", "b", "c"}}, {"ncc", 1701}, {"cost", 47.47}}; +} + +py::object mk_decimal(const std::string& value = "1.0"s) +{ + // return a Python decimal.Decimal object, as a simple object without a supported JSON serialization + return py::module_::import("decimal").attr("Decimal")(value); +} + +TEST_F(TestJSONValues, DefaultConstructor) +{ + JSONValues j; + + EXPECT_EQ(j.to_json(JSONValues::stringify), nlohmann::json()); + EXPECT_TRUE(j.to_python().is_none()); +} + +TEST_F(TestJSONValues, ToPythonSerializable) +{ + auto py_dict = mk_py_dict(); + + JSONValues j{py_dict}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it +} + +TEST_F(TestJSONValues, ToPythonFromJSON) +{ + py::dict py_expected_results = mk_py_dict(); + + nlohmann::json json_input = mk_json(); + JSONValues j{json_input}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_expected_results)); +} + +TEST_F(TestJSONValues, ToJSONFromPython) +{ + auto expected_results = mk_json(); + + py::dict py_input = mk_py_dict(); + + JSONValues j{py_input}; + auto result = j.to_json(JSONValues::stringify); + + EXPECT_EQ(result, expected_results); +} + +TEST_F(TestJSONValues, ToJSONFromPythonUnserializable) +{ + std::string dec_val{"2.2"}; + auto expected_results = mk_json(); + expected_results["other"] = dec_val; + + py::dict py_input = mk_py_dict(); + py_input["other"] = mk_decimal(dec_val); + + JSONValues j{py_input}; + EXPECT_EQ(j.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, ToJSONFromJSON) +{ + JSONValues j{mk_json()}; + auto result = j.to_json(JSONValues::stringify); + + EXPECT_EQ(result, mk_json()); +} + +TEST_F(TestJSONValues, ToPythonRootUnserializable) +{ + py::object py_dec = mk_decimal(); + + JSONValues j{py_dec}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dec)); + EXPECT_TRUE(result.is(py_dec)); // Ensure we stored the object + + nlohmann::json expexted_json("**pymrc_placeholder"s); + EXPECT_EQ(j.view_json(), expexted_json); +} + +TEST_F(TestJSONValues, ToPythonSimpleDict) +{ + py::object py_dec = mk_decimal(); + py::dict py_dict; + py_dict[py::str("test"s)] = py_dec; + + JSONValues j{py_dict}; + py::dict result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the dict and not stored it + + py::object result_dec = result["test"]; + EXPECT_TRUE(result_dec.is(py_dec)); // Ensure we stored the decimal object +} + +TEST_F(TestJSONValues, ToPythonNestedDictUnserializable) +{ + // decimal.Decimal is not serializable + py::object py_dec1 = mk_decimal("1.1"); + py::object py_dec2 = mk_decimal("1.2"); + py::object py_dec3 = mk_decimal("1.3"); + + std::vector py_values = {py::cast(1), py::cast(2), py_dec3, py::cast(4)}; + py::list py_list = py::cast(py_values); + + // Test with object in a nested dict + py::dict py_dict("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec1))), + "other"_a = py_dec2, + "nested_list"_a = py_list); + + JSONValues j{py_dict}; + auto result = j.to_python(); + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it + + // Individual Decimal instances shoudl be stored and thus pass an `is` test + py::object result_dec1 = result["a"]["b"]["c"]["d"]; + EXPECT_TRUE(result_dec1.is(py_dec1)); + + py::object result_dec2 = result["other"]; + EXPECT_TRUE(result_dec2.is(py_dec2)); + + py::list nested_list = result["nested_list"]; + py::object result_dec3 = nested_list[2]; + EXPECT_TRUE(result_dec3.is(py_dec3)); +} + +TEST_F(TestJSONValues, ToPythonList) +{ + py::object py_dec = mk_decimal("1.1"s); + + std::vector py_values = {py::cast(1), py::cast(2), py_dec, py::cast(4)}; + py::list py_list = py::cast(py_values); + + JSONValues j{py_list}; + py::list result = j.to_python(); + EXPECT_TRUE(result.equal(py_list)); + py::object result_dec = result[2]; + EXPECT_TRUE(result_dec.is(py_dec)); +} + +TEST_F(TestJSONValues, ToPythonMultipleTypes) +{ + // Test with miultiple types not json serializable: module, class, function, generator + py::object py_mod = py::module_::import("decimal"); + py::object py_cls = py_mod.attr("Decimal"); + py::object globals = py::globals(); + py::exec( + R"( + def gen_fn(): + yield 1 + )", + globals); + + py::object py_fn = globals["gen_fn"]; + py::object py_gen = py_fn(); + + std::vector> expected_list_objs = {{1, py_mod}, + {3, py_cls}, + {5, py_fn}, + {7, py_gen}}; + + std::vector py_values = + {py::cast(0), py_mod, py::cast(2), py_cls, py::cast(4), py_fn, py::cast(6), py_gen}; + py::list py_list = py::cast(py_values); + + std::vector> expected_dict_objs = {{"module", py_mod}, + {"class", py_cls}, + {"function", py_fn}, + {"generator", py_gen}}; + + // Test with object in a nested dict + py::dict py_dict("module"_a = py_mod, + "class"_a = py_cls, + "function"_a = py_fn, + "generator"_a = py_gen, + "nested_list"_a = py_list); + + JSONValues j{py_dict}; + auto result = j.to_python(); + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it + + for (const auto& [key, value] : expected_dict_objs) + { + py::object result_value = result[key.c_str()]; + EXPECT_TRUE(result_value.is(value)); + } + + py::list nested_list = result["nested_list"]; + for (const auto& [index, value] : expected_list_objs) + { + py::object result_value = nested_list[index]; + EXPECT_TRUE(result_value.is(value)); + } +} + +TEST_F(TestJSONValues, NumUnserializable) +{ + { + JSONValues j{mk_json()}; + EXPECT_EQ(j.num_unserializable(), 0); + EXPECT_FALSE(j.has_unserializable()); + } + { + JSONValues j{mk_py_dict()}; + EXPECT_EQ(j.num_unserializable(), 0); + EXPECT_FALSE(j.has_unserializable()); + } + { + // Test with object in a nested dict + py::object py_dec = mk_decimal(); + { + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec))), "other"_a = 2); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 1); + EXPECT_TRUE(j.has_unserializable()); + } + { + // Storing the same object twice should count twice + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec))), "other"_a = py_dec); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 2); + EXPECT_TRUE(j.has_unserializable()); + } + { + py::object py_dec2 = mk_decimal("2.0"); + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec, "e"_a = py_dec2))), + "other"_a = py_dec); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 3); + EXPECT_TRUE(j.has_unserializable()); + } + } +} + +TEST_F(TestJSONValues, SetValueNewKeyJSON) +{ + // Set to new key that doesn't exist + auto expected_results = mk_json(); + expected_results["other"] = mk_json(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueExistingKeyJSON) +{ + // Set to existing key + auto expected_results = mk_json(); + expected_results["this"] = mk_json(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/this", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueNewKeyJSONWithUnserializable) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = mk_py_dict(); + expected_results["dec"] = mk_decimal(); + + auto input = mk_py_dict(); + input["dec"] = mk_decimal(); + + JSONValues values{input}; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueExistingKeyJSONWithUnserializable) +{ + // Set to existing key + auto expected_results = mk_py_dict(); + expected_results["dec"] = mk_decimal(); + expected_results["this"] = mk_py_dict(); + + auto input = mk_py_dict(); + input["dec"] = mk_decimal(); + + JSONValues values{input}; + auto new_values = values.set_value("/this", mk_json()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNewKeyPython) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = mk_decimal(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/other", mk_decimal()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNestedUnsupportedPython) +{ + JSONValues values{mk_json()}; + EXPECT_THROW(values.set_value("/other/nested", mk_decimal()), py::error_already_set); +} + +TEST_F(TestJSONValues, SetValueNestedUnsupportedJSON) +{ + JSONValues values{mk_json()}; + EXPECT_THROW(values.set_value("/other/nested", nlohmann::json(1.0)), nlohmann::json::out_of_range); +} + +TEST_F(TestJSONValues, SetValueExistingKeyPython) +{ + // Set to existing key + auto expected_results = mk_py_dict(); + expected_results["this"] = mk_decimal(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/this", mk_decimal()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNewKeyJSONDefaultConstructed) +{ + nlohmann::json expected_results{{"other", mk_json()}}; + + JSONValues values; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueJSONValues) +{ + // Set to new key that doesn't exist + auto expected_results = mk_json(); + expected_results["other"] = mk_json(); + + JSONValues values1{mk_json()}; + JSONValues values2{mk_json()}; + auto new_values = values1.set_value("/other", values2); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueJSONValuesWithUnserializable) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = py::dict("dec"_a = mk_decimal()); + + JSONValues values1{mk_json()}; + + auto input_dict = py::dict("dec"_a = mk_decimal()); + JSONValues values2{input_dict}; + + auto new_values = values1.set_value("/other", values2); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, GetJSON) +{ + using namespace nlohmann; + const auto json_doc = mk_json(); + std::vector paths = {"/", "/this", "/this/is", "/alphabet", "/ncc", "/cost"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + json::json_pointer jp; + if (path != "/") + { + jp = json::json_pointer(path); + } + + EXPECT_TRUE(json_doc.contains(jp)) << "Path: '" << path << "' not found in json"; + EXPECT_EQ(value.get_json(path, JSONValues::stringify), json_doc[jp]); + } + } +} + +TEST_F(TestJSONValues, GetJSONError) +{ + std::vector paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value.get_json(path, JSONValues::stringify), std::runtime_error); + } + } +} + +TEST_F(TestJSONValues, GetPython) +{ + const auto py_dict = mk_py_dict(); + + // + std::vector> tests = {{"/", py_dict}, + {"/this", py::dict("is"_a = "a test"s)}, + {"/this/is", py::str("a test"s)}, + {"/alphabet", py_dict["alphabet"]}, + {"/ncc", py::int_(1701)}, + {"/cost", py::float_(47.47)}}; + + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& p : tests) + { + const auto& path = p.first; + const auto& expected_result = p.second; + EXPECT_TRUE(value.get_python(path).equal(expected_result)); + } + } +} + +TEST_F(TestJSONValues, GetPythonError) +{ + std::vector paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value.get_python(path), std::runtime_error) << "Expected failure with path: '" << path << "'"; + } + } +} + +TEST_F(TestJSONValues, SubscriptOpt) +{ + using namespace nlohmann; + const auto json_doc = mk_json(); + std::vector values = {"", "this", "this/is", "alphabet", "ncc", "cost"}; + std::vector paths; + for (const auto& value : values) + { + paths.push_back(value); + paths.push_back("/" + value); + } + + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + auto jv = value[path]; + + json::json_pointer jp; + if (!path.empty() && path != "/") + { + std::string json_path = path; + if (json_path[0] != '/') + { + json_path = "/"s + json_path; + } + + jp = json::json_pointer(json_path); + } + + EXPECT_EQ(jv.to_json(JSONValues::stringify), json_doc[jp]); + } + } +} + +TEST_F(TestJSONValues, SubscriptOptError) +{ + std::vector paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value[path], std::runtime_error); + } + } +} + +TEST_F(TestJSONValues, Stringify) +{ + auto dec_val = mk_decimal("2.2"s); + EXPECT_EQ(JSONValues::stringify(dec_val, "/"s), nlohmann::json("2.2"s)); +} + +TEST_F(TestJSONValues, CastPyToJSONValues) +{ + auto py_dict = mk_py_dict(); + auto j = py_dict.cast(); + EXPECT_TRUE(j.to_python().equal(py_dict)); +} + +TEST_F(TestJSONValues, CastJSONValuesToPy) +{ + auto j = JSONValues{mk_json()}; + auto py_dict = py::cast(j); + EXPECT_TRUE(py_dict.equal(j.to_python())); +} diff --git a/python/mrc/_pymrc/tests/test_pipeline.cpp b/python/mrc/_pymrc/tests/test_pipeline.cpp index 68091ba14..7b375d21a 100644 --- a/python/mrc/_pymrc/tests/test_pipeline.cpp +++ b/python/mrc/_pymrc/tests/test_pipeline.cpp @@ -31,9 +31,7 @@ #include "mrc/options/topology.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" -#include #include #include #include @@ -46,7 +44,6 @@ #include #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/tests/test_serializers.cpp b/python/mrc/_pymrc/tests/test_serializers.cpp index cbf5147c5..e6c72e27c 100644 --- a/python/mrc/_pymrc/tests/test_serializers.cpp +++ b/python/mrc/_pymrc/tests/test_serializers.cpp @@ -28,7 +28,6 @@ #include #include // IWYU pragma: keep -#include #include #include #include diff --git a/python/mrc/_pymrc/tests/test_utils.cpp b/python/mrc/_pymrc/tests/test_utils.cpp index a802009fc..7606b6502 100644 --- a/python/mrc/_pymrc/tests/test_utils.cpp +++ b/python/mrc/_pymrc/tests/test_utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,9 +32,9 @@ #include #include #include +#include // for size_t #include #include -#include #include #include #include @@ -42,6 +42,7 @@ namespace py = pybind11; namespace pymrc = mrc::pymrc; using namespace std::string_literals; +using namespace pybind11::literals; // to bring in the `_a` literal // Create values too big to fit in int & float types to ensure we can pass // long & double types to both nlohmann/json and python @@ -144,6 +145,73 @@ TEST_F(TestUtils, CastFromPyObject) } } +TEST_F(TestUtils, CastFromPyObjectSerializeErrors) +{ + // Test to verify that cast_from_pyobject throws a python TypeError when encountering something that is not json + // serializable issue #450 + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_THROW(pymrc::cast_from_pyobject(o), py::type_error); + + // Test with object in a nested dict + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = 2); + EXPECT_THROW(pymrc::cast_from_pyobject(d), py::type_error); +} + +TEST_F(TestUtils, CastFromPyObjectUnserializableHandlerFn) +{ + // Test to verify that cast_from_pyobject calls the unserializable_handler_fn when encountering an object that it + // does not know how to serialize + + bool handler_called{false}; + pymrc::unserializable_handler_fn_t handler_fn = [&handler_called](const py::object& source, + const std::string& path) { + handler_called = true; + return nlohmann::json(py::cast(source)); + }; + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_EQ(pymrc::cast_from_pyobject(o, handler_fn), nlohmann::json(1.0)); + EXPECT_TRUE(handler_called); +} + +TEST_F(TestUtils, CastFromPyObjectUnserializableHandlerFnNestedObj) +{ + std::size_t handler_call_count{0}; + + // Test with object in a nested dict + pymrc::unserializable_handler_fn_t handler_fn = [&handler_call_count](const py::object& source, + const std::string& path) { + ++handler_call_count; + return nlohmann::json(py::cast(source)); + }; + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = o); + nlohmann::json expected_results = {{"a", {{"b", {{"c", {{"d", 1.0}}}}}}}, {"other", 1.0}}; + + EXPECT_EQ(pymrc::cast_from_pyobject(d, handler_fn), expected_results); + EXPECT_EQ(handler_call_count, 2); +} + +TEST_F(TestUtils, GetTypeName) +{ + // invalid objects should return an empty string + EXPECT_EQ(pymrc::get_py_type_name(py::object()), ""); + EXPECT_EQ(pymrc::get_py_type_name(py::none()), "NoneType"); + + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_EQ(pymrc::get_py_type_name(o), "Decimal"); +} + TEST_F(TestUtils, PyObjectWrapper) { py::list test_list; diff --git a/python/mrc/benchmarking/watchers.cpp b/python/mrc/benchmarking/watchers.cpp index 2a4b3418f..920826239 100644 --- a/python/mrc/benchmarking/watchers.cpp +++ b/python/mrc/benchmarking/watchers.cpp @@ -26,11 +26,9 @@ #include // IWYU pragma: keep #include -#include #include #include #include -#include namespace mrc::pymrc { namespace py = pybind11; diff --git a/python/mrc/core/CMakeLists.txt b/python/mrc/core/CMakeLists.txt index d635e071f..f04b17f1f 100644 --- a/python/mrc/core/CMakeLists.txt +++ b/python/mrc/core/CMakeLists.txt @@ -16,6 +16,7 @@ list(APPEND CMAKE_MESSAGE_CONTEXT "core") mrc_add_pybind11_module(common SOURCE_FILES common.cpp) +mrc_add_pybind11_module(coro SOURCE_FILES coro.cpp) mrc_add_pybind11_module(executor SOURCE_FILES executor.cpp) mrc_add_pybind11_module(logging SOURCE_FILES logging.cpp) mrc_add_pybind11_module(node SOURCE_FILES node.cpp) diff --git a/python/mrc/core/common.cpp b/python/mrc/core/common.cpp index 741fec61b..7dde55b4b 100644 --- a/python/mrc/core/common.cpp +++ b/python/mrc/core/common.cpp @@ -18,21 +18,15 @@ #include "pymrc/port_builders.hpp" #include "pymrc/types.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" -#include "mrc/types.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" -#include #include #include #include -#include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/coro.cpp b/python/mrc/core/coro.cpp new file mode 100644 index 000000000..d647a7b11 --- /dev/null +++ b/python/mrc/core/coro.cpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pymrc/coro.hpp" + +#include +#include +#include +#include +#include +#include // IWYU pragma: keep + +#include +#include +#include +#include + +namespace mrc::pymrc::coro { + +namespace py = pybind11; + +PYBIND11_MODULE(coro, _module) +{ + _module.doc() = R"pbdoc( + ----------------------- + .. currentmodule:: morpheus.llm + .. autosummary:: + :toctree: _generate + + )pbdoc"; + + py::class_>(_module, "CppToPyAwaitable") + .def(py::init<>()) + .def("__iter__", &CppToPyAwaitable::iter) + .def("__await__", &CppToPyAwaitable::await) + .def("__next__", &CppToPyAwaitable::next); + + py::class_>( // + _module, + "BoostFibersMainPyAwaitable") + .def(py::init<>()); + + _module.def("wrap_coroutine", [](coroutines::Task> fn) -> coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + auto strings = co_await fn; + + co_return strings[0]; + }); + + // _module.attr("__version__") = + // MRC_CONCAT_STR(morpheus_VERSION_MAJOR << "." << morpheus_VERSION_MINOR << "." << morpheus_VERSION_PATCH); +} +} // namespace mrc::pymrc::coro diff --git a/python/mrc/core/node.cpp b/python/mrc/core/node.cpp index bbbdfe658..0452bc9e5 100644 --- a/python/mrc/core/node.cpp +++ b/python/mrc/core/node.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,8 @@ #include "pymrc/utils.hpp" #include "mrc/node/operators/broadcast.hpp" +#include "mrc/node/operators/round_robin_router_typeless.hpp" +#include "mrc/node/operators/zip.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" #include "mrc/utils/string_utils.hpp" @@ -58,6 +60,35 @@ PYBIND11_MODULE(node, py_mod) return node; })); + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "RoundRobinRouter") + .def(py::init<>([](mrc::segment::IBuilder& builder, std::string name) { + auto node = builder.construct_object(name); + + return node; + })); + + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "Zip") + .def(py::init<>([](mrc::segment::IBuilder& builder, std::string name, size_t count) { + // std::shared_ptr node; + + if (count == 2) + { + return builder.construct_object>(name)->as(); + } + else + { + py::print("Unsupported count!"); + throw std::runtime_error("Unsupported count!"); + } + })) + .def("get_sink", [](mrc::segment::Object& self, size_t index) { + return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); + }); + py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." << mrc_VERSION_PATCH); } diff --git a/python/mrc/core/operators.cpp b/python/mrc/core/operators.cpp index b74ff96ec..be931fc27 100644 --- a/python/mrc/core/operators.cpp +++ b/python/mrc/core/operators.cpp @@ -28,7 +28,6 @@ #include #include // IWYU pragma: keep -#include #include namespace mrc::pymrc { diff --git a/python/mrc/core/pipeline.cpp b/python/mrc/core/pipeline.cpp index 2f1dcf970..a6e9f0b5e 100644 --- a/python/mrc/core/pipeline.cpp +++ b/python/mrc/core/pipeline.cpp @@ -27,7 +27,6 @@ #include #include // IWYU pragma: keep -#include #include namespace mrc::pymrc { diff --git a/python/mrc/core/segment.cpp b/python/mrc/core/segment.cpp index ed87f83f2..6c1898d33 100644 --- a/python/mrc/core/segment.cpp +++ b/python/mrc/core/segment.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,12 +38,9 @@ #include #include -#include -#include #include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp b/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp index 570bd3c69..49aee1f7b 100644 --- a/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp +++ b/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp @@ -26,18 +26,14 @@ #include "mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp" #include "mrc/modules/module_registry.hpp" #include "mrc/modules/module_registry_util.hpp" -#include "mrc/node/operators/broadcast.hpp" -#include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/version.hpp" #include #include // IWYU pragma: keep #include #include +#include -#include -#include #include #include #include diff --git a/python/mrc/core/segment/module_definitions/segment_module_registry.cpp b/python/mrc/core/segment/module_definitions/segment_module_registry.cpp index 0ae7b5728..86d21f65c 100644 --- a/python/mrc/core/segment/module_definitions/segment_module_registry.cpp +++ b/python/mrc/core/segment/module_definitions/segment_module_registry.cpp @@ -25,12 +25,9 @@ #include #include // IWYU pragma: keep #include -#include #include // IWYU pragma: keep -#include #include -#include #include #include diff --git a/python/mrc/core/segment/module_definitions/segment_modules.cpp b/python/mrc/core/segment/module_definitions/segment_modules.cpp index 08332dd40..5cc22f61d 100644 --- a/python/mrc/core/segment/module_definitions/segment_modules.cpp +++ b/python/mrc/core/segment/module_definitions/segment_modules.cpp @@ -25,9 +25,7 @@ #include #include -#include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/subscriber.cpp b/python/mrc/core/subscriber.cpp index 656ff6884..8d6de717a 100644 --- a/python/mrc/core/subscriber.cpp +++ b/python/mrc/core/subscriber.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,10 +25,10 @@ #include #include // IWYU pragma: keep -#include // IWYU pragma: keep(for call_guard) +#include // IWYU pragma: keep #include +#include -#include #include #include @@ -50,7 +50,8 @@ PYBIND11_MODULE(subscriber, py_mod) // Common must be first in every module pymrc::import(py_mod, "mrc.core.common"); - py::class_(py_mod, "Subscription"); + py::class_(py_mod, "Subscription") + .def("is_subscribed", &SubscriptionProxy::is_subscribed, py::call_guard()); py::class_(py_mod, "Observer") .def("on_next", diff --git a/python/mrc/tests/sample_modules.cpp b/python/mrc/tests/sample_modules.cpp index 8bd6d354e..041d67a91 100644 --- a/python/mrc/tests/sample_modules.cpp +++ b/python/mrc/tests/sample_modules.cpp @@ -20,15 +20,12 @@ #include "pymrc/utils.hpp" #include "mrc/modules/module_registry_util.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" #include #include -#include -#include #include #include diff --git a/python/mrc/tests/test_edges.cpp b/python/mrc/tests/test_edges.cpp index 1e9cc0359..ccac5a2d7 100644 --- a/python/mrc/tests/test_edges.cpp +++ b/python/mrc/tests/test_edges.cpp @@ -24,29 +24,22 @@ #include "mrc/channel/status.hpp" #include "mrc/edge/edge_connector.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" -#include #include #include #include #include -#include #include #include -#include #include #include #include #include -#include namespace mrc::pytests { diff --git a/python/mrc/tests/utils.cpp b/python/mrc/tests/utils.cpp index 35a64d6e5..b95920b43 100644 --- a/python/mrc/tests/utils.cpp +++ b/python/mrc/tests/utils.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,10 +17,13 @@ #include "pymrc/utils.hpp" +#include "pymrc/utilities/json_values.hpp" // for JSONValues + #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" #include +#include // for gil_scoped_acquire #include #include @@ -30,6 +33,21 @@ namespace mrc::pytests { namespace py = pybind11; +// Simple test class which uses pybind11's `gil_scoped_acquire` class in the destructor. Needed to repro #362 +struct RequireGilInDestructor +{ + ~RequireGilInDestructor() + { + // Grab the GIL + py::gil_scoped_acquire gil; + } +}; + +pymrc::JSONValues roundtrip_cast(pymrc::JSONValues v) +{ + return v; +} + PYBIND11_MODULE(utils, py_mod) { py_mod.doc() = R"pbdoc()pbdoc"; @@ -48,6 +66,10 @@ PYBIND11_MODULE(utils, py_mod) }, py::arg("msg") = ""); + py::class_(py_mod, "RequireGilInDestructor").def(py::init<>()); + + py_mod.def("roundtrip_cast", &roundtrip_cast, py::arg("v")); + py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." << mrc_VERSION_PATCH); } diff --git a/python/setup.py b/python/setup.py index cc37c7077..7d19ae679 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,20 +22,16 @@ ############################################################################## # - Python package generation ------------------------------------------------ -setup( - name="mrc", - description="mrc", - version=versioneer.get_version(), - classifiers=[ - "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3.10" - ], - author="NVIDIA Corporation", - setup_requires=[], - include_package_data=True, - packages=find_namespace_packages(include=["mrc*"], exclude=["tests", "mrc.core.segment.module_definitions"]), - package_data={ - "mrc": ["_pymrc/*.so"] # Add the pymrc library for the root package - }, - license="Apache", - cmdclass=versioneer.get_cmdclass(), - zip_safe=False) +setup(name="mrc", + description="mrc", + version=versioneer.get_version(), + classifiers=[ + "Intended Audience :: Developers", "Programming Language :: Python", "Programming Language :: Python :: 3.10" + ], + author="NVIDIA Corporation", + setup_requires=[], + include_package_data=True, + packages=find_namespace_packages(include=["mrc*"], exclude=["tests", "mrc.core.segment.module_definitions"]), + license="Apache", + cmdclass=versioneer.get_cmdclass(), + zip_safe=False) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 842261f10..7052fe176 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import typing import pytest @@ -50,6 +51,9 @@ def configure_tests_logging(is_debugger_attached: bool): if (is_debugger_attached): log_level = logging.INFO + if (os.environ.get('GLOG_v') is not None): + log_level = logging.DEBUG + mrc_logging.init_logging("mrc_testing", py_level=log_level) diff --git a/python/tests/test_coro.py b/python/tests/test_coro.py new file mode 100644 index 000000000..940160f18 --- /dev/null +++ b/python/tests/test_coro.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from mrc._pymrc.tests.coro.coro import call_async +from mrc._pymrc.tests.coro.coro import call_fib_async +from mrc._pymrc.tests.coro.coro import raise_at_depth_async +from mrc.core import coro + + +@pytest.mark.asyncio +async def test_coro(): + + # hit_inside = False + + async def inner(): + + # nonlocal hit_inside + + result = await coro.wrap_coroutine(asyncio.sleep(1, result=['a', 'b', 'c'])) + + # hit_inside = True + + return [result] + + returned_val = await coro.wrap_coroutine(inner()) + + assert returned_val == 'a' + # assert hit_inside + + +@pytest.mark.asyncio +async def test_coro_many(): + + expected_count = 1000 + hit_count = 0 + + start_time = asyncio.get_running_loop().time() + + async def inner(): + + nonlocal hit_count + + await asyncio.sleep(0.1) + + hit_count += 1 + + return ['a', 'b', 'c'] + + coros = [coro.wrap_coroutine(inner()) for _ in range(expected_count)] + + returned_vals = await asyncio.gather(*coros) + + end_time = asyncio.get_running_loop().time() + + assert returned_vals == ['a'] * expected_count + assert hit_count == expected_count + assert (end_time - start_time) < 1.5 + + +@pytest.mark.asyncio +async def test_python_cpp_async_interleave(): + + def fib(n): + if n < 0: + raise ValueError() + + if n < 2: + return 1 + + return fib(n - 1) + fib(n - 2) + + async def fib_async(n): + if n < 0: + raise ValueError() + + if n < 2: + return 1 + + task_a = call_fib_async(fib_async, n, 1) + task_b = call_fib_async(fib_async, n, 2) + + [a, b] = await asyncio.gather(task_a, task_b) + + return a + b + + assert fib(15) == await fib_async(15) + + +@pytest.mark.asyncio +async def test_python_cpp_async_exception(): + + async def py_raise_at_depth_async(n: int): + if n <= 0: + raise RuntimeError("depth reached zero in python") + + await raise_at_depth_async(py_raise_at_depth_async, n - 1) + + depth = 100 + + with pytest.raises(RuntimeError) as ex: + await raise_at_depth_async(py_raise_at_depth_async, depth + 1) + assert "python" in str(ex.value) + + with pytest.raises(RuntimeError) as ex: + await raise_at_depth_async(py_raise_at_depth_async, depth) + assert "c++" in str(ex.value) + + +@pytest.mark.asyncio +async def test_can_cancel_coroutine_from_python(): + + counter = 0 + + async def increment_recursively(): + nonlocal counter + await asyncio.sleep(0) + counter += 1 + await call_async(increment_recursively) + + task = asyncio.ensure_future(call_async(increment_recursively)) + + await asyncio.sleep(0) + assert counter == 0 + await asyncio.sleep(0) + await asyncio.sleep(0) + assert counter == 1 + await asyncio.sleep(0) + await asyncio.sleep(0) + assert counter == 2 + + task.cancel() + + with pytest.raises(asyncio.exceptions.CancelledError): + await task + + assert counter == 3 diff --git a/python/tests/test_edges.py b/python/tests/test_edges.py index 98ed11d0e..3c28c77e4 100644 --- a/python/tests/test_edges.py +++ b/python/tests/test_edges.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -252,6 +252,26 @@ def add_broadcast(seg: mrc.Builder, *upstream: mrc.SegmentObject): return node +def add_round_robin_router(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node = mrc.core.node.RoundRobinRouter(seg, "RoundRobinRouter") + + for u in upstream: + seg.make_edge(u, node) + + return node + + +def add_zip(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node = mrc.core.node.Zip(seg, "Zip", len(upstream)) + + for i, u in enumerate(upstream): + seg.make_edge(u, node.get_sink(i)) + + return node + + # THIS TEST IS CAUSING ISSUES WHEN RUNNING ALL TESTS TOGETHER # @dataclasses.dataclass @@ -431,14 +451,15 @@ def fail_if_more_derived_type(combo: typing.Tuple): @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) @pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"]) @pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) -@pytest.mark.parametrize("source_type,sink1_type,sink2_type", - gen_parameters("source", - "sink1", - "sink2", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize( + "source_type,sink1_type,sink2_type", + gen_parameters("source", + "sink1", + "sink2", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_broadcast_to_sinks(run_segment, sink1_component: bool, sink2_component: bool, @@ -503,13 +524,84 @@ def segment_init(seg: mrc.Builder): assert results == expected_node_counts +@pytest.mark.parametrize("sink1_component,sink2_component", + gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False)) @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) -@pytest.mark.parametrize("source_type", - gen_parameters("source", - is_fail_fn=lambda _: False, - values={ - "base": m.Base, "derived": m.DerivedA - })) +@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"]) +@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) +@pytest.mark.parametrize( + "source_type,sink1_type,sink2_type", + gen_parameters("source", + "sink1", + "sink2", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) +def test_source_to_round_robin_router_to_sinks(run_segment, + sink1_component: bool, + sink2_component: bool, + source_cpp: bool, + sink1_cpp: bool, + sink2_cpp: bool, + source_type: type, + sink1_type: type, + sink2_type: type): + + def segment_init(seg: mrc.Builder): + + source = add_source(seg, is_cpp=source_cpp, data_type=source_type, is_component=False) + broadcast = add_round_robin_router(seg, source) + add_sink(seg, + broadcast, + is_cpp=sink1_cpp, + data_type=sink1_type, + is_component=sink1_component, + suffix="1", + count=3) + add_sink(seg, + broadcast, + is_cpp=sink2_cpp, + data_type=sink2_type, + is_component=sink2_component, + suffix="2", + count=2) + + results = run_segment(segment_init) + + assert results == expected_node_counts + + +@pytest.mark.parametrize("sink1_component,sink2_component", + gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False)) +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink1_py"]) +@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"]) +def test_multi_source_to_round_robin_router_to_multi_sink(run_segment, + sink1_component: bool, + sink2_component: bool, + source_cpp: bool, + sink1_cpp: bool, + sink2_cpp: bool): + + def segment_init(seg: mrc.Builder): + + source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1") + source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2") + broadcast = add_round_robin_router(seg, source1, source2) + add_sink(seg, broadcast, is_cpp=sink1_cpp, data_type=m.Base, is_component=sink1_component, suffix="1") + add_sink(seg, broadcast, is_cpp=sink2_cpp, data_type=m.Base, is_component=sink2_component, suffix="2") + + results = run_segment(segment_init) + + assert results == expected_node_counts + + +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize( + "source_type", gen_parameters("source", is_fail_fn=lambda _: False, values={ + "base": m.Base, "derived": m.DerivedA + })) def test_source_to_null(run_segment, source_cpp: bool, source_type: type): def segment_init(seg: mrc.Builder): @@ -522,24 +614,24 @@ def segment_init(seg: mrc.Builder): assert results == expected_node_counts -@pytest.mark.parametrize("source_cpp,node_cpp", - gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ - "cpp": True, "py": False - })) -@pytest.mark.parametrize("source_type,node_type", - gen_parameters("source", - "node", - is_fail_fn=fail_if_more_derived_type, - values={ - "base": m.Base, "derived": m.DerivedA - })) -@pytest.mark.parametrize("source_component,node_component", - gen_parameters("source", - "node", - is_fail_fn=lambda x: x[0] and x[1], - values={ - "run": False, "com": True - })) +@pytest.mark.parametrize( + "source_cpp,node_cpp", + gen_parameters("source", "node", is_fail_fn=lambda _: False, values={ + "cpp": True, "py": False + })) +@pytest.mark.parametrize( + "source_type,node_type", + gen_parameters("source", + "node", + is_fail_fn=fail_if_more_derived_type, + values={ + "base": m.Base, "derived": m.DerivedA + })) +@pytest.mark.parametrize( + "source_component,node_component", + gen_parameters("source", "node", is_fail_fn=lambda x: x[0] and x[1], values={ + "run": False, "com": True + })) def test_source_to_node_to_null(run_segment, source_cpp: bool, node_cpp: bool, @@ -557,3 +649,18 @@ def segment_init(seg: mrc.Builder): results = run_segment(segment_init) assert results == expected_node_counts + + +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +def test_multi_source_to_zip_to_sink(run_segment, source_cpp: bool): + + def segment_init(seg: mrc.Builder): + + source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1") + source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2") + zip = add_zip(seg, source1, source2) + add_sink(seg, zip, is_cpp=False, data_type=m.Base, is_component=False) + + results = run_segment(segment_init) + + assert results == expected_node_counts diff --git a/python/tests/test_executor.py b/python/tests/test_executor.py index 8e5b7ad47..46381d285 100644 --- a/python/tests/test_executor.py +++ b/python/tests/test_executor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,8 @@ # limitations under the License. import asyncio +import os +import time import typing import pytest @@ -30,6 +32,53 @@ def pairwise(t): node_fn_type = typing.Callable[[mrc.Builder], mrc.SegmentObject] +@pytest.fixture +def source(): + + def build(builder: mrc.Builder): + + def gen_data(): + yield 1 + yield 2 + yield 3 + + return builder.make_source("source", gen_data) + + return build + + +@pytest.fixture +def endless_source(): + + def build(builder: mrc.Builder): + + def gen_data(): + i = 0 + while True: + yield i + i += 1 + time.sleep(0.1) + + return builder.make_source("endless_source", gen_data()) + + return build + + +@pytest.fixture +def blocking_source(): + + def build(builder: mrc.Builder): + + def gen_data(subscription: mrc.Subscription): + yield 1 + while subscription.is_subscribed(): + time.sleep(0.1) + + return builder.make_source("blocking_source", gen_data) + + return build + + @pytest.fixture def source_pyexception(): @@ -64,6 +113,21 @@ def gen_data_and_raise(): return build +@pytest.fixture +def node_exception(): + + def build(builder: mrc.Builder): + + def on_next(data): + time.sleep(1) + print("Received value: {}".format(data), flush=True) + raise RuntimeError("unittest") + + return builder.make_node("node", mrc.core.operators.map(on_next)) + + return build + + @pytest.fixture def sink(): @@ -112,6 +176,8 @@ def build_executor(): def inner(pipe: mrc.Pipeline): options = mrc.Options() + options.topology.user_cpuset = f"0-{os.cpu_count() - 1}" + options.engine_factories.default_engine_type = mrc.core.options.EngineType.Thread executor = mrc.Executor(options) executor.register_pipeline(pipe) @@ -183,5 +249,35 @@ async def run_pipeline(): asyncio.run(run_pipeline()) +@pytest.mark.parametrize("souce_name", ["source", "endless_source", "blocking_source"]) +def test_pyexception_in_node(source: node_fn_type, + endless_source: node_fn_type, + blocking_source: node_fn_type, + node_exception: node_fn_type, + build_pipeline: build_pipeline_type, + build_executor: build_executor_type, + souce_name: str): + """ + Test to reproduce Morpheus issue #1838 where an exception raised in a node doesn't always shutdown the executor + when the source is intended to run indefinitely. + """ + + if souce_name == "endless_source": + source_fn = endless_source + elif souce_name == "blocking_source": + source_fn = blocking_source + else: + source_fn = source + + pipe = build_pipeline(source_fn, node_exception) + + executor: mrc.Executor = None + + executor = build_executor(pipe) + + with pytest.raises(RuntimeError): + executor.join() + + if (__name__ in ("__main__", )): test_pyexception_in_source() diff --git a/python/tests/test_gil_tls.py b/python/tests/test_gil_tls.py new file mode 100644 index 000000000..eca5a23d7 --- /dev/null +++ b/python/tests/test_gil_tls.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import mrc +from mrc.tests.utils import RequireGilInDestructor + +TLS = threading.local() + + +def test_gil_thread_local_storage(): + """ + Test to reproduce issue #362 + No asserts needed if it doesn't segfault, then we're good + """ + + def source_gen(): + x = RequireGilInDestructor() + TLS.x = x + yield x + + def init_seg(builder: mrc.Builder): + builder.make_source("souce_gen", source_gen) + + pipe = mrc.Pipeline() + pipe.make_segment("seg1", init_seg) + + options = mrc.Options() + executor = mrc.Executor(options) + executor.register_pipeline(pipe) + executor.start() + executor.join() diff --git a/python/tests/test_json_values_cast.py b/python/tests/test_json_values_cast.py new file mode 100644 index 000000000..a65e5ba9d --- /dev/null +++ b/python/tests/test_json_values_cast.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from decimal import Decimal + +import pytest + +from mrc.tests.utils import roundtrip_cast + + +def test_docstrings(): + expected_docstring = "roundtrip_cast(v: object) -> object" + docstring = inspect.getdoc(roundtrip_cast) + assert docstring == expected_docstring + + +@pytest.mark.parametrize( + "value", + [ + 12, + 2.4, + RuntimeError("test"), + Decimal("1.2"), + "test", [1, 2, 3], { + "a": 1, "b": 2 + }, { + "a": 1, "b": RuntimeError("not serializable") + }, { + "a": 1, "b": Decimal("1.3") + } + ], + ids=["int", "float", "exception", "decimal", "str", "list", "dict", "dict_w_exception", "dict_w_decimal"]) +def test_cast_roundtrip(value: object): + result = roundtrip_cast(value) + assert result == value diff --git a/python/tests/test_node.py b/python/tests/test_node.py index a520e9c65..a59e11eef 100644 --- a/python/tests/test_node.py +++ b/python/tests/test_node.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -489,5 +489,39 @@ def on_completed(): assert on_completed_count == 1 +def test_source_with_bound_value(): + """ + This test ensures that the bound values isn't confused with a subscription object + """ + on_next_value = None + + def segment_init(seg: mrc.Builder): + + def source_gen(a): + yield a + + bound_gen = functools.partial(source_gen, a=1) + source = seg.make_source("my_src", bound_gen) + + def on_next(x: int): + nonlocal on_next_value + on_next_value = x + + sink = seg.make_sink("sink", on_next) + seg.make_edge(source, sink) + + pipeline = mrc.Pipeline() + pipeline.make_segment("my_seg", segment_init) + + options = mrc.Options() + executor = mrc.Executor(options) + executor.register_pipeline(pipeline) + + executor.start() + executor.join() + + assert on_next_value == 1 + + if (__name__ == "__main__"): test_launch_options_properties()