diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 00000000..7362029a --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,123 @@ +# This file is autogenerated by maturin v0.14.17 +# To update, run +# +# maturin generate-ci github +# +name: CD + +on: + push: + branches: + - main + - master + tags: + - "*" + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: -i python${{ matrix.python-version }} --release --out dist + sccache: "true" + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + windows: + runs-on: windows-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + target: [x64, x86] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: "true" + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + macos: + runs-on: macos-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: "true" + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + release: + name: Release + runs-on: ubuntu-latest + needs: [linux, windows, macos, sdist] + steps: + - uses: actions/download-artifact@v3 + with: + name: wheels + - name: Publish to TestPyPI + uses: messense/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.TEST_PYPI_API_TOKEN }} + MATURIN_REPOSITORY: "testpypi" + with: + command: upload + args: --skip-existing * diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e37164a9..8d7fcfc4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: motion +name: CI on: workflow_dispatch: @@ -19,6 +19,30 @@ jobs: - name: Checkout code uses: actions/checkout@v2 + - name: Fetch main branch + run: git fetch origin main + + - name: Get current version from main branch + id: main_version + run: | + MAIN_VERSION=$(git show origin/main:pyproject.toml | grep '^version =' | awk -F '"' '{print $2}') + echo "Main branch version: $MAIN_VERSION" + echo "MAIN_VERSION=$MAIN_VERSION" >> $GITHUB_ENV + + - name: Get version from current branch + id: current_version + run: | + CURRENT_VERSION=$(grep '^version =' pyproject.toml | awk -F '"' '{print $2}') + echo "Current branch version: $CURRENT_VERSION" + echo "CURRENT_VERSION=$CURRENT_VERSION" >> $GITHUB_ENV + + - name: Check if version is bumped + run: | + if [[ "$CURRENT_VERSION" == "$MAIN_VERSION" ]]; then + echo "Error: Version in current branch is not bumped from main branch" + exit 1 + fi + - name: Start Redis run: | retries=3 @@ -35,9 +59,20 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal + override: true + - name: Install dependencies run: poetry install + - name: Build and install + run: | + make build + - name: Run pytest run: make tests diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index d749adff..00000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,67 +0,0 @@ -name: publish - -on: - push: - branches: - - main - paths: - - "motion/**" - - "poetry.lock" - - "pyproject.toml" - -jobs: - publish-to-pypi: - runs-on: ubuntu-latest - if: ${{github.event.head_commit.author.name != 'github-actions[bot]' }} - steps: - - name: Print author - run: | - echo "Commit author name: ${{ github.event.head_commit.author.name }}" - echo "Commit author email: ${{ github.event.head_commit.author.email }}" - - name: Checkout code - uses: actions/checkout@v2 - with: - persist-credentials: false # otherwise, the token used is the GITHUB_TOKEN, instead of your personal token - fetch-depth: 0 # otherwise, you will failed to push refs to dest repo - - name: Configure Git - run: | - git config --global user.name "github-actions[bot]" - git config --global user.email "github-actions[bot]@users.noreply.github.com" - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Install Poetry - uses: snok/install-poetry@v1 - - name: Bump version - id: bump_version - run: | - poetry version patch - - name: Build and publish - run: | - poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }} - poetry build - poetry publish - git add pyproject.toml - git commit -m "Bump up version" - - name: Push changes - uses: ad-m/github-push-action@master - with: - github_token: ${{ secrets.BRANCH_PROTECTION_WORKAROUND }} - branch: main - - - name: Set release number - run: echo "RELEASE_NUMBER=$(poetry version --no-ansi | awk -F' ' '{print $2}')" >> $GITHUB_ENV - - name: Create release - id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.AUTORELEASE_TOKEN }} - RELEASE_NUMBER: ${{ env.RELEASE_NUMBER }} - with: - tag_name: v${{ env.RELEASE_NUMBER }} - release_name: Release ${{ env.RELEASE_NUMBER }} - body: | - An autorelease from the latest version of main. - draft: false - prerelease: true diff --git a/.gitignore b/.gitignore index d553f8fc..d8dd66fd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,9 +16,11 @@ dist* site* *package-lock.json projects -unnecessary.py -motionstate* -*.whl *.so +*rustc* +*motionenv* +*.whl +*unnecessary* target* -.motionenv* \ No newline at end of file +unnecessary.py +motionstate* diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..b27d878c --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,487 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "memchr", +] + +[[package]] +name = "form_urlencoded" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "libc" +version = "0.2.147" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "motion" +version = "0.2.0" +dependencies = [ + "bincode", + "pyo3", + "redis", + "redlock", + "serde", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", +] + +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "redis" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffd6543a7bc6428396845f6854ccf3d1ae8823816592e2cbe74f20f50f209d02" +dependencies = [ + "combine", + "itoa", + "percent-encoding", + "ryu", + "sha1_smol", + "socket2", + "url", +] + +[[package]] +name = "redlock" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce310b27b5923ad1cd21b0e007fba6b6a9926f773099d867fb0477cc188b8aa2" +dependencies = [ + "rand", + "redis", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + +[[package]] +name = "smallvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" + +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" + +[[package]] +name = "unicode-ident" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" + +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "url" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..e82be5f4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "motion" +version = "0.2.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "motion" +crate-type = ["cdylib"] + +[dependencies] +redis = "0.23" +pyo3 = { version = "0.19", features = ["extension-module"] } +redlock = "2.0.0" +serde = {version="1.0.183", features = ["derive"]} +bincode = "1.3.3" diff --git a/Makefile b/Makefile index 5588a9b3..bda50399 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: tests lint install mypy update docs +.PHONY: tests lint install mypy update docs build tests: poetry run pytest @@ -17,4 +17,7 @@ update: poetry update docs: - poetry run mkdocs serve \ No newline at end of file + poetry run mkdocs serve + +build: + poetry run maturin develop --release diff --git a/docs/api/component.md b/docs/api/component.md index 064d7ecf..1dcbc46a 100644 --- a/docs/api/component.md +++ b/docs/api/component.md @@ -7,8 +7,6 @@ - update - init_state - __call__ - - save_state - - load_state - name - params show_root_full_path: false diff --git a/docs/api/props-and-state.md b/docs/api/props-and-state.md index d56aa99a..89a2859b 100644 --- a/docs/api/props-and-state.md +++ b/docs/api/props-and-state.md @@ -9,11 +9,9 @@ show_source: false show_signature_annotations: true -::: motion.dicts.State +::: motion.state.State handler: python options: - members: - - instance_id show_root_full_path: false show_root_toc_entry: false show_root_heading: true diff --git a/motion/__init__.py b/motion/__init__.py index acef994f..35e9348d 100644 --- a/motion/__init__.py +++ b/motion/__init__.py @@ -11,6 +11,7 @@ from motion.dicts import MDataFrame from motion.copy_utils import copy_db from motion.server.application import Application +from .motion import TempValue from motion.mtable import MTable __all__ = [ @@ -25,5 +26,6 @@ "copy_db", "RedisParams", "Application", + "TempValue", "MTable", ] diff --git a/motion/component.py b/motion/component.py index fc032196..77559743 100644 --- a/motion/component.py +++ b/motion/component.py @@ -147,8 +147,8 @@ def __init__( self._serve_routes: Dict[str, Route] = {} self._update_routes: Dict[str, List[Route]] = {} self._init_state_func: Optional[Callable] = None - self._save_state_func: Optional[Callable] = None - self._load_state_func: Optional[Callable] = None + # self._save_state_func: Optional[Callable] = None + # self._load_state_func: Optional[Callable] = None @property def cache_ttl(self) -> int: @@ -245,59 +245,59 @@ def setUp(): self._init_state_func = func return func - def save_state(self, func: Callable) -> Callable: - """Decorator for the save_state function. This function - saves the state of the component to be accessible in - future component instances of the same name. + # def save_state(self, func: Callable) -> Callable: + # """Decorator for the save_state function. This function + # saves the state of the component to be accessible in + # future component instances of the same name. - Usage: - ```python - from motion import Component + # Usage: + # ```python + # from motion import Component - MyComponent = Component("MyComponent") + # MyComponent = Component("MyComponent") - @c.save_state - def save(state): - # state might have other unpicklable keys, like a DB connection - return {"fit_count": state["fit_count"]} - ``` + # @c.save_state + # def save(state): + # # state might have other unpicklable keys, like a DB connection + # return {"fit_count": state["fit_count"]} + # ``` - Args: - func (Callable): Function that returns a cloudpickleable object. + # Args: + # func (Callable): Function that returns a cloudpickleable object. - Returns: - Callable: Decorated save_state function. - """ - self._save_state_func = func - return func + # Returns: + # Callable: Decorated save_state function. + # """ + # self._save_state_func = func + # return func - def load_state(self, func: Callable) -> Callable: - """Decorator for the load_state function. This function - loads the state of the component from the unpickled state. + # def load_state(self, func: Callable) -> Callable: + # """Decorator for the load_state function. This function + # loads the state of the component from the unpickled state. - Usage: - ```python - from motion import Component + # Usage: + # ```python + # from motion import Component - MyComponent = Component("MyComponent") + # MyComponent = Component("MyComponent") - @c.load_state - def load(state): - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - return {"cursor": cursor, "fit_count": state["fit_count"]} - ``` + # @c.load_state + # def load(state): + # conn = sqlite3.connect(":memory:") + # cursor = conn.cursor() + # return {"cursor": cursor, "fit_count": state["fit_count"]} + # ``` - Args: - func (Callable): Function that consumes a cloudpickleable object. - Should return a dictionary representing the state of the - component instance. + # Args: + # func (Callable): Function that consumes a cloudpickleable object. + # Should return a dictionary representing the state of the + # component instance. - Returns: - Callable: Decorated load_state function. - """ - self._load_state_func = func - return func + # Returns: + # Callable: Decorated load_state function. + # """ + # self._load_state_func = func + # return func def serve(self, keys: Union[str, List[str]]) -> Callable: """Decorator for any serve operation for a dataflow through the @@ -549,8 +549,8 @@ def ... instance_id=instance_id, init_state_func=self._init_state_func, init_state_params=init_state_params, - save_state_func=self._save_state_func, - load_state_func=self._load_state_func, + # save_state_func=self._save_state_func, + # load_state_func=self._load_state_func, serve_routes=self._serve_routes, update_routes=self._update_routes, logging_level=logging_level, diff --git a/motion/copy_utils.py b/motion/copy_utils.py index 11bac16f..1fff6264 100644 --- a/motion/copy_utils.py +++ b/motion/copy_utils.py @@ -3,8 +3,7 @@ from one Redis instance to another. """ -import logging - +import picologging as logging import redis.asyncio as redis from motion.utils import RedisParams diff --git a/motion/dicts.py b/motion/dicts.py index 6ad74f30..eab12cf7 100644 --- a/motion/dicts.py +++ b/motion/dicts.py @@ -4,6 +4,8 @@ """ from typing import Any, Optional +# from motionstate import StateAccessor + import pandas as pd import pyarrow as pa @@ -87,73 +89,155 @@ def serve_result(self) -> Any: """ return self._serve_result - # def __getattr__(self, key: str) -> object: - # return self.__getitem__(key) - - # def __setattr__(self, key: str, value: Any) -> None: - # self[key] = value - - # def __getstate__(self) -> dict: - # return dict(self) - -class State(dict): - """Dictionary that stores state for a component instance. - The instance id is stored in the `instance_id` attribute. - - Example usage: - - ```python - from motion import Component - - some_component = Component("SomeComponent") - - @some_component.init_state - def setUp(): - return {"model": ...} - - @some_component.serve("retrieve") - def retrieve_nn(state, props): - # model can be accessed via state["model"] - prediction = state["model"](props["image_embedding"]) - # match the prediction to some other data to do a retrieval - nn_component_instance = SomeOtherMotionComponent(state.instance_id) - return nn_component_instance.run("retrieve", props={"prediction": prediction}) - - if __name__ == "__main__": - c = some_component() - nearest_neighbors = c.run("retrieve", props={"image_embedding": ...}) - ``` - """ - - def __init__( - self, - component_name: str, - instance_id: str, - *args: Any, - **kwargs: Any, - ) -> None: - self.component_name = component_name - self._instance_id = instance_id - super().__init__(*args, **kwargs) - - @property - def instance_id(self) -> str: - """ - Returns the instance id of the component. - Useful if wanting to create other component instances - within a serve or update operation. - """ - return self._instance_id - - def __getitem__(self, key: str) -> object: - try: - return super().__getitem__(key) - except KeyError: - raise KeyError( - f"Key `{key}` not found in state for " - + f"instance {self.component_name}__{self._instance_id}." - ) +# STATE_ERROR_MSG = ( +# "Cannot edit state directly. Use component update operations instead." +# ) + + +# class State: +# """Python class that stores state for a component instance. +# The instance id is stored in the `instance_id` attribute. + +# Example usage: + +# ```python +# from motion import Component + +# some_component = Component("SomeComponent") + +# @some_component.init_state +# def setUp(): +# return {"model": ...} + +# @some_component.serve("retrieve") +# def retrieve_nn(state, props): +# # model can be accessed via state["model"] +# prediction = state["model"](props["image_embedding"]) +# # match the prediction to some other data to do a retrieval +# nn_component_instance = SomeOtherMotionComponent(state.instance_id) +# return nn_component_instance.run("retrieve", props={"prediction": prediction}) + +# if __name__ == "__main__": +# c = some_component() +# nearest_neighbors = c.run("retrieve", props={"image_embedding": ...}) +# ``` +# """ + +# def __init__( +# self, +# component_name: str, +# instance_id: str, +# redis_host: str, +# redis_port: int, +# redis_db: int = 0, +# redis_password: Optional[str] = None, +# *args: Any, +# **kwargs: Any, +# ) -> None: +# self.component_name = component_name +# self._instance_id = instance_id +# self._state_accessor = StateAccessor( +# component_name, +# instance_id, +# 1000 +# * 60 +# * 2, # 2 minutes lock duration TODO: make this configurable +# redis_host, +# redis_port, +# redis_db, +# redis_password, +# ) +# super().__init__(*args, **kwargs) + +# def get_version(self) -> int: +# return self._state_accessor.version + +# @property +# def instance_id(self) -> str: +# """ +# Returns the instance id of the component. +# Useful if wanting to create other component instances +# within a serve or update operation. +# """ +# return self._instance_id + +# def clear_cache(self) -> None: +# # Clear the cache +# self._state_accessor.clear_cache() + +# def __getitem__(self, key: str) -> object: +# try: +# # Get from state accessor +# return self._state_accessor.get(key) +# except KeyError: +# raise KeyError( +# f"Key `{key}` not found in state for " +# + f"instance {self.component_name}__{self._instance_id}." +# ) + +# def get(self, key: str, default: Optional[Any] = None) -> Any: +# try: +# return self[key] +# except KeyError: +# return default + +# def __setitem__(self, key: str, value: Any) -> None: +# # Disable this functionality +# raise RuntimeError(STATE_ERROR_MSG) + +# def flushUpdateDict( +# self, update_dict: dict, from_migration: bool = False +# ) -> None: +# self._state_accessor.bulk_set(update_dict, from_migration) + +# def __delitem__(self, key: str) -> None: +# raise RuntimeError(STATE_ERROR_MSG) + +# def update(self, *args: Any, **kwargs: Any) -> None: +# raise RuntimeError(STATE_ERROR_MSG) + +# def clear(self) -> None: +# raise RuntimeError(STATE_ERROR_MSG) + +# def pop(self, *args: Any, **kwargs: Any) -> None: +# raise RuntimeError(STATE_ERROR_MSG) + +# def popitem(self, *args: Any, **kwargs: Any) -> None: +# raise RuntimeError(STATE_ERROR_MSG) + +# def keys(self) -> List[str]: +# return self._state_accessor.keys() + +# def values(self) -> List[Any]: +# """Values in the state dictionary. + +# Note: This fetches all the values from the state. We +# do not recommend using this method as it can be slow. +# Consider accessing values directly via `state[key]`. + +# Returns: +# List[Any]: List of values in the state. +# """ + +# return self._state_accessor.values() + +# def items(self) -> List[Tuple[str, Any]]: +# """Items in the state dictionary. + +# Note: This fetches all the key-value pairs from the state. +# We do not recommend using this method as it can be slow. +# If you need to iterate over the state, conditionally accessing +# values, we recommend using the `keys` method instead and then +# calling `state[key]` to access the value. + +# Returns: +# List[Tuple[str, Any]]: List of key-value pairs in the state. +# """ +# return self._state_accessor.items() + +# def __iter__(self) -> Iterator[str]: +# return iter(self._state_accessor.keys()) class MDataFrame(pd.DataFrame): diff --git a/motion/execute.py b/motion/execute.py index bfafccc0..02426d56 100644 --- a/motion/execute.py +++ b/motion/execute.py @@ -1,5 +1,4 @@ import asyncio -import logging import multiprocessing import threading from concurrent.futures import ThreadPoolExecutor @@ -7,20 +6,20 @@ from uuid import uuid4 import cloudpickle +import picologging as logging import psutil import redis -from motion.dicts import Properties, State +from motion.dicts import Properties from motion.route import Route from motion.server.update_task import UpdateProcess, UpdateThread +from motion.state import State from motion.utils import ( RedisParams, UpdateEvent, UpdateEventGroup, get_redis_params, hash_object, - loadState, - saveState, ) logger = logging.getLogger(__name__) @@ -33,8 +32,8 @@ def __init__( cache_ttl: int, init_state_func: Optional[Callable], init_state_params: Dict[str, Any], - save_state_func: Optional[Callable], - load_state_func: Optional[Callable], + # save_state_func: Optional[Callable], + # load_state_func: Optional[Callable], serve_routes: Dict[str, Route], update_routes: Dict[str, List[Route]], update_task_type: Literal["thread", "process"] = "thread", @@ -46,8 +45,6 @@ def __init__( self._init_state_func = init_state_func self._init_state_params = init_state_params - self._load_state_func = load_state_func - self._save_state_func = save_state_func self.running: Any = multiprocessing.Value("b", False) self._redis_socket_timeout = redis_socket_timeout @@ -65,26 +62,16 @@ def __init__( self.running.value = True # Set up state - self.version = self._redis_con.get(f"MOTION_VERSION:{self._instance_name}") self._state = State( - instance_name.split("__")[0], - instance_name.split("__")[1], - {}, + self._instance_name.split("__")[0], + self._instance_name.split("__")[1], + self._redis_params.dict(), ) - if self.version is None: - self.version = 1 - # Setup state - self._state.update(self.setUp(**self._init_state_params)) - saveState( - self._state, - self._redis_con, - self._instance_name, - self._save_state_func, - ) - else: - # Load state - self.version = -1 # will get updated in _loadState - self._loadState() + + # If it is version 0, then call the init_state_func + if self._state.get_version() == 0 and self._init_state_func is not None: + update_dict = self._init_state_func(**self._init_state_params) + self._state.flushUpdateDict(update_dict) # Set up routes self._serve_routes: Dict[str, Route] = serve_routes @@ -121,19 +108,11 @@ def _connectToRedis(self) -> Tuple[RedisParams, redis.Redis]: r = redis.Redis(**param_dict) return rp, r - def _loadState(self) -> None: - redis_v = self._redis_con.get(f"MOTION_VERSION:{self._instance_name}") - if not redis_v: - raise ValueError( - f"Error loading state for {self._instance_name}." + " No version found." - ) - - if self.version and self.version < int(redis_v): - # Reload state - self._state = loadState( - self._redis_con, self._instance_name, self._load_state_func - ) - self.version = int(redis_v) + def _loadState(self, force_refresh: bool = True) -> State: + # Clear state cache + if force_refresh: + self._state.clear_cache() + return self._state def setUp(self, **kwargs: Any) -> Dict[str, Any]: # Set up initial state @@ -170,8 +149,6 @@ def _build_fit_jobs(self) -> None: self.worker_task = update_cls( instance_name=self._instance_name, routes=self.route_dict_for_fit, - save_state_func=self._save_state_func, - load_state_func=self._load_state_func, queue_identifiers=self.queue_ids_for_fit, channel_identifiers=self.channel_dict_for_fit, redis_params=self._redis_params.dict(), @@ -206,8 +183,6 @@ def _monitor_process(self) -> None: self.worker_task = update_cls( instance_name=self._instance_name, routes=self.route_dict_for_fit, - save_state_func=self._save_state_func, - load_state_func=self._load_state_func, queue_identifiers=self.queue_ids_for_fit, channel_identifiers=self.channel_dict_for_fit, redis_params=self._redis_params.dict(), @@ -264,57 +239,14 @@ def shutdown(self, is_open: bool) -> None: self.monitor_thread.join() - def _updateState( - self, - new_state: Dict[str, Any], - force_update: bool = True, - use_lock: bool = True, - ) -> None: + def _updateState(self, new_state: Dict[str, Any]) -> None: if not new_state: return if not isinstance(new_state, dict): raise TypeError("State should be a dict.") - # Get latest state - if use_lock: - with self._redis_con.lock( - f"MOTION_LOCK:{self._instance_name}", timeout=120 - ): - if force_update: - self._loadState() - self._state.update(new_state) - - # Save state to redis - saveState( - self._state, - self._redis_con, - self._instance_name, - self._save_state_func, - ) - - version = self._redis_con.get(f"MOTION_VERSION:{self._instance_name}") - if version is None: - raise ValueError("Version not found in Redis.") - self.version = int(version) - - else: - if force_update: - self._loadState() - self._state.update(new_state) - - # Save state to redis - saveState( - self._state, - self._redis_con, - self._instance_name, - self._save_state_func, - ) - - version = self._redis_con.get(f"MOTION_VERSION:{self._instance_name}") - if version is None: - raise ValueError("Version not found in Redis.") - self.version = int(version) + self._state.flushUpdateDict(new_state) def _enqueue_and_trigger_update( self, @@ -334,32 +266,21 @@ def _enqueue_and_trigger_update( if flush_update: route = self._update_routes[key][update_udf_name] - # Hold lock - - with self._redis_con.lock( - f"MOTION_LOCK:{self._instance_name}", timeout=120 - ): - try: - self._loadState() - - state_update = route.run( - state=self._state, - props=props, - ) - - if not isinstance(state_update, dict): - raise ValueError("State update must be a dict.") - else: - # Update state - self._updateState( - state_update, - force_update=False, - use_lock=False, - ) - except Exception as e: - raise RuntimeError( - "Error running update route in main process: " + str(e) - ) + try: + state_update = route.run( + state=self._state, + props=props, + ) + + if not isinstance(state_update, dict): + raise ValueError("State update must be a dict.") + else: + # Update state + self._updateState(state_update) + except Exception as e: + raise RuntimeError( + "Error running update route in main process: " + str(e) + ) else: # Enqueue update @@ -406,33 +327,24 @@ async def _async_enqueue_and_trigger_update( if flush_update: route = self._update_routes[key][update_udf_name] - with self._redis_con.lock( - f"MOTION_LOCK:{self._instance_name}", timeout=120 - ): - try: - self._loadState() - - state_update = route.run( - state=self._state, - props=props, - ) - - if asyncio.iscoroutine(state_update): - state_update = await state_update - - if not isinstance(state_update, dict): - raise ValueError("State update must be a dict.") - else: - # Update state - self._updateState( - state_update, - force_update=False, - use_lock=False, - ) - except Exception as e: - raise RuntimeError( - "Error running update route in main process: " + str(e) - ) + try: + state_update = route.run( + state=self._state, + props=props, + ) + + if asyncio.iscoroutine(state_update): + state_update = await state_update + + if not isinstance(state_update, dict): + raise ValueError("State update must be a dict.") + else: + # Update state + self._updateState(state_update) + except Exception as e: + raise RuntimeError( + "Error running update route in main process: " + str(e) + ) else: # Enqueue update @@ -511,6 +423,13 @@ def run( # Run the serve route if key in self._serve_routes.keys(): + # Check if the function is an async function + if asyncio.iscoroutinefunction(self._serve_routes[key].udf): + raise TypeError( + f"Route {key} is an async function. " + + "Call `await instance.arun(...)` instead." + ) + route_hit = True ( route_run, @@ -535,7 +454,7 @@ def run( ) # Cache result - if value_hash: + if value_hash and self._cache_ttl > 0: cache_result_key = ( f"MOTION_RESULT:{self._instance_name}/{key}/{value_hash}" ) @@ -588,7 +507,7 @@ async def arun( props._serve_result = serve_result # Cache result - if value_hash: + if value_hash and self._cache_ttl > 0: cache_result_key = ( f"MOTION_RESULT:{self._instance_name}/{key}/{value_hash}" ) diff --git a/motion/instance.py b/motion/instance.py index dbd84fec..f495b3f0 100644 --- a/motion/instance.py +++ b/motion/instance.py @@ -1,7 +1,8 @@ import atexit -import logging from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional +import picologging as logging + from motion.execute import Executor from motion.route import Route from motion.utils import DEFAULT_KEY_TTL, configureLogging @@ -27,8 +28,8 @@ def __init__( instance_id: str, init_state_func: Optional[Callable], init_state_params: Optional[Dict[str, Any]], - save_state_func: Optional[Callable], - load_state_func: Optional[Callable], + # save_state_func: Optional[Callable], + # load_state_func: Optional[Callable], serve_routes: Dict[str, Route], update_routes: Dict[str, List[Route]], logging_level: str = "WARNING", @@ -55,7 +56,7 @@ def __init__( configureLogging(logging_level) # self._serverless = serverless # indicator = "serverless" if serverless else "local" - logger.info(f"Creating local instance of {self._component_name}...") + logger.debug(f"Creating local instance of {self._component_name}...") atexit.register(self.shutdown) # Create instance name @@ -69,8 +70,8 @@ def __init__( cache_ttl=self._cache_ttl, init_state_func=init_state_func, init_state_params=init_state_params if init_state_params else {}, - save_state_func=save_state_func, - load_state_func=load_state_func, + # save_state_func=save_state_func, + # load_state_func=load_state_func, serve_routes=serve_routes, update_routes=update_routes, update_task_type=update_task_type, @@ -162,9 +163,9 @@ def setUp(): c_instance.get_version() # Returns 1 (first version) ``` """ - return self._executor.version # type: ignore + return self._executor._state.get_version() # type: ignore - def write_state(self, state_update: Dict[str, Any], latest: bool = False) -> None: + def write_state(self, state_update: Dict[str, Any]) -> None: """Writes the state update to the component instance's state. If a update op is currently running, the state update will be applied after the update op is finished. Warning: this could @@ -194,17 +195,12 @@ def setUp(): Args: state_update (Dict[str, Any]): Dictionary of key-value pairs to update the state with. - latest (bool, optional): Whether or not to apply the update - to the latest version of the state. - If true, Motion will redownload the latest version - of the state and apply the update to that version. You - only need to set this to true if you are updating an - instance you connected to a while ago and might be - outdated. Defaults to False. """ - self._executor._updateState(state_update, force_update=latest) + self._executor._updateState(state_update) - def read_state(self, key: str, default_value: Optional[Any] = None) -> Any: + def read_state( + self, key: str, default_value: Optional[Any] = None, force_refresh: bool = True + ) -> Any: """Gets the current value for the key in the component instance's state. Usage: @@ -232,13 +228,16 @@ def setUp(): key (str): Key in the state to get the value for. default_value (Optional[Any], optional): Default value to return if the key is not found. Defaults to None. + force_refresh (bool, optional): Read the latest value of the state + in the KV store, otherwise return what is in the cache. Returns: Any: Current value for the key, or default_value if the key is not found. """ - self._executor._loadState() - return self._executor._state.get(key, default_value) + + return self._executor._loadState(force_refresh).get(key, default_value) + def flush_update(self, dataflow_key: str) -> None: """Flushes the update queue corresponding to the dataflow diff --git a/motion/migrate.py b/motion/migrate.py index a4d7ebba..1aae7fb1 100644 --- a/motion/migrate.py +++ b/motion/migrate.py @@ -1,15 +1,15 @@ import inspect -import logging from multiprocessing import Pool from typing import Callable, List, Optional, Tuple +import picologging as logging import redis from pydantic import BaseConfig, BaseModel, Field from tqdm import tqdm from motion.component import Component -from motion.dicts import State -from motion.utils import get_redis_params, loadState, saveState +from motion.state import State +from motion.utils import get_redis_params logger = logging.getLogger(__name__) @@ -17,31 +17,37 @@ def process_migration( instance_name: str, migrate_func: Callable, - load_state_fn: Callable, - save_state_fn: Callable, + timeout: int = 60, # 60-second timeout for lock ) -> Tuple[str, Optional[Exception]]: try: rp = get_redis_params() redis_con = redis.Redis( **rp.dict(), ) - state = loadState(redis_con, instance_name, load_state_fn) - new_state = migrate_func(state) - assert isinstance(new_state, dict), ( - "Migration function must return a dict." - + " Warning: partial progress may have been made!" - ) - empty_state = State( + + state = State( instance_name.split("__")[0], instance_name.split("__")[1], - {}, + redis_params=rp.dict(), ) - empty_state.update(new_state) - saveState(empty_state, redis_con, instance_name, save_state_fn) + + # Acquire lock to prevent other writes during migration + with redis_con.lock(f"MOTION_LOCK:{instance_name}", timeout=timeout): + # state = loadState(redis_con, instance_name, load_state_fn) + state_updates = migrate_func(state) + + assert isinstance(state_updates, dict), ( + "Migration function must return a dict of updates." + + " Warning: partial progress may have been made!" + ) + state.flushUpdateDict(state_updates, from_migration=True) + # saveState(empty_state, redis_con, instance_name, save_state_fn) except Exception as e: if isinstance(e, AssertionError): raise e else: + # Log an error and continue + logger.error(f"Error migrating {instance_name}: {e}", exc_info=True) return instance_name, e redis_con.close() @@ -85,7 +91,10 @@ def __init__(self, component: Component, migrate_func: Callable) -> None: self.migrate_func = migrate_func def migrate( - self, instance_ids: List[str] = [], num_workers: int = 4 + self, + instance_ids: List[str] = [], + num_workers: int = 4, + lock_timeout: int = 60, ) -> List[MigrationResult]: """Performs the migrate_func for component instances' states. If instance_ids is empty, then migrate_func is performed for all @@ -98,6 +107,10 @@ def migrate( num_workers (int, optional): Number of workers to use for parallel processing the migration. Defaults to 4. + lock_timeout (int, optional): + Number of seconds to lock the state object during its migration + operation to prevent any other writes from other processes. + Defaults to 60. Returns: List[MigrationResult]: @@ -117,8 +130,8 @@ def migrate( ] if not instance_names: instance_names = [ - key.decode("utf-8").replace("MOTION_STATE:", "") # type: ignore - for key in redis_con.keys(f"MOTION_STATE:{self.component.name}__*") + key.decode("utf-8").replace("MOTION_VERSION:", "") # type: ignore + for key in redis_con.keys(f"MOTION_VERSION:{self.component.name}__*") ] if not instance_names: @@ -131,8 +144,7 @@ def migrate( ( instance_name, self.migrate_func, - self.component._load_state_func, - self.component._save_state_func, + lock_timeout, ) for instance_name in instance_names ] diff --git a/motion/motion.pyi b/motion/motion.pyi new file mode 100644 index 00000000..c7f99a69 --- /dev/null +++ b/motion/motion.pyi @@ -0,0 +1,41 @@ +from typing import Any, Dict, List, Optional, Tuple +from pyo3 import PyObject + +class StateAccessor: + component_name: str + instance_id: str + lock_duration: int + client: Any # Figure out redis type + cache: Dict[str, PyObject] + lock_manager: Any # Figure out redis type + max_lock_attempts: int + + def __init__( + self, + component_name: str, + instance_id: str, + lock_duration: int, + redis_host: str, + redis_port: int, + redis_db: int, + redis_password: Optional[str] = None, + redis_ssl: Optional[bool] = None, + ) -> None: ... + @property + def version(self) -> int: ... + def set(self, key: str, value: Any) -> None: ... + def bulk_set(self, items: Dict[str, Any], from_migration: bool = False) -> None: ... + def get(self, key: str) -> Any: ... + def items(self) -> List[Tuple[str, Any]]: ... + def keys(self) -> List[str]: ... + def values(self) -> List[Any]: ... + def clear_cache(self) -> None: ... + +class TempValue: + def __init__(self, value: PyObject, ttl: int) -> None: ... + @property + def value(self) -> PyObject: ... + @property + def ttl(self) -> int: ... + @ttl.setter + def ttl(self, new_ttl: int) -> None: ... diff --git a/motion/route.py b/motion/route.py index cbde151e..245e7ed7 100644 --- a/motion/route.py +++ b/motion/route.py @@ -1,7 +1,7 @@ import inspect -import logging from typing import Any, Callable, Dict +import picologging as logging from pydantic import BaseModel, Field, PrivateAttr logger = logging.getLogger(__name__) diff --git a/motion/server/update_task.py b/motion/server/update_task.py index 20733bbb..02b27a87 100644 --- a/motion/server/update_task.py +++ b/motion/server/update_task.py @@ -2,13 +2,14 @@ import traceback from multiprocessing import Process from threading import Thread -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import cloudpickle import redis from motion.route import Route -from motion.utils import loadState, logger, saveState +from motion.state import State +from motion.utils import logger class BaseUpdateTask: @@ -17,8 +18,6 @@ def __init__( task_type: str, instance_name: str, routes: Dict[str, Route], - save_state_func: Optional[Callable], - load_state_func: Optional[Callable], queue_identifiers: List[str], channel_identifiers: Dict[str, str], redis_params: Dict[str, Any], @@ -28,8 +27,6 @@ def __init__( self.task_type = task_type self.name = f"UpdateTask-{task_type}-{instance_name}" self.instance_name = instance_name - self.save_state_func = save_state_func - self.load_state_func = load_state_func self.routes = routes self.queue_identifiers = queue_identifiers @@ -94,32 +91,40 @@ def custom_run(self) -> None: # Run update op try: - with redis_con.lock(f"MOTION_LOCK:{self.instance_name}", timeout=120): - old_state = loadState( - redis_con, - self.instance_name, - self.load_state_func, - ) - state_update = self.routes[queue_name].run( - state=old_state, - props=item["props"], + # with redis_con.lock(f"MOTION_LOCK:{self.instance_name}", timeout=120): + # old_state = loadState( + # redis_con, + # self.instance_name, + # self.load_state_func, + # ) + + old_state = State( + self.instance_name.split("__")[0], + self.instance_name.split("__")[1], + redis_params=self.redis_params, + ) + + state_update = self.routes[queue_name].run( + state=old_state, + props=item["props"], + ) + # Await if state_update is a coroutine + if asyncio.iscoroutine(state_update): + state_update = asyncio.run(state_update) + + if not isinstance(state_update, dict): + logger.error( + "Update methods should return a dict of state updates." ) - # Await if state_update is a coroutine - if asyncio.iscoroutine(state_update): - state_update = asyncio.run(state_update) - - if not isinstance(state_update, dict): - logger.error( - "Update methods should return a dict of state updates." - ) - else: - old_state.update(state_update) - saveState( - old_state, - redis_con, - self.instance_name, - self.save_state_func, - ) + else: + old_state.flushUpdateDict(state_update) + # old_state.update(state_update) + # saveState( + # old_state, + # redis_con, + # self.instance_name, + # self.save_state_func, + # ) except Exception: logger.error(traceback.format_exc()) exception_str = str(traceback.format_exc()) diff --git a/motion/state/__init__.py b/motion/state/__init__.py new file mode 100644 index 00000000..163c1e51 --- /dev/null +++ b/motion/state/__init__.py @@ -0,0 +1,3 @@ +from motion.state.state import State + +__all__ = ["State"] diff --git a/motion/state/state.py b/motion/state/state.py new file mode 100644 index 00000000..77ddccea --- /dev/null +++ b/motion/state/state.py @@ -0,0 +1,147 @@ +""" +This file contains the state class. +""" +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from ..motion import StateAccessor # type: ignore + +STATE_ERROR_MSG = "Cannot edit state directly. Use component update operations instead." + + +class State: + """Python class that stores state for a component instance. + The instance id is stored in the `instance_id` attribute. + + Example usage: + + ```python + from motion import Component + + some_component = Component("SomeComponent") + + @some_component.init_state + def setUp(): + return {"model": ...} + + @some_component.serve("retrieve") + def retrieve_nn(state, props): + # model can be accessed via state["model"] + prediction = state["model"](props["image_embedding"]) + # match the prediction to some other data to do a retrieval + nn_component_instance = SomeOtherMotionComponent(state.instance_id) + return nn_component_instance.run("retrieve", props={"prediction": prediction}) + + if __name__ == "__main__": + c = some_component() + nearest_neighbors = c.run("retrieve", props={"image_embedding": ...}) + ``` + """ + + def __init__( + self, + component_name: str, + instance_id: str, + redis_params: Dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> None: + self.component_name = component_name + self._instance_id = instance_id + self._state_accessor = StateAccessor( + component_name, + instance_id, + 1000 * 60 * 2, # 2 minutes lock duration TODO: make this configurable + redis_params["host"], + redis_params["port"], + redis_params["db"], + redis_params["password"], + redis_params.get("ssl", False), + ) + super().__init__(*args, **kwargs) + + def get_version(self) -> int: + return self._state_accessor.version # type: ignore + + @property + def instance_id(self) -> str: + """ + Returns the instance id of the component. + Useful if wanting to create other component instances + within a serve or update operation. + """ + return self._instance_id # type: ignore + + def clear_cache(self) -> None: + # Clear the cache + self._state_accessor.clear_cache() + + def __getitem__(self, key: str) -> Any: + try: + # Get from state accessor + return self._state_accessor.get(key) + except KeyError: + raise KeyError( + f"Key `{key}` not found in state for " + + f"instance {self.component_name}__{self._instance_id}." + ) + + def get(self, key: str, default: Optional[Any] = None) -> Any: + try: + return self[key] + except KeyError: + return default + + def __setitem__(self, key: str, value: Any) -> None: + # Disable this functionality + raise RuntimeError(STATE_ERROR_MSG) + + def flushUpdateDict(self, update_dict: dict, from_migration: bool = False) -> None: + self._state_accessor.bulk_set(update_dict, from_migration) + + def __delitem__(self, key: str) -> None: + raise RuntimeError(STATE_ERROR_MSG) + + def update(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError(STATE_ERROR_MSG) + + def clear(self) -> None: + raise RuntimeError(STATE_ERROR_MSG) + + def pop(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError(STATE_ERROR_MSG) + + def popitem(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError(STATE_ERROR_MSG) + + def keys(self) -> List[str]: + return self._state_accessor.keys() # type: ignore + + def values(self) -> List[Any]: + """Values in the state dictionary. + + Note: This fetches all the values from the state. We + do not recommend using this method as it can be slow. + Consider accessing values directly via `state[key]`. + + Returns: + List[Any]: List of values in the state. + """ + + return self._state_accessor.values() # type: ignore + + def items(self) -> List[Tuple[str, Any]]: + """Items in the state dictionary. + + Note: This fetches all the key-value pairs from the state. + We do not recommend using this method as it can be slow. + If you need to iterate over the state, conditionally accessing + values, we recommend using the `keys` method instead and then + calling `state[key]` to access the value. + + Returns: + List[Tuple[str, Any]]: List of key-value pairs in the state. + """ + return self._state_accessor.items() # type: ignore + + def __iter__(self) -> Iterator[str]: + return iter(self._state_accessor.keys()) # type: ignore diff --git a/motion/utils.py b/motion/utils.py index dca431ae..d4aeca02 100644 --- a/motion/utils.py +++ b/motion/utils.py @@ -1,17 +1,16 @@ import hashlib -import logging import os import random from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional -import cloudpickle -import colorlog +import picologging as logging import redis import yaml from pydantic import BaseModel -from motion.dicts import CustomDict, State +from motion.dicts import CustomDict +from motion.state import State logger = logging.getLogger(__name__) @@ -160,13 +159,15 @@ def clear_instance(instance_name: str) -> bool: return False # Delete the instance state, version, and cached results - redis_con.delete(f"MOTION_STATE:{instance_name}") redis_con.delete(f"MOTION_VERSION:{instance_name}") redis_con.delete(f"MOTION_LOCK:{instance_name}") + state_vals_to_delete = redis_con.keys(f"MOTION_STATE:{instance_name}/*") results_to_delete = redis_con.keys(f"MOTION_RESULT:{instance_name}/*") queues_to_delete = redis_con.keys(f"MOTION_QUEUE:{instance_name}/*") pipeline = redis_con.pipeline() + for state_val in state_vals_to_delete: + pipeline.delete(state_val) for result in results_to_delete: pipeline.delete(result) for queue in queues_to_delete: @@ -214,10 +215,15 @@ def inspect_state(instance_name: str) -> Dict[str, Any]: raise ValueError(f"Instance {instance_name} does not exist.") # Get the state - state = loadState(redis_con, instance_name, None) + state = State( + instance_name.split("__")[0], + instance_name.split("__")[1], + redis_params=rp.dict(), + ) + # Iterate through all items + all_items = {k: v for k, v in state.items()} - redis_con.close() - return state + return all_items def validate_args(parameters: Any, op: str) -> bool: @@ -234,68 +240,69 @@ def validate_args(parameters: Any, op: str) -> bool: def configureLogging(level: str) -> None: - formatter = colorlog.ColoredFormatter( - "%(log_color)s%(asctime)s %(levelname)-8s%(reset)s %(blue)s%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - log_colors={ - "DEBUG": "cyan", - "INFO": "green", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "bold_red", - }, - ) - - logger = logging.getLogger("motion") - if logger.hasHandlers(): - logger.handlers.clear() - - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(formatter) - - logger.addHandler(stream_handler) - logger.setLevel(level) - - -def loadState( - redis_con: redis.Redis, - instance_name: str, - load_state_func: Optional[Callable], -) -> State: - # Get state from redis - state = State(instance_name.split("__")[0], instance_name.split("__")[1], {}) - loaded_state = redis_con.get(f"MOTION_STATE:{instance_name}") - - if not loaded_state: - # This is an error - logger.warning(f"Could not find state for {instance_name}.") - return state - - # Unpickle state - loaded_state = cloudpickle.loads(loaded_state) - - if load_state_func is not None: - state.update(load_state_func(loaded_state)) - else: - state.update(loaded_state) - - return state - - -def saveState( - state_to_save: State, - redis_con: redis.Redis, - instance_name: str, - save_state_func: Optional[Callable], -) -> None: - # Save state to redis - if save_state_func is not None: - state_to_save = save_state_func(state_to_save) - - state_pickled = cloudpickle.dumps(state_to_save) - - redis_con.set(f"MOTION_STATE:{instance_name}", state_pickled) - redis_con.incr(f"MOTION_VERSION:{instance_name}") + # formatter = colorlog.ColoredFormatter( + # "%(log_color)s%(asctime)s %(levelname)-8s%(reset)s %(blue)s%(message)s", + # datefmt="%Y-%m-%d %H:%M:%S", + # log_colors={ + # "DEBUG": "cyan", + # "INFO": "green", + # "WARNING": "yellow", + # "ERROR": "red", + # "CRITICAL": "bold_red", + # }, + # ) + + logging.basicConfig(level=level) # type: ignore + # logger = logging.getLogger("motion") + # if logger.hasHandlers(): + # logger.handlers.clear() + + # stream_handler = logging.StreamHandler() + # stream_handler.setFormatter(formatter) + + # logger.addHandler(stream_handler) + # logger.setLevel(level) + + +# def loadState( +# redis_con: redis.Redis, +# instance_name: str, +# load_state_func: Optional[Callable], +# ) -> State: +# # Get state from redis +# state = State(instance_name.split("__")[0], instance_name.split("__")[1], {}) +# loaded_state = redis_con.get(f"MOTION_STATE:{instance_name}") + +# if not loaded_state: +# # This is an error +# logger.warning(f"Could not find state for {instance_name}.") +# return state + +# # Unpickle state +# loaded_state = cloudpickle.loads(loaded_state) + +# if load_state_func is not None: +# state.update(load_state_func(loaded_state)) +# else: +# state.update(loaded_state) + +# return state + + +# def saveState( +# state_to_save: State, +# redis_con: redis.Redis, +# instance_name: str, +# save_state_func: Optional[Callable], +# ) -> None: +# # Save state to redis +# if save_state_func is not None: +# state_to_save = save_state_func(state_to_save) + +# state_pickled = cloudpickle.dumps(state_to_save) + +# redis_con.set(f"MOTION_STATE:{instance_name}", state_pickled) +# redis_con.incr(f"MOTION_VERSION:{instance_name}") class UpdateEvent: diff --git a/poetry.lock b/poetry.lock index 6af668cd..e257a064 100644 --- a/poetry.lock +++ b/poetry.lock @@ -379,24 +379,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "colorlog" -version = "6.7.0" -description = "Add colours to the output of Python's logging module." -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "colorlog-6.7.0-py2.py3-none-any.whl", hash = "sha256:0d33ca236784a1ba3ff9c532d4964126d8a2c44f1f0cb1d2b0728196f512f662"}, - {file = "colorlog-6.7.0.tar.gz", hash = "sha256:bd94bd21c1e13fac7bd3153f4bc3a7dc0eb0974b8bc2fdf1a989e474f6e582e5"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "sys_platform == \"win32\""} - -[package.extras] -development = ["black", "flake8", "mypy", "pytest", "types-colorama"] - [[package]] name = "coverage" version = "7.3.2" @@ -1731,6 +1713,54 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "picologging" +version = "0.9.2" +description = "A fast and lightweight logging library for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "picologging-0.9.2-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:3ef4c40dd5029660d54949422eae1af4e1aa37f5fc2d551ccfc7bf31a72c4083"}, + {file = "picologging-0.9.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:d56134545a322e9bb97c95e52076dc86ccd4c9aa4ea21167f5f1c55f390de8de"}, + {file = "picologging-0.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3b59501f10cc088fe6a094ce3080ca52c538038081b2d3e632801ccd9a97e17"}, + {file = "picologging-0.9.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37b54895e2a122d0a009b82bd586a35a5762c66045958d0fc1d8c35a623ebc15"}, + {file = "picologging-0.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1b4ed6f6d574760e7ce8a277a0478b55da791b1251a6fc44ba2161b096f23ae"}, + {file = "picologging-0.9.2-cp310-cp310-win32.whl", hash = "sha256:8abb06d75563f35f69aef1fa59753705e2e566a97f954694f7b445a1f638fef7"}, + {file = "picologging-0.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:eeae1f52c1a6aeb88256f484d6ae52e92aa4d5761e66c34a913fc98226e269f5"}, + {file = "picologging-0.9.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:b62fdb47ae7c1261e945e0f562d0962b8559e1fd46afc1be30f7ca7decc01eca"}, + {file = "picologging-0.9.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:4fe519f607d03ccb0a629eb23491a9e155b8945e08b8a153600b2853852d9eda"}, + {file = "picologging-0.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7dfff150c3cf081e4a68c2528687cb7de8114f48578eaa07fb6168e7de2f012f"}, + {file = "picologging-0.9.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb5a85e87bce59a1874494ae2995f1fa4bcbf139e753af717a768313d6057404"}, + {file = "picologging-0.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d45472e29e40be5b1b473cfa4c320adebb088fdb7bd47f9111d21ee37b3fdd33"}, + {file = "picologging-0.9.2-cp311-cp311-win32.whl", hash = "sha256:8af05accda1babea4804c65262c7ac1e51586290f7fbff6e7fe93dca09ca1b0f"}, + {file = "picologging-0.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:b46ac219edac40d98f891e6a2064dbf10f1af328d6e6038997ebaf8b34c8de4f"}, + {file = "picologging-0.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:787f7b0ae666500d99836cd7704d58f752c5f60fd2c2f0c86ac00c09bfe18b80"}, + {file = "picologging-0.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01f35a6a3a0bbfaca8cd8a15af94d6cea1fe8a75d3295aa3779b1e048596fcc1"}, + {file = "picologging-0.9.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79a2117cd911aaaf705ef15510ca68478666e3d9f8840f108270b7650ad6d235"}, + {file = "picologging-0.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2d6886efbd31cb0fe1b900320056f3ad6b785029c8d2a1d64b1ab0b6e59164c"}, + {file = "picologging-0.9.2-cp37-cp37m-win32.whl", hash = "sha256:39fcb164b6ce296ad6598397f2f8a21e71f84f42a73cfc08963d227f2ece17db"}, + {file = "picologging-0.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:fc7c23960a2095cb3d57a60e3dfead16cfde99adcb635efeeb05a09bd5b28b20"}, + {file = "picologging-0.9.2-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:a98e5b2516763943fab7b5a29a3a1a387a69b1a5ba7e2c145a9f4221ee6ae6b5"}, + {file = "picologging-0.9.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:872a21939523f6fb72026d6a6e09a72df21221b922c62919e129cb69565506fa"}, + {file = "picologging-0.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47c13801574aa1eb89297a870d47b59dd18f5b5ad24b30d96ae8ba438439b604"}, + {file = "picologging-0.9.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc4d9053113670a11bf2f69a8a69cca88d57eda7b4e3b1c0eb1a8cf88d49fa72"}, + {file = "picologging-0.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98787d9dc4401672d0c1f39af1095fc4f53db3b07e2e17fb74059778939ab32d"}, + {file = "picologging-0.9.2-cp38-cp38-win32.whl", hash = "sha256:e322d1f5b6b7b9f2023cd205cd2ce5749e944359d5f53928e3c3d850f910aa33"}, + {file = "picologging-0.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:a8ab594d621104d7618929663a5b03af4b2a58679b594f19ba887f4e5d79fb1c"}, + {file = "picologging-0.9.2-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:9d7c02d0fc015dabde3c096d7ce84bd7f167114abba37c7dd6b761086518b6fa"}, + {file = "picologging-0.9.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:eefcc1e5bdc0003b37973e803c80b430cd3c1a9ab53839d3e5562d890be1cf77"}, + {file = "picologging-0.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c41ee46960d581faa2fe20098cd0541701d35c04c2e89774781524e7290cd897"}, + {file = "picologging-0.9.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37e61e6770113473e0aee841b14b6a92e3c8e1982a038f141e5dcd88ac7ff446"}, + {file = "picologging-0.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa728c45ef76c284b9febb794998e784ddf8e508d460e9d5f1ac70c0303317f"}, + {file = "picologging-0.9.2-cp39-cp39-win32.whl", hash = "sha256:ab1bbb809ab955ca852dc8a3e78fe14090280dcd81337f8efcb38272b3f24291"}, + {file = "picologging-0.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:bb026157fd752be9388d9df134631a5089f83e2dbeb1bd85319e5eb4ed410964"}, + {file = "picologging-0.9.2.tar.gz", hash = "sha256:bcb578063a2e2af01948b5d1cbd08c1d54a5411c916da826bf3f695724b93623"}, +] + +[package.extras] +dev = ["black", "flaky", "hypothesis", "pre-commit", "pytest", "pytest-cov", "rich"] + [[package]] name = "pkginfo" version = "1.9.6" @@ -3303,4 +3333,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4aa1b21fc79f39ac540e82e675c33fd6cbad507bac860a863c538f42af9929d9" +content-hash = "7b3052cf6cfb6f45a3811cb455180b681b7b5706cf6f7bed4494a8f6548b40e7" diff --git a/pyproject.toml b/pyproject.toml index 1eae0364..64dd3b45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "motion-python" -version = "0.1.109" +version = "0.2.0" description = "A trigger-based framework for creating and executing ML pipelines." authors = ["Shreya Shankar "] readme = "README.md" @@ -9,7 +9,6 @@ packages = [{include = "motion"}] [tool.poetry.dependencies] python = "^3.9" click = "^8.1.3" -colorlog = "^6.7.0" pydantic = "^1.10.7" cloudpickle = "^2.0" redis = "^4.5.5" @@ -18,6 +17,7 @@ rich = "^13.4.1" pyyaml = "^6.0.1" pandas = "^2.1.0" tqdm = "^4.66.1" +picologging = "^0.9.2" fastvs = "^0.1.7" pyarrow = "^14.0.1" @@ -72,8 +72,8 @@ show_error_codes = true omit = [".*", "*/site-packages/*"] [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +requires = ["maturin>=0.14,<0.15"] +build-backend = "maturin" [tool.coverage.report] exclude_lines = [ diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..a5dd2ea9 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,602 @@ +// pub mod state_value; +// use state_value::StateValue; + +pub mod temp_value; +use temp_value::TempValue; + +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyAny, PyBytes, PyDict, PyList}; +use redis::Commands; +use redlock::RedLock; +use std::collections::HashMap; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +enum PyValue { + Int(i64), + Float(f64), + String(String), + List(Vec), + Dict(HashMap), + // ... Add other types as needed. +} + +#[pyclass] +pub struct StateAccessor { + component_name: String, + instance_id: String, + lock_duration: usize, + version: u64, + client: redis::Client, + cache: HashMap, // Stores deserialized values + lock_manager: RedLock, + max_lock_attempts: u32, +} + +#[pymethods] +impl StateAccessor { + #[new] + pub fn new( + component_name: String, + instance_id: String, + lock_duration: u64, + redis_host: &str, + redis_port: u16, + redis_db: i64, + redis_password: Option<&str>, + redis_ssl: Option, + ) -> PyResult { + let use_ssl: bool = redis_ssl.unwrap_or(false); + let protocol: &str = if use_ssl { "rediss" } else { "redis" }; + + // Constructing the Redis URL with SSL consideration + let redis_url = match redis_password { + Some(password) => format!( + "{}://:{}@{}:{}/{}", + protocol, password, redis_host, redis_port, redis_db + ), + None => format!("{}://{}:{}/{}", protocol, redis_host, redis_port, redis_db), + }; + + let client = redis::Client::open(redis_url.clone()).map_err(|err| { + PyErr::new::(format!( + "Redis connection error: {}", + err + )) + })?; + + // Read the version from Redis. If it doesn't exist, set it to 0. + let mut con = client.get_connection().unwrap(); + let instancename = format!("MOTION_VERSION:{}__{}", component_name, instance_id); + let version: u64 = con.get(&instancename).unwrap_or(0); + + // Create a lock manager + let lock_manager = RedLock::new(vec![redis_url]); + let max_lock_attempts = 3; + + Ok(StateAccessor { + component_name, + instance_id, + lock_duration: lock_duration.try_into().unwrap(), + version, + client, + cache: HashMap::new(), + lock_manager, + max_lock_attempts, + }) + } + + #[getter] + pub fn version(&self) -> PyResult { + Ok(self.version) + } + + pub fn set(&mut self, py: Python, key: &str, value: &PyAny) -> PyResult<()> { + // Warning: This function does not check if the value is a TempValue. + // But it is also never called from the Python side, so it's fine. + let mut con = self.client.get_connection().unwrap(); + let serialized_data = Arc::new(serialize_value(py, value)?); + + // Create key name as MOTION_STATE:__/ + let keyname = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, key + ); + + // Acquire the lock using rslock + // Lockname will be MOTION_LOCK:__ + let lock_name = format!("MOTION_LOCK:{}__{}", self.component_name, self.instance_id); + let mut lock = None; + + // Loop until we get the lock + for _ in 0..self.max_lock_attempts { + match self + .lock_manager + .lock(lock_name.as_bytes(), self.lock_duration) + { + Ok(Some(l)) => { + lock = Some(l); + break; + } + Ok(None) => { + // Lock was not acquired. Sleep for 100ms and try again. + std::thread::sleep(std::time::Duration::from_millis(100)); + } + Err(e) => { + // Handle the Redis error, maybe return an error or log it. + return Err(PyErr::new::(format!( + "Failed to acquire lock due to Redis error: {}", + e + ))); + } + } + } + if lock.is_none() { + return Err(PyErr::new::(format!( + "Failed to acquire lock after {} attempts", + self.max_lock_attempts + ))); + } + + // Critical section + // Insert the key and value into the cache + self.cache.insert(keyname.clone(), value.into_py(py)); + + // Increment the version and write it to Redis + self.version += 1; + + // Insert the key and value into Redis through an atomic pipeline + redis::pipe() + .atomic() + .set(keyname.clone(), &*serialized_data) + .ignore() + .set( + format!( + "MOTION_VERSION:{}__{}", + self.component_name, self.instance_id + ), + self.version, + ) + .ignore() + .query(&mut con) + .map_err(|err| { + // Undo the cache insert and version increment + self.cache.remove(&keyname); + self.version -= 1; + + // Drop the lock + self.lock_manager.unlock(lock.as_ref().unwrap()); + + PyErr::new::(format!( + "Redis set data error: {}", + err + )) + })?; + + // Drop the lock + self.lock_manager.unlock(lock.as_ref().unwrap()); + + Ok(()) + } + + pub fn bulk_set(&mut self, py: Python, items: &PyDict, from_migration: bool) -> PyResult<()> { + let mut con = self.client.get_connection().unwrap(); + + // Preserialize all the data + let mut serialized_items = Vec::with_capacity(items.len()); + for (key, value) in items.iter() { + let keyname = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, key + ); + + // If value is of type TempValue, we should serialize + // the value inside it instead of the TempValue itself + // and extract the TTL from the TempValue. On default, + // the TTL will be None. + // let (value_to_serialize, ttl): (PyObject, Option); + if value.is_instance_of::() { + let temp_value: PyRef = value.extract()?; + // let value_to_serialize = &temp_value.value; + let value_ref: &PyAny = temp_value.value.as_ref(py); + let ttl = Some(temp_value.ttl); + + let serialized_data = Arc::new(serialize_value(py, value_ref)?); + serialized_items.push((keyname, serialized_data, ttl)); + } else { + let serialized_data = Arc::new(serialize_value(py, value)?); + serialized_items.push((keyname, serialized_data, None)); + } + + // let serialized_data = Arc::new(serialize_value(py, value_to_serialize)?); + // serialized_items.push((keyname, serialized_data, ttl)); + } + + let mut pipeline = redis::pipe(); + pipeline.atomic(); + + // If not from_migration, acquire the lock using rslock + // Lockname will be MOTION_LOCK:__ + let mut lock = None; + if !from_migration { + let lock_name = format!("MOTION_LOCK:{}__{}", self.component_name, self.instance_id); + + // Loop until we get the lock + for _ in 0..self.max_lock_attempts { + match self + .lock_manager + .lock(lock_name.as_bytes(), self.lock_duration) + { + Ok(Some(l)) => { + lock = Some(l); + break; + } + Ok(None) => { + // Lock was not acquired. Sleep for 100ms and try again. + std::thread::sleep(std::time::Duration::from_millis(100)); + } + Err(e) => { + // Handle the Redis error, maybe return an error or log it. + return Err(PyErr::new::(format!( + "Failed to acquire lock due to Redis error: {}", + e + ))); + } + } + } + if lock.is_none() { + return Err(PyErr::new::(format!( + "Failed to acquire lock after {} attempts", + self.max_lock_attempts + ))); + } + } + + // Critical section + for (keyname, serialized_data, ttl) in serialized_items.iter() { + let unserialized_value = items + .get_item(keyname.replace( + &format!( + "MOTION_STATE:{}__{}/", + self.component_name, self.instance_id + ), + "", + )) + .unwrap(); + + // Insert the key and value into the cache + self.cache + .insert(keyname.clone(), unserialized_value.into_py(py)); + + // If ttl is not None, set the TTL + if let Some(ttl) = ttl { + pipeline + .cmd("SETEX") + .arg(keyname) + .arg(ttl) + .arg(&**serialized_data); + } else { + pipeline.cmd("SET").arg(keyname).arg(&**serialized_data); + } + } + + // Increment the version and write it to Redis + self.version += 1; + pipeline + .set( + format!( + "MOTION_VERSION:{}__{}", + self.component_name, self.instance_id + ), + self.version, + ) + .ignore(); + + // Execute the pipeline, throwing a Python error if it fails + pipeline.query(&mut con).map_err(|err| { + // Undo the cache insert and version increment + for (key, _) in items { + let keyname = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, key + ); + self.cache.remove(&keyname); + } + self.version -= 1; + + // Drop the lock if from_migration is false + if !from_migration { + self.lock_manager.unlock(lock.as_ref().unwrap()); + } + + PyErr::new::(format!( + "Redis bulk set error: {}", + err + )) + })?; + + // Drop the lock if from_migration is false + if !from_migration { + self.lock_manager.unlock(lock.as_ref().unwrap()); + } + + Ok(()) + } + + pub fn get(&mut self, py: Python, key: &str) -> PyResult { + // Create key name as MOTION_STATE:__/ + let keyname = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, key + ); + + // Return the cached object if it exists + if let Some(value) = self.cache.get(&keyname) { + return Ok(value.clone_ref(py)); + } + + // Otherwise, fetch it from Redis + let mut con = self.client.get_connection().unwrap(); + let result_data: redis::RedisResult>> = con.get(&keyname); + + match result_data { + Ok(Some(data)) => { + // Deserialize the value + let deserialized_value = deserialize_value(py, &data)?; + + // Insert the deserialized value into the cache + self.cache + .insert(keyname.clone(), deserialized_value.clone_ref(py)); + + Ok(deserialized_value) + } + Ok(None) => Err(PyErr::new::("Key not found")), + Err(err) => Err(PyErr::new::(format!( + "Redis get error: {}", + err + ))), + } + } + + pub fn items(&mut self, py: Python) -> PyResult { + let items_list = pyo3::types::PyList::empty(py); + let pattern = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, "*" + ); + + let replaced_pattern = pattern.replace("*", ""); + let mut con = self.client.get_connection().unwrap(); + + // Minimized Redis calls by fetching everything in one go. + let keys: Vec = con.keys(pattern).map_err(|err| { + PyErr::new::(format!("Redis keys error: {}", err)) + })?; + + for key in keys { + let key_without_prefix = key.replace(&replaced_pattern, ""); + + // Avoid cloning the key for Python conversion. + let py_key = key_without_prefix.as_str().into_py(py); + let value = self.get(py, &key_without_prefix)?; + let tuple = pyo3::types::PyTuple::new(py, &[py_key, value]); + items_list.append(tuple)?; + } + + Ok(items_list.into()) + } + + pub fn keys(&self, _py: Python) -> PyResult> { + let pattern = format!( + "MOTION_STATE:{}__{}/{}", + self.component_name, self.instance_id, "*" + ); + + let mut con = self.client.get_connection().unwrap(); + let keys: Vec = con.keys(pattern.clone()).map_err(|err| { + PyErr::new::(format!("Redis keys error: {}", err)) + })?; + + let replaced_pattern = pattern.replace("*", ""); + Ok(keys + .into_iter() + .map(|key| key.replace(&replaced_pattern, "")) + .collect()) + } + + pub fn values(&mut self, py: Python) -> PyResult { + let values_list = pyo3::types::PyList::empty(py); + let keys = self.keys(py)?; + for key in keys.iter() { + let value = self.get(py, &key)?; + values_list.append(value)?; + } + Ok(values_list.into()) + } + + pub fn clear_cache(&mut self) { + self.cache.clear(); + + // Reset version to whatever is in Redis + let mut con = self.client.get_connection().unwrap(); + let version_key = format!( + "MOTION_VERSION:{}__{}", + self.component_name, self.instance_id + ); + let version: u64 = con.get(version_key).unwrap_or(0); + self.version = version; + } +} + +// Serialization Helpers + +fn cloudpickle_serialize(py: Python, value: &PyAny) -> PyResult> { + let cloudpickle = py.import("cloudpickle")?; + let bytes = cloudpickle + .getattr("dumps")? + .call1((value,))? + .extract::<&PyBytes>()?; + Ok(bytes.as_bytes().to_vec()) +} + +fn cloudpickle_deserialize(py: Python, value: &[u8]) -> PyResult { + let cloudpickle = py.import("cloudpickle")?; + let bytes_value = PyBytes::new(py, value); + let obj = cloudpickle.getattr("loads")?.call1((bytes_value,))?; + Ok(obj.into()) +} + +fn py_to_rust(value: &PyAny) -> PyResult { + if let Ok(val) = value.extract::() { + Ok(PyValue::Int(val)) + } else if let Ok(val) = value.extract::() { + Ok(PyValue::Float(val)) + } else if let Ok(val) = value.extract::() { + Ok(PyValue::String(val)) + } else if let Ok(val) = value.downcast::() { + let list: Vec<_> = val + .iter() + .map(|item| py_to_rust(item)) + .collect::>()?; + Ok(PyValue::List(list)) + } else if let Ok(val) = value.downcast::() { + let mut dict = HashMap::new(); + for (key, val) in val.iter() { + let key_str = key.extract::()?; + let val_rust = py_to_rust(val)?; + dict.insert(key_str, val_rust); + } + Ok(PyValue::Dict(dict)) + } else { + Err(PyErr::new::( + "Unsupported type for bincode serialization", + )) + } +} + +fn rust_to_py(py: Python, value: &PyValue) -> PyResult { + match value { + PyValue::Int(val) => Ok(val.into_py(py)), + PyValue::Float(val) => Ok(val.into_py(py)), + PyValue::String(val) => Ok(val.into_py(py)), + PyValue::List(val) => { + let list = PyList::empty(py); + for item in val { + let py_item = rust_to_py(py, item)?; + list.append(py_item)?; + } + Ok(list.into()) + } + PyValue::Dict(val) => { + let dict = PyDict::new(py); + for (key, value) in val { + let py_val = rust_to_py(py, value)?; + dict.set_item(key, py_val)?; + } + Ok(dict.into()) + } // ... Handle other cases. + } +} + +fn serialize_value(py: Python, value: &PyAny) -> PyResult> { + if let Ok(rust_value) = py_to_rust(value) { + let serialized = bincode::serialize(&rust_value) + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + Ok(serialized) + } else { + // Fall back to cloudpickle if not any of the defined types + let serialized = cloudpickle_serialize(py, value)?; + + Ok(serialized) + } +} + +fn deserialize_value(py: Python, value: &[u8]) -> PyResult { + match bincode::deserialize::(value) { + Ok(rust_value) => rust_to_py(py, &rust_value), + Err(_) => { + // Fall back to pickle if bincode deserialization fails + let deserialized = cloudpickle_deserialize(py, value)?; + Ok(deserialized.into()) + } + } +} + +#[pymodule] +fn motion(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + // m.add_class::()?; + m.add_class::()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::types::IntoPyDict; + + #[test] + fn state_init_with_valid_url() { + let _state = StateAccessor::new( + "component".to_string(), + "instance".to_string(), + 180 as u64, + "127.0.0.1", + 6381, + 0, + None, + None, + ) + .unwrap(); + } + + #[test] + fn state_init_with_invalid_url() { + let result = StateAccessor::new( + "component".to_string(), + "instance".to_string(), + 180 as u64, + "invalid", + 6381, + 0, + None, + None, + ); + assert!(result.is_err()); + } + + #[test] + fn cache_test() { + pyo3::Python::with_gil(|py| { + let mut state = StateAccessor::new( + "component".to_string(), + "instance".to_string(), + 180 as u64, + "127.0.0.1", + 6381, + 0, + None, + None, + ) + .unwrap(); + + // Set a value to Redis + let _ = state + .bulk_set(py, [("test_key", 42)].into_py_dict(py), false) + .unwrap(); + + // Clear cache to simulate fetching from Redis + state.clear_cache(); + let first_fetch = state.get(py, "test_key").unwrap(); + assert_eq!(first_fetch.extract::(py).unwrap(), 42); + + // This should be fetched from cache + let second_fetch = state.get(py, "test_key").unwrap(); + assert_eq!(second_fetch.extract::(py).unwrap(), 42); + }); + } +} diff --git a/src/state_value.rs b/src/state_value.rs new file mode 100644 index 00000000..ee5bb1d9 --- /dev/null +++ b/src/state_value.rs @@ -0,0 +1,47 @@ +// This file is not used yet. It is a placeholder for a future feature. + +use pyo3::exceptions::PyNotImplementedError; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyType}; + +#[pyclass(subclass)] +pub struct StateValue; + +#[pymethods] +impl StateValue { + #[new] + pub fn new() -> Self { + StateValue {} + } + + #[classmethod] + pub fn load(_cls: &PyType, _data: &PyBytes) -> PyResult<()> { + Err(PyNotImplementedError::new_err( + "The 'load' method has not been implemented.", + )) + } + + pub fn save(&self, _py: Python) -> PyResult<&PyBytes> { + Err(PyNotImplementedError::new_err( + "The 'save' method has not been implemented.", + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_not_implemented() { + pyo3::Python::with_gil(|py| { + let state_object = py.get_type::(); + let result = state_object.call_method1("load", ("some_data",)); + assert!(result.is_err()); + + let obj = state_object.call0().unwrap(); + let result = obj.call_method0("save"); + assert!(result.is_err()); + }); + } +} diff --git a/src/temp_value.rs b/src/temp_value.rs new file mode 100644 index 00000000..68789cb6 --- /dev/null +++ b/src/temp_value.rs @@ -0,0 +1,57 @@ +use pyo3::prelude::*; +use pyo3::PyObject; + +#[pyclass] +pub struct TempValue { + pub value: PyObject, + pub ttl: u64, +} + +#[pymethods] +impl TempValue { + #[new] + pub fn new(py: Python, value: PyObject, ttl: u64) -> Self { + TempValue { + value: value.into_py(py), + ttl, + } + } + + #[getter] + pub fn value(&self, py: Python) -> PyObject { + self.value.clone_ref(py) + } + + #[getter] + pub fn ttl(&self) -> u64 { + self.ttl + } + + #[setter] + pub fn set_ttl(&mut self, new_ttl: u64) { + self.ttl = new_ttl; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::types::IntoPyDict; + + #[test] + fn test_tempvalue_creation() { + Python::with_gil(|py| { + // Your code that requires Python's GIL goes here + let d = [("TempValue", py.get_type::())].into_py_dict(py); + let instance: PyObject = py + .eval("TempValue(value='hello', ttl=100)", Some(d), None) + .unwrap() + .extract() + .unwrap(); + + // Extract ttl as i64 and compare + let ttl_value: i64 = instance.getattr(py, "ttl").unwrap().extract(py).unwrap(); + assert_eq!(ttl_value, 100); + }); + } +} diff --git a/tests/parallel/test_async.py b/tests/parallel/test_async.py index 3ebb6ad9..f93e4224 100644 --- a/tests/parallel/test_async.py +++ b/tests/parallel/test_async.py @@ -50,13 +50,12 @@ async def test_async_update(): @pytest.mark.asyncio -@pytest.mark.timeout(1) # This test should take less than 3 seconds +@pytest.mark.timeout(2) # This test should take less than 3 seconds async def test_gather(): c = Counter(disable_update_task=True) tasks = [ - c.arun("multiply", props={"value": i}, flush_update=True) - for i in range(100) + c.arun("multiply", props={"value": i}, flush_update=True) for i in range(100) ] # Run all tasks at the same time await asyncio.gather(*tasks) diff --git a/tests/state/test_db_conn.py b/tests/state/test_db_conn.py index a3fce2f2..1945d947 100644 --- a/tests/state/test_db_conn.py +++ b/tests/state/test_db_conn.py @@ -1,6 +1,9 @@ +# This file tests the StateValue functionality + from motion import Component import sqlite3 +import os c = Component("DBComponent") @@ -8,8 +11,13 @@ @c.init_state def setUp(): # Create in-memory sqlite database - conn = sqlite3.connect(":memory:") + path = ":file::memory:?cache=shared:" + conn = sqlite3.connect(path) cursor = conn.cursor() + + # Drop table if exists + cursor.execute("DROP TABLE IF EXISTS users") + cursor.execute( """CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -17,32 +25,18 @@ def setUp(): age INTEGER)""" ) - cursor.execute( - "INSERT INTO users (name, age) VALUES (?, ?)", ("John Doe", 25) - ) - cursor.execute( - "INSERT INTO users (name, age) VALUES (?, ?)", ("Jane Smith", 30) - ) + cursor.execute("INSERT INTO users (name, age) VALUES (?, ?)", ("John Doe", 25)) + cursor.execute("INSERT INTO users (name, age) VALUES (?, ?)", ("Jane Smith", 30)) conn.commit() - return {"cursor": cursor, "fit_count": 0} - - -@c.save_state -def save(state): - return {"fit_count": state["fit_count"]} - - -@c.load_state -def load(state): - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - return {"cursor": cursor, "fit_count": state["fit_count"]} + return {"path": path, "fit_count": 0} @c.serve("count") def execute_fn(state, props): - return state["cursor"].execute("SELECT COUNT(*) FROM users").fetchall() + conn = sqlite3.connect(state["path"]) + cursor = conn.cursor() + return cursor.execute("SELECT COUNT(*) FROM users").fetchall() @c.serve("something") @@ -57,6 +51,9 @@ def increment(state, props): def test_db_component(): c_instance = c() - assert c_instance.run("count", props={"value": 1}) == [(2,)] + assert c_instance.run("count", props={"value": 1}, flush_update=True) == [(2,)] c_instance.run("something", props={"value": 1}, flush_update=True) assert c_instance.run("something", props={"value": 5}) == 1 + + # Delete the database + os.remove(c_instance.read_state("path")) diff --git a/tests/state/test_rust_vs_cloudpickle.py b/tests/state/test_rust_vs_cloudpickle.py new file mode 100644 index 00000000..ea618cdb --- /dev/null +++ b/tests/state/test_rust_vs_cloudpickle.py @@ -0,0 +1,51 @@ +from motion import Component + +import pytest +import time +import random +import numpy as np +import copy + +FragmentedState = Component("FragmentedState") +UnifiedState = Component("UnifiedState") + +NUM_KEYS = 1000 +VECTOR_LEN = 10000 +D = { + str(i): np.array([random.random() for _ in range(VECTOR_LEN)]) + for i in range(NUM_KEYS) +} +print("Done generating keys") + + +@FragmentedState.init_state +def setupf(): + # Make a bunch of keys and values + return D + + +@UnifiedState.init_state +def setupu(): + d = copy.deepcopy(D) + return {"state": d} + + +def test_key_level_serialization_faster(): + fs = FragmentedState() + print("Done initializing FragmentedState") + us = UnifiedState() + print("Done initializing UnifiedState") + + # Time how long it takes to read a key + start = time.time() + fs.read_state("0") + end = time.time() + key_level_time = end - start + + # Time how long it takes to serialize the same dict and read it + start = time.time() + us.read_state("state")["0"] + end = time.time() + unified_level_time = end - start + + assert key_level_time < unified_level_time diff --git a/tests/state/test_temp_value.py b/tests/state/test_temp_value.py new file mode 100644 index 00000000..2cbbc656 --- /dev/null +++ b/tests/state/test_temp_value.py @@ -0,0 +1,30 @@ +from motion import Component +from motion import TempValue + +import pytest +import time + +TempCounter = Component("TempCounter") + + +def test_temp_state_value(): + counter = TempCounter() + + # Assert nothing in it + assert counter.read_state("value") is None + + # Add a temp value + val = TempValue(0, ttl=1) + counter.write_state({"value": val}) + + # Check that it's there before clearing cache + assert counter.read_state("value") == 0 + + # Check that it's there after clearing cache + assert counter.read_state("value", force_refresh=True) == 0 + + # Sleep for a bit + time.sleep(1) + + # Check that it's gone after clearing cache + assert counter.read_state("value", force_refresh=True) is None diff --git a/tests/test_migrate.py b/tests/test_migrate.py index ac1ecaf9..06fbde0a 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -22,11 +22,15 @@ def migrator_not_returning_dict(state): return "this isn't a dict" -def good_migrator(state): +def bad_inplace_update(state): state.update({"another_val": 0}) return state +def good_migrator(state): + return {"another_val": 0} + + def test_state_migration(): # Create a bunch of instances instance_ids = [] @@ -42,9 +46,14 @@ def test_state_migration(): sm = StateMigrator(Something, migrator_not_returning_dict) sm.migrate() - # Run good migrator + # Run another bad migrator + sm = StateMigrator(Something, bad_inplace_update) + result = sm.migrate([instance_ids[0]], num_workers=1) + assert result[0].exception is not None + + # Run good migrator on one instance sm = StateMigrator(Something, good_migrator) - result = sm.migrate([instance_ids[0]]) + result = sm.migrate([instance_ids[0]], num_workers=1) assert len(result) == 1 assert result[0].instance_id == instance_ids[0] assert result[0].exception is None @@ -59,4 +68,5 @@ def test_state_migration(): # Assert the instances have the new state for instance_id in instance_ids: s = Something(instance_id) - assert s._executor._state == {"state_val": 0, "another_val": 0} + assert s._executor._state["state_val"] == 0 + assert s._executor._state["another_val"] == 0 diff --git a/unnecessary.py b/unnecessary.py new file mode 100644 index 00000000..76a4dc47 --- /dev/null +++ b/unnecessary.py @@ -0,0 +1,334 @@ +from typing import Any, Awaitable, Dict, Optional + +import httpx +import requests + + +class ComponentInstanceClient: + def __init__( + self, + component_name: str, + instance_id: str, + uri: str, + access_token: str, + **kwargs: Any, + ): + """Creates a new instance of a Motion component. + + Args: + component_name (str): + Name of the component we are creating an instance of. + instance_id (str): + ID of the instance we are creating. + """ + self._component_name = component_name + + # Create instance name + self._instance_name = f"{self._component_name}__{instance_id}" + + self.uri = uri + self.access_token = access_token + + self.kwargs = kwargs + + @property + def instance_name(self) -> str: + """Component name with a random phrase to represent + the name of this instance. + In the form of componentname__randomphrase. + """ + return self._instance_name + + @property + def instance_id(self) -> str: + """Latter part of the instance name, which is a random phrase + or a user-defined ID.""" + return self._instance_name.split("__")[-1] + + def write_state(self, state_update: Dict[str, Any], latest: bool = False) -> None: + """Writes the state update to the component instance's state. + If a update op is currently running, the state update will be + applied after the update op is finished. Warning: this could + take a while if your update ops take a long time! + + Usage: + ```python + from motion import Component + + C = Component("MyComponent") + + @C.init_state + def setUp(): + return {"value": 0} + + # Define serve and update operations + ... + + if __name__ == "__main__": + c_instance = C() + c_instance.read_state("value") # Returns 0 + c_instance.write_state({"value": 1, "value2": 2}) + c_instance.read_state("value") # Returns 1 + c_instance.read_state("value2") # Returns 2 + ``` + + Args: + state_update (Dict[str, Any]): Dictionary of key-value pairs + to update the state with. + latest (bool, optional): Whether or not to apply the update + to the latest version of the state. + If true, Motion will redownload the latest version + of the state and apply the update to that version. You + only need to set this to true if you are updating an + instance you connected to a while ago and might be + outdated. Defaults to False. + """ + # Ask server to update state + response = requests.post( + f"{self.uri}/update_state", + json={ + "instance_id": self.instance_id, + "state_update": state_update, + "kwargs": {"latest": latest}, + }, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code != 200: + raise RuntimeError( + f"Failed to update state for instance {self.instance_id}: {response.text}" + ) + + def read_state(self, key: str, default_value: Optional[Any] = None) -> Any: + """Gets the current value for the key in the component instance's state. + + Usage: + ```python + from motion import Component + + C = Component("MyComponent") + + @C.init_state + def setUp(): + return {"value": 0} + + # Define serve and update operations + ... + + if __name__ == "__main__": + c_instance = C() + c_instance.read_state("value") # Returns 0 + c_instance.run(...) + c_instance.read_state("value") # This will return the current value + # of "value" in the state + ``` + + Args: + key (str): Key in the state to get the value for. + default_value (Optional[Any], optional): Default value to return + if the key is not found. Defaults to None. + + Returns: + Any: Current value for the key, or default_value if the key + is not found. + """ + # Ask server to read state + response = requests.get( + f"{self.uri}/read_state", + params={"instance_id": self.instance_id, "key": key}, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code != 200: + raise RuntimeError( + f"Failed to read state for instance {self.instance_id}: {response.text}" + ) + + # Get response + result = response.json()["value"] + if not result: + return default_value + + return result + + def run( + self, + # *, + dataflow_key: str, + props: Dict[str, Any] = {}, + ignore_cache: bool = False, + force_refresh: bool = False, + flush_update: bool = False, + ) -> Any: + """Runs the dataflow (serve and update ops) for the keyword argument + passed in. If the key is not found to have any ops, an error + is raised. Only one dataflow key should be passed in. + + Example Usage: + ```python + from motion import Component + + C = Component("MyComponent") + + @C.init_state + def setUp(): + return {"value": 0} + + @C.serve("add") + def add(state, value): + return state["value"] + value + + @C.update("add") + def add(state, value): + return {"value": state["value"] + value} + + if __name__ == "__main__": + c = C() # Create instance of C + c.run("add", props={"value": 1}, flush_update=True) # (1)! + c.run("add", props={"value": 1}) # Returns 1 + c.run("add", props={"value": 2}, flush_update=True) # (2)! + + c.run("add", props={"value": 3}) + time.sleep(3) # Wait for the previous update op to finish + + c.run("add", props={"value": 3}, force_refresh=True) # (3)! + + # 1. Waits for the update op to finish, then updates the state + # 2. Returns 2, result state["value"] = 4 + # 3. Force refreshes the state before running the dataflow, and + # reruns the serve op even though the result might be cached. + ``` + + + Args: + dataflow_key (str): Key of the dataflow to run. + props (Dict[str, Any]): Keyword arguments to pass into the + dataflow ops, in addition to the state. + ignore_cache (bool, optional): + If True, ignores the cache and runs the serve op. Does not + force refresh the state. Defaults to False. + force_refresh (bool, optional): Read the latest value of the + state before running an serve call, otherwise a stale + version of the state or a cached result may be used. + Defaults to False. + flush_update (bool, optional): + If True, waits for the update op to finish executing before + returning. If the update queue hasn't reached batch_size + yet, the update op runs anyways. Force refreshes the + state after the update op completes. Defaults to False. + + Raises: + ValueError: If more than one dataflow key-value pair is passed. + RuntimeError: + If flush_update is called and the component instance update + processes are disabled. + + Returns: + Any: Result of the serve call. Might take a long time + to run if `flush_update = True` and the update operation is + computationally expensive. + """ + + # Ask server to run dataflow + response = requests.post( + f"{self.uri}/{self.component_name}", + json={ + "component_name": self.component_name, + "instance_id": self.instance_id, + "dataflow_key": dataflow_key, + "is_async": False, + "props": props, + "kwargs": { + "ignore_cache": ignore_cache, + "force_refresh": force_refresh, + "flush_update": flush_update, + }, + }, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code != 200: + raise RuntimeError( + f"Failed to run dataflow for instance {self.instance_id}: {response.text}" + ) + + # Get response + result = response.json()["value"] + return result + + async def arun( + self, + # *, + dataflow_key: str, + props: Dict[str, Any] = {}, + ignore_cache: bool = False, + force_refresh: bool = False, + flush_update: bool = False, + ) -> Awaitable[Any]: + """Async version of run. Runs the dataflow (serve and update ops) for + the specified key. You should use arun if either the serve or update op + is an async function. + + Example Usage: + ```python + from motion import Component + import asyncio + + C = Component("MyComponent") + + @C.serve("sleep") + async def sleep(state, value): + await asyncio.sleep(value) + return "Slept!" + + async def main(): + c = C() + await c.arun("sleep", props={"value": 1}) + + if __name__ == "__main__": + asyncio.run(main()) + ``` + + Args: + dataflow_key (str): Key of the dataflow to run. + props (Dict[str, Any]): Keyword arguments to pass into the + dataflow ops, in addition to the state. + ignore_cache (bool, optional): + If True, ignores the cache and runs the serve op. Does not + force refresh the state. Defaults to False. + force_refresh (bool, optional): Read the latest value of the + state before running an serve call, otherwise a stale + version of the state or a cached result may be used. + Defaults to False. + flush_update (bool, optional): + If True, waits for the update op to finish executing before + returning. If the update queue hasn't reached batch_size + yet, the update op runs anyways. Force refreshes the + state after the update op completes. Defaults to False. + + Raises: + ValueError: If more than one dataflow key-value pair is passed. + If flush_update is called and the component instance update + processes are disabled. + + Returns: + Awaitable[Any]: Awaitable Result of the serve call. + """ + + # Ask server to run dataflow asynchronously + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.uri}/{self.component_name}", + json={ + "component_name": self.component_name, + "instance_id": self.instance_id, + "dataflow_key": dataflow_key, + "is_async": True, + "props": props, + "kwargs": { + "ignore_cache": ignore_cache, + "force_refresh": force_refresh, + "flush_update": flush_update, + }, + }, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + response.raise_for_status() + return response.json()["value"]