From 6377756a9396389ef57c78f0244b78ecbbc50b86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Oct 2025 18:59:51 +0000 Subject: [PATCH 1/2] chore(pre-commit): [pre-commit.ci] autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v6.0.0) - [github.com/pre-commit/mirrors-clang-format: v18.1.8 → v21.1.2](https://github.com/pre-commit/mirrors-clang-format/compare/v18.1.8...v21.1.2) - [github.com/astral-sh/ruff-pre-commit: v0.5.0 → v0.13.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.0...v0.13.3) - [github.com/PyCQA/isort: 5.13.2 → 6.1.0](https://github.com/PyCQA/isort/compare/5.13.2...6.1.0) - https://github.com/psf/black → https://github.com/psf/black-pre-commit-mirror - [github.com/psf/black-pre-commit-mirror: 24.4.2 → 25.9.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...25.9.0) - [github.com/asottile/pyupgrade: v3.16.0 → v3.20.0](https://github.com/asottile/pyupgrade/compare/v3.16.0...v3.20.0) - [github.com/pycqa/flake8: 7.1.0 → 7.3.0](https://github.com/pycqa/flake8/compare/7.1.0...7.3.0) - [github.com/codespell-project/codespell: v2.3.0 → v2.4.1](https://github.com/codespell-project/codespell/compare/v2.3.0...v2.4.1) --- .pre-commit-config.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ab860a5..b3d574b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,10 +6,10 @@ ci: autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]" autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate" autoupdate_schedule: monthly -default_stages: [commit, push, manual] +default_stages: [pre-commit, pre-push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,24 +26,24 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.8 + rev: v21.1.2 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.13.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + rev: 6.1.0 hooks: - id: isort - - repo: https://github.com/psf/black - rev: 24.4.2 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.9.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 + rev: v3.20.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -52,7 +52,7 @@ repos: ^examples/ ) - repo: https://github.com/pycqa/flake8 - rev: 7.1.0 + rev: 7.3.0 hooks: - id: flake8 additional_dependencies: @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: [".[toml]"] From 5babbf71c03f4a53f68815d2dbc3909137d349ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Oct 2025 19:02:54 +0000 Subject: [PATCH 2/2] fix: [pre-commit.ci] auto fixes [...] --- include/adam_op/adam_op.h | 40 +- include/adam_op/adam_op_impl_cpu.h | 38 +- include/adam_op/adam_op_impl_cuda.cuh | 38 +- include/utils.h | 2 +- src/adam_op/adam_op.cpp | 40 +- src/adam_op/adam_op_impl_cpu.cpp | 86 +- src/adam_op/adam_op_impl_cuda.cu | 86 +- torchopt/diff/zero_order/decorator.py | 5 +- torchopt/distributed/api.py | 4 +- torchopt/nn/stateless.py | 2 +- torchopt/utils.py | 18 +- torchopt/version.py | 2 +- torchopt/visual.py | 2 +- tutorials/1_Functional_Optimizer.ipynb | 1164 +++++++++--------- tutorials/2_Visualization.ipynb | 7 +- tutorials/3_Meta_Optimizer.ipynb | 44 +- tutorials/4_Stop_Gradient.ipynb | 12 +- tutorials/5_Implicit_Differentiation.ipynb | 1150 ++++++++--------- tutorials/6_Zero_Order_Differentiation.ipynb | 8 +- 19 files changed, 1387 insertions(+), 1361 deletions(-) diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h index 2d0abcd3..e18f9edb 100644 --- a/include/adam_op/adam_op.h +++ b/include/adam_op/adam_op.h @@ -27,51 +27,51 @@ namespace py = pybind11; namespace adam_op { -TensorArray<3> adamForwardInplace(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplace(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMu(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMu(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNu(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNu(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, const pyuint_t count); -void buildSubmodule(py::module &mod); // NOLINT[runtime/references] +void buildSubmodule(py::module& mod); // NOLINT[runtime/references] } // namespace adam_op } // namespace torchopt diff --git a/include/adam_op/adam_op_impl_cpu.h b/include/adam_op/adam_op_impl_cpu.h index 4d54377e..c2125ff1 100644 --- a/include/adam_op/adam_op_impl_cpu.h +++ b/include/adam_op/adam_op_impl_cpu.h @@ -23,45 +23,45 @@ namespace torchopt { namespace adam_op { -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/include/adam_op/adam_op_impl_cuda.cuh b/include/adam_op/adam_op_impl_cuda.cuh index 17002b36..f38b3e7f 100644 --- a/include/adam_op/adam_op_impl_cuda.cuh +++ b/include/adam_op/adam_op_impl_cuda.cuh @@ -23,45 +23,45 @@ namespace torchopt { namespace adam_op { -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count); -TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1); -TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2); -TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/include/utils.h b/include/utils.h index cefabfac..3b029949 100644 --- a/include/utils.h +++ b/include/utils.h @@ -24,7 +24,7 @@ #endif namespace torchopt { -__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) { +__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) { const auto dim = tensor.dim(); size_t n = 1; for (std::decay_t i = 0; i < dim; ++i) { diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index 47f5d7f1..a0f61cc9 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -29,9 +29,9 @@ namespace py = pybind11; namespace adam_op { -TensorArray<3> adamForwardInplace(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplace(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -49,8 +49,8 @@ TensorArray<3> adamForwardInplace(const torch::Tensor &updates, } } -torch::Tensor adamForwardMu(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMu(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (updates.device().is_cuda()) { @@ -64,8 +64,8 @@ torch::Tensor adamForwardMu(const torch::Tensor &updates, } } -torch::Tensor adamForwardNu(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNu(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (updates.device().is_cuda()) { @@ -79,8 +79,8 @@ torch::Tensor adamForwardNu(const torch::Tensor &updates, } } -torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -98,9 +98,9 @@ torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu, } } -TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { #if defined(__USE_CUDA__) if (dmu.device().is_cuda()) { @@ -114,9 +114,9 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu, } } -TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { #if defined(__USE_CUDA__) if (dnu.device().is_cuda()) { @@ -130,10 +130,10 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu, } } -TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, @@ -152,7 +152,7 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates, } } -void buildSubmodule(py::module &mod) { // NOLINT[runtime/references] +void buildSubmodule(py::module& mod) { // NOLINT[runtime/references] py::module m = mod.def_submodule("adam_op", "Adam Ops"); m.def("forward_", &adamForwardInplace, diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 9c460685..38aa2bc0 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -37,9 +37,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_ptr, - scalar_t *__restrict__ mu_ptr, - scalar_t *__restrict__ nu_ptr) { + scalar_t* __restrict__ updates_ptr, + scalar_t* __restrict__ mu_ptr, + scalar_t* __restrict__ nu_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -61,9 +61,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, } } -TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -91,11 +91,11 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, } template -void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ mu_ptr, +void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ mu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ mu_out_ptr) { + scalar_t* __restrict__ mu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -107,8 +107,8 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, } } -torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCPU(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); @@ -125,11 +125,11 @@ torch::Tensor adamForwardMuCPU(const torch::Tensor &updates, } template -void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ nu_ptr, +void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ nu_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ nu_out_ptr) { + scalar_t* __restrict__ nu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -142,8 +142,8 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, } } -torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCPU(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); @@ -160,14 +160,14 @@ torch::Tensor adamForwardNuCPU(const torch::Tensor &updates, } template -void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, - const scalar_t *__restrict__ new_nu_ptr, +void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr, + const scalar_t* __restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_out_ptr) { + scalar_t* __restrict__ updates_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -180,8 +180,8 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, } } -torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -209,11 +209,11 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu, } template -void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, +void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dmu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dmu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -225,9 +225,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, } } -TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); @@ -245,12 +245,12 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu, } template -void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, - const scalar_t *__restrict__ updates_ptr, +void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr, + const scalar_t* __restrict__ updates_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dnu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dnu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -263,9 +263,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, } } -TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); @@ -284,14 +284,14 @@ TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu, } template -void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, - const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ new_mu_ptr, +void adamBackwardUpdatesCPUKernel(const scalar_t* __restrict__ dupdates_ptr, + const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ new_mu_ptr, const other_t one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const size_t n, - scalar_t *__restrict__ dnew_mu_out_ptr, - scalar_t *__restrict__ dnew_nu_out_ptr) { + scalar_t* __restrict__ dnew_mu_out_ptr, + scalar_t* __restrict__ dnew_nu_out_ptr) { #pragma omp parallel for num_threads( \ std::min(n / MIN_NUMEL_USE_OMP, \ static_cast(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) @@ -316,10 +316,10 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, } } -TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index a12eca4f..538ad7e5 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -35,9 +35,9 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_ptr, - scalar_t *__restrict__ mu_ptr, - scalar_t *__restrict__ nu_ptr) { + scalar_t* __restrict__ updates_ptr, + scalar_t* __restrict__ mu_ptr, + scalar_t* __restrict__ nu_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -62,9 +62,9 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, } } -TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, - const torch::Tensor &nu, +TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, + const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -112,11 +112,11 @@ TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ mu_ptr, +__global__ void adamForwardMuCUDAKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ mu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ mu_out_ptr) { + scalar_t* __restrict__ mu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -132,8 +132,8 @@ __global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr } } -torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, - const torch::Tensor &mu, +torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto mu_out = torch::empty_like(mu); @@ -165,11 +165,11 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ nu_ptr, +__global__ void adamForwardNuCUDAKernel(const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ nu_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ nu_out_ptr) { + scalar_t* __restrict__ nu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -186,8 +186,8 @@ __global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr } } -torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, - const torch::Tensor &nu, +torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto nu_out = torch::empty_like(nu); @@ -219,14 +219,14 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, } template -__global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu_ptr, - const scalar_t *__restrict__ new_nu_ptr, +__global__ void adamForwardUpdatesCUDAKernel(const scalar_t* __restrict__ new_mu_ptr, + const scalar_t* __restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const other_t eps, const other_t eps_root, const size_t n, - scalar_t *__restrict__ updates_out_ptr) { + scalar_t* __restrict__ updates_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -243,8 +243,8 @@ __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu } } -torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, @@ -291,11 +291,11 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, } template -__global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, +__global__ void adamBackwardMuCUDAKernel(const scalar_t* __restrict__ dmu_ptr, const other_t b1, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dmu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dmu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -311,9 +311,9 @@ __global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, } } -TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, - const torch::Tensor &updates, - const torch::Tensor &mu, +TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu, + const torch::Tensor& updates, + const torch::Tensor& mu, const pyfloat_t b1) { auto dupdates_out = torch::empty_like(updates); auto dmu_out = torch::empty_like(mu); @@ -346,12 +346,12 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, } template -__global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, - const scalar_t *__restrict__ updates_ptr, +__global__ void adamBackwardNuCUDAKernel(const scalar_t* __restrict__ dnu_ptr, + const scalar_t* __restrict__ updates_ptr, const other_t b2, const size_t n, - scalar_t *__restrict__ dupdates_out_ptr, - scalar_t *__restrict__ dnu_out_ptr) { + scalar_t* __restrict__ dupdates_out_ptr, + scalar_t* __restrict__ dnu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -368,9 +368,9 @@ __global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, } } -TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, - const torch::Tensor &updates, - const torch::Tensor &nu, +TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu, + const torch::Tensor& updates, + const torch::Tensor& nu, const pyfloat_t b2) { auto dupdates_out = torch::empty_like(updates); auto dnu_out = torch::empty_like(nu); @@ -405,14 +405,14 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, } template -__global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupdates_ptr, - const scalar_t *__restrict__ updates_ptr, - const scalar_t *__restrict__ new_mu_ptr, +__global__ void adamBackwardUpdatesCUDAKernel(const scalar_t* __restrict__ dupdates_ptr, + const scalar_t* __restrict__ updates_ptr, + const scalar_t* __restrict__ new_mu_ptr, const other_t one_minus_pow_b1, const other_t inv_one_minus_pow_b2, const size_t n, - scalar_t *__restrict__ dnew_mu_out_ptr, - scalar_t *__restrict__ dnew_nu_out_ptr) { + scalar_t* __restrict__ dnew_mu_out_ptr, + scalar_t* __restrict__ dnew_nu_out_ptr) { const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; #pragma unroll for (int i = 0; i < unroll_size; ++i) { @@ -441,10 +441,10 @@ __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupda } } -TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, - const torch::Tensor &updates, - const torch::Tensor &new_mu, - const torch::Tensor &new_nu, +TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates, + const torch::Tensor& updates, + const torch::Tensor& new_mu, + const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps_root, diff --git a/torchopt/diff/zero_order/decorator.py b/torchopt/diff/zero_order/decorator.py index e498b43c..ea10702d 100644 --- a/torchopt/diff/zero_order/decorator.py +++ b/torchopt/diff/zero_order/decorator.py @@ -17,7 +17,6 @@ from __future__ import annotations import functools -import itertools from typing import Any, Callable, Literal, Sequence from typing_extensions import TypeAlias # Python 3.10+ @@ -124,7 +123,7 @@ def add_perturbation( for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = list( - itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + map(add_perturbation, flat_diff_params, noises), ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, @@ -228,7 +227,7 @@ def add_perturbation(tensor: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: for _ in range(num_samples): noises = [distribution.sample(sample_shape=p.shape) for p in flat_diff_params] flat_noisy_params = list( - itertools.starmap(add_perturbation, zip(flat_diff_params, noises)), + map(add_perturbation, flat_diff_params, noises), ) noisy_params: list[Any] = pytree.tree_unflatten( # type: ignore[assignment] diff_params_treespec, diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 97be682f..a99f4802 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -318,12 +318,12 @@ def remote_async_call( futures.append(fut) future = cast( - Future[List[T]], + 'Future[List[T]]', torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]), ) if reducer is not None: return cast( - Future[U], + 'Future[U]', future.then(lambda fut: reducer(fut.wait())), ) return future diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index c7f92b86..0f4f17b3 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -84,7 +84,7 @@ def reparametrize( module: nn.Module, named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]], allow_missing: bool = False, -) -> Generator[nn.Module, None, None]: +) -> Generator[nn.Module]: """Reparameterize the module parameters and/or buffers.""" if not isinstance(named_tensors, dict): named_tensors = dict(named_tensors) diff --git a/torchopt/utils.py b/torchopt/utils.py index 5f9202a3..c95fd0dc 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -79,13 +79,13 @@ def fn_(obj: Any) -> None: obj.detach_().requires_grad_(requires_grad) if isinstance(target, ModuleState): - true_target = cast(TensorTree, (target.params, target.buffers)) + true_target = cast('TensorTree', (target.params, target.buffers)) elif isinstance(target, nn.Module): - true_target = cast(TensorTree, tuple(target.parameters())) + true_target = cast('TensorTree', tuple(target.parameters())) elif isinstance(target, MetaOptimizer): - true_target = cast(TensorTree, target.state_dict()) + true_target = cast('TensorTree', target.state_dict()) else: - true_target = cast(TensorTree, target) # tree of tensors + true_target = cast('TensorTree', target) # tree of tensors pytree.tree_map_(fn_, true_target) @@ -325,7 +325,7 @@ def recover_state_dict( from torchopt.optim.meta.base import MetaOptimizer if isinstance(target, nn.Module): - params, buffers, *_ = state = cast(ModuleState, state) + params, buffers, *_ = state = cast('ModuleState', state) params_containers, buffers_containers = extract_module_containers(target, with_buffers=True) if state.detach_buffers: @@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: ): tgt.update(src) elif isinstance(target, MetaOptimizer): - state = cast(Sequence[OptState], state) + state = cast('Sequence[OptState]', state) target.load_state_dict(state) else: raise TypeError(f'Unexpected class of {target}') @@ -422,9 +422,9 @@ def module_clone( # noqa: C901 if isinstance(target, (nn.Module, MetaOptimizer)): if isinstance(target, nn.Module): - containers = cast(TensorTree, extract_module_containers(target, with_buffers=True)) + containers = cast('TensorTree', extract_module_containers(target, with_buffers=True)) else: - containers = cast(TensorTree, target.state_dict()) + containers = cast('TensorTree', target.state_dict()) tensors = pytree.tree_leaves(containers) memo = {id(t): t for t in tensors} cloned = copy.deepcopy(target, memo=memo) @@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor: else: replicate = clone_detach_ - return pytree.tree_map(replicate, cast(TensorTree, target)) + return pytree.tree_map(replicate, cast('TensorTree', target)) @overload diff --git a/torchopt/version.py b/torchopt/version.py index 9fdcac9b..69aff7da 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -25,7 +25,7 @@ try: prefix, sep, suffix = ( - subprocess.check_output( # noqa: S603 + subprocess.check_output( ['git', 'describe', '--abbrev=7'], # noqa: S607 cwd=os.path.dirname(os.path.abspath(__file__)), stderr=subprocess.DEVNULL, diff --git a/torchopt/visual.py b/torchopt/visual.py index 7638d7ec..37b08c15 100644 --- a/torchopt/visual.py +++ b/torchopt/visual.py @@ -129,7 +129,7 @@ def make_dot( # noqa: C901 elif isinstance(param, Generator): param_map.update({v: k for k, v in param}) else: - param_map.update({v: k for k, v in cast(Mapping, param).items()}) + param_map.update({v: k for k, v in cast('Mapping', param).items()}) node_attr = { 'style': 'filled', diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb index afc55f38..231bceff 100644 --- a/tutorials/1_Functional_Optimizer.ipynb +++ b/tutorials/1_Functional_Optimizer.ipynb @@ -1,588 +1,588 @@ { "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# TorchOpt as Functional Optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Basic API\n", - "\n", - "In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "\n", - "import functorch\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax\n", - "import torch\n", - "import torch.autograd\n", - "import torch.nn as nn\n", - "\n", - "import torchopt\n", - "\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "def mse(inputs, targets):\n", - " return ((inputs - targets) ** 2).mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.1 Original JAX implementation\n", - "\n", - "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_jax():\n", - " batch_size = 1\n", - " dim = 1\n", - " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", - "\n", - " def model(params, x):\n", - " return jnp.matmul(x, params['weight']) + params['bias']\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = optax.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " def compute_loss(params, x, y):\n", - " pred = model(params, x)\n", - " return mse(pred, y)\n", - "\n", - " xs = 2 * jnp.ones((batch_size, dim))\n", - " ys = jnp.ones((batch_size, 1))\n", - "\n", - " grads = jax.grad(compute_loss)(params, xs, ys)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optax.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[1.]], dtype=float32)),\n", - " ('bias', DeviceArray([0.], dtype=float32))\n", - "])\n", - "Parameters after update:\n", - "OrderedDict([\n", - " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", - " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", - "])\n" - ] - } - ], - "source": [ - "origin_jax()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.2 `functorch` with TorchOpt\n", - "\n", - "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Obtain the `opt_state` that contains statistics for the optimizer\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.adam(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = torchopt.apply_updates(params, updates)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " Parameter containing: tensor([-1.0000], requires_grad=True)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "def interact_with_functorch_with_wrapper():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', params)\n", - " params = optimizer.step(loss, params)\n", - " print('Parameters after update:', params)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "(\n", - " Parameter containing: tensor([[1.]], requires_grad=True),\n", - " Parameter containing: tensor([0.], requires_grad=True)\n", - ")\n", - "Parameters after update:\n", - "(\n", - " tensor([[6.6757e-06]], grad_fn=),\n", - " tensor([-1.0000], grad_fn=)\n", - ")\n" - ] - } - ], - "source": [ - "interact_with_functorch_with_wrapper()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.3 Full TorchOpt\n", - "\n", - "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def full_torchopt():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " # High-level API\n", - " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", - " # Low-level API\n", - " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "full_torchopt()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.4 Original PyTorch\n", - "\n", - "The final example is to original PyTorch example with `torch.optim`." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "def origin_torch():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - "\n", - " learning_rate = 1.0\n", - " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", - "\n", - " xs = 2 * torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = net(xs)\n", - " loss = mse(pred, ys)\n", - "\n", - " print('Parameters before update:', dict(net.named_parameters()))\n", - " optim.zero_grad()\n", - " loss.backward()\n", - " optim.step()\n", - " print('Parameters after update:', dict(net.named_parameters()))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameters before update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", - "}\n", - "Parameters after update:\n", - "{\n", - " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", - " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", - "}\n" - ] - } - ], - "source": [ - "origin_torch()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Differentiable Optimization with Functional Optimizer\n", - "\n", - "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", - "\n", - "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "def differentiable():\n", - " batch_size = 1\n", - " dim = 1\n", - " net = Net(dim)\n", - " model, params = functorch.make_functional(net) # get the functional version of the model\n", - "\n", - " # Meta-parameter\n", - " meta_param = nn.Parameter(torch.ones(1))\n", - "\n", - " # SGD example\n", - " learning_rate = 1.0\n", - " optimizer = torchopt.sgd(learning_rate)\n", - " opt_state = optimizer.init(params)\n", - "\n", - " xs = torch.ones((batch_size, dim))\n", - " ys = torch.ones((batch_size, 1))\n", - "\n", - " pred = model(params, xs)\n", - " # Where meta_param is used\n", - " pred = pred + meta_param\n", - " loss = mse(pred, ys)\n", - "\n", - " grads = torch.autograd.grad(loss, params, create_graph=True)\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", - " # Update parameters with single step SGD update\n", - " params = torchopt.apply_updates(params, updates, inplace=False)\n", - "\n", - " pred = model(params, xs)\n", - " loss = mse(pred, ys)\n", - " loss.backward()\n", - "\n", - " print('Gradient for the meta-parameter:', meta_param.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Gradient for the meta-parameter: tensor([32.])\n" - ] - } - ], - "source": [ - "differentiable()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 Track the Gradient of Momentum\n", - "\n", - "Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Accelerated Optimizer\n", - "\n", - "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Check whether the `accelerated_op` is available:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TorchOpt as Functional Optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/1_Functional_Optimizer.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programming style. We will also illustrate how to conduct differentiable optimization with functional programming in PyTorch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Basic API\n", + "\n", + "In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import functorch\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import torch\n", + "import torch.autograd\n", + "import torch.nn as nn\n", + "\n", + "import torchopt\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "def mse(inputs, targets):\n", + " return ((inputs - targets) ** 2).mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1 Original JAX implementation\n", + "\n", + "The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programming style." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def origin_jax():\n", + " batch_size = 1\n", + " dim = 1\n", + " params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])\n", + "\n", + " def model(params, x):\n", + " return jnp.matmul(x, params['weight']) + params['bias']\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = optax.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " def compute_loss(params, x, y):\n", + " pred = model(params, x)\n", + " return mse(pred, y)\n", + "\n", + " xs = 2 * jnp.ones((batch_size, dim))\n", + " ys = jnp.ones((batch_size, 1))\n", + "\n", + " grads = jax.grad(compute_loss)(params, xs, ys)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optax.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[1.]], dtype=float32)),\n", + " ('bias', DeviceArray([0.], dtype=float32))\n", + "])\n", + "Parameters after update:\n", + "OrderedDict([\n", + " ('weight', DeviceArray([[6.735325e-06]], dtype=float32)),\n", + " ('bias', DeviceArray([-0.99999326], dtype=float32))\n", + "])\n" + ] + } + ], + "source": [ + "origin_jax()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2 `functorch` with TorchOpt\n", + "\n", + "The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " # Obtain the `opt_state` that contains statistics for the optimizer\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.adam(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = torchopt.apply_updates(params, updates)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " Parameter containing: tensor([-1.0000], requires_grad=True)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchOpt also offers a wrapper `torchopt.FuncOptimizer` to make it easier to maintain the optimizer states." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def interact_with_functorch_with_wrapper():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.FuncOptimizer(torchopt.adam(learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', params)\n", + " params = optimizer.step(loss, params)\n", + " print('Parameters after update:', params)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "(\n", + " Parameter containing: tensor([[1.]], requires_grad=True),\n", + " Parameter containing: tensor([0.], requires_grad=True)\n", + ")\n", + "Parameters after update:\n", + "(\n", + " tensor([[6.6757e-06]], grad_fn=),\n", + " tensor([-1.0000], grad_fn=)\n", + ")\n" + ] + } + ], + "source": [ + "interact_with_functorch_with_wrapper()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3 Full TorchOpt\n", + "\n", + "`torchopt.Optimizer` is the base class for our PyTorch-like optimizer. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam`, we can define our high-level API `torchopt.SGD` and `torchopt.Adam`. The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def full_torchopt():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " # High-level API\n", + " optim = torchopt.Adam(net.parameters(), lr=learning_rate)\n", + " # Low-level API\n", + " optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[6.6757e-06]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } + ], + "source": [ + "full_torchopt()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4 Original PyTorch\n", + "\n", + "The final example is to original PyTorch example with `torch.optim`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def origin_torch():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + "\n", + " learning_rate = 1.0\n", + " optim = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", + "\n", + " xs = 2 * torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = net(xs)\n", + " loss = mse(pred, ys)\n", + "\n", + " print('Parameters before update:', dict(net.named_parameters()))\n", + " optim.zero_grad()\n", + " loss.backward()\n", + " optim.step()\n", + " print('Parameters after update:', dict(net.named_parameters()))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameters before update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([0.], requires_grad=True)\n", + "}\n", + "Parameters after update:\n", + "{\n", + " 'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),\n", + " 'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)\n", + "}\n" + ] + } + ], + "source": [ + "origin_torch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Differentiable Optimization with Functional Optimizer\n", + "\n", + "Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programming style). \n", + "\n", + "Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def differentiable():\n", + " batch_size = 1\n", + " dim = 1\n", + " net = Net(dim)\n", + " model, params = functorch.make_functional(net) # get the functional version of the model\n", + "\n", + " # Meta-parameter\n", + " meta_param = nn.Parameter(torch.ones(1))\n", + "\n", + " # SGD example\n", + " learning_rate = 1.0\n", + " optimizer = torchopt.sgd(learning_rate)\n", + " opt_state = optimizer.init(params)\n", + "\n", + " xs = torch.ones((batch_size, dim))\n", + " ys = torch.ones((batch_size, 1))\n", + "\n", + " pred = model(params, xs)\n", + " # Where meta_param is used\n", + " pred = pred + meta_param\n", + " loss = mse(pred, ys)\n", + "\n", + " grads = torch.autograd.grad(loss, params, create_graph=True)\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=False)\n", + " # Update parameters with single step SGD update\n", + " params = torchopt.apply_updates(params, updates, inplace=False)\n", + "\n", + " pred = model(params, xs)\n", + " loss = mse(pred, ys)\n", + " loss.backward()\n", + "\n", + " print('Gradient for the meta-parameter:', meta_param.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gradient for the meta-parameter: tensor([32.])\n" + ] + } + ], + "source": [ + "differentiable()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Track the Gradient of Momentum\n", + "\n", + "Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, moment_requires_grad=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, moment_requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.sgd(lr=1.0, momentum=0.8, moment_requires_grad=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Accelerated Optimizer\n", + "\n", + "Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check whether the `accelerated_op` is available:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "torchopt.accelerated_op_available(torch.device('cpu'))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "torchopt.accelerated_op_available(torch.device('cuda'))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "net = Net(1).cuda()\n", + "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" + ] } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cpu'))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } } - ], - "source": [ - "torchopt.accelerated_op_available(torch.device('cuda'))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "net = Net(1).cuda()\n", - "optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "optim = torchopt.adam(lr=1.0, use_accelerated_op=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb index dd58c48d..f2b89ec4 100644 --- a/tutorials/2_Visualization.ipynb +++ b/tutorials/2_Visualization.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -181,8 +181,9 @@ "# Draw computation graph\n", "display(\n", " torchopt.visual.make_dot(\n", - " loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n", - " )\n", + " loss,\n", + " [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}],\n", + " ),\n", ")" ] } diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb index 69be77ed..6c254f33 100644 --- a/tutorials/3_Meta_Optimizer.ipynb +++ b/tutorials/3_Meta_Optimizer.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -200,8 +200,9 @@ "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, @@ -247,8 +248,9 @@ "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, @@ -513,21 +515,30 @@ "source": [ "functional_adam = torchopt.adam(\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", - " )\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", + " ),\n", ")\n", "\n", "adam = torchopt.Adam(\n", " net.parameters(),\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", " ),\n", ")\n", "\n", "meta_adam = torchopt.MetaAdam(\n", " net,\n", " lr=torchopt.schedule.linear_schedule(\n", - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", + " init_value=1e-3,\n", + " end_value=1e-4,\n", + " transition_steps=10000,\n", + " transition_begin=2000,\n", " ),\n", ")" ] @@ -610,19 +621,26 @@ "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n", "\n", "net_state_0 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", + " net,\n", + " by='reference',\n", + " enable_visual=True,\n", + " visual_prefix='step0.',\n", ")\n", "inner_loss = F.mse_loss(net(x), y)\n", "optim.step(inner_loss)\n", "net_state_1 = torchopt.extract_state_dict(\n", - " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", + " net,\n", + " by='reference',\n", + " enable_visual=True,\n", + " visual_prefix='step1.',\n", ")\n", "\n", "outer_loss = F.mse_loss(net(x), y)\n", "display(\n", " torchopt.visual.make_dot(\n", - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", - " )\n", + " outer_loss,\n", + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", + " ),\n", ")" ] }, diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb index d8c24bc6..d6f03aa9 100644 --- a/tutorials/4_Stop_Gradient.ipynb +++ b/tutorials/4_Stop_Gradient.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "metadata": {}, @@ -192,7 +192,7 @@ " one_step_net_state,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")" ] }, @@ -393,7 +393,7 @@ " one_step_net_state,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")\n", "\n", "# Outer update\n", @@ -457,7 +457,9 @@ "torchopt.stop_gradient(net)\n", "torchopt.stop_gradient(optim)\n", "one_step_net_state_detached = torchopt.extract_state_dict(\n", - " net, enable_visual=True, visual_prefix='step1.detached.'\n", + " net,\n", + " enable_visual=True,\n", + " visual_prefix='step1.detached.',\n", ")\n", "\n", "# Inner update\n", @@ -480,7 +482,7 @@ " one_step_net_state_detached,\n", " {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n", " ),\n", - " )\n", + " ),\n", ")" ] }, diff --git a/tutorials/5_Implicit_Differentiation.ipynb b/tutorials/5_Implicit_Differentiation.ipynb index 23407801..5f4d3357 100644 --- a/tutorials/5_Implicit_Differentiation.ipynb +++ b/tutorials/5_Implicit_Differentiation.ipynb @@ -1,576 +1,578 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", - "metadata": {}, - "source": [ - "# TorchOpt for Implicit Differentiation" - ] - }, - { - "cell_type": "markdown", - "id": "2b547376", - "metadata": {}, - "source": [ - "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", - "metadata": {}, - "source": [ - "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." - ] - }, - { - "cell_type": "markdown", - "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", - "metadata": {}, - "source": [ - "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", - "metadata": {}, - "outputs": [], - "source": [ - "import functorch\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import torchopt" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", - "metadata": {}, - "source": [ - "## 1. Functional API\n", - "\n", - "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "# Functional API for implicit gradient\n", - "def stationary(params, meta_params, data):\n", - " # stationary condition construction\n", - " return stationary condition\n", - "\n", - "# Decorator that wraps the function\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", - "def solve(params, meta_params, data):\n", - " # Forward optimization process for params\n", - " return optimal_params\n", - "\n", - "# Define params, meta_params and get data\n", - "params, meta_prams, data = ..., ..., ...\n", - "optimal_params = solve(params, meta_params, data)\n", - "loss = outer_loss(optimal_params)\n", - "\n", - "meta_grads = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", - "metadata": {}, - "source": [ - "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", - "\n", - "$$\n", - "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", - "$$\n", - "\n", - "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", - "\n", - "$$\n", - "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", - "$$\n", - "\n", - "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", - "metadata": {}, - "outputs": [], - "source": [ - "# Inner-loop objective function\n", - "# The optimality function: grad(imaml_objective)\n", - "def imaml_objective(params, meta_params, data):\n", - " x, y, fmodel = data\n", - " y_pred = fmodel(params, x)\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " loss = F.mse_loss(y_pred, y) + regularization_loss\n", - " return loss\n", - "\n", - "\n", - "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", - "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", - "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", - "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", - "\n", - "\n", - "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", - "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", - "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - ")\n", - "def inner_solver(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params\n", - "\n", - "\n", - "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", - "@torchopt.diff.implicit.custom_root(\n", - " functorch.grad(imaml_objective, argnums=0), # optimality function\n", - " argnums=1,\n", - " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", - ")\n", - "def inner_solver_inv_ns(params, meta_params, data):\n", - " # Initial functional optimizer based on TorchOpt\n", - " x, y, fmodel = data\n", - " optimizer = torchopt.sgd(lr=2e-2)\n", - " opt_state = optimizer.init(params)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for i in range(100):\n", - " pred = fmodel(params, x)\n", - " loss = F.mse_loss(pred, y) # compute loss\n", - "\n", - " # Compute regularization loss\n", - " regularization_loss = 0.0\n", - " for p1, p2 in zip(params, meta_params):\n", - " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " final_loss = loss + regularization_loss\n", - "\n", - " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", - " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", - " params = torchopt.apply_updates(params, updates, inplace=True)\n", - "\n", - " optimal_params = params\n", - " return optimal_params" - ] - }, - { - "cell_type": "markdown", - "id": "32a75c81-d479-4120-a73d-5b2b488358d0", - "metadata": {}, - "source": [ - "In the next step, we consider a specific case for one layer neural network to fit the linear data." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(0)\n", - "x = torch.randn(20, 4)\n", - "w = torch.randn(4, 1)\n", - "b = torch.randn(1)\n", - "y = x @ w + b + 0.5 * torch.randn(20, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "eeb1823a-2231-4471-bb68-cce7724f2578", - "metadata": {}, - "source": [ - "We instantiate an one layer neural network, where the weights and bias are initialized with constant." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, 1, bias=True)\n", - " nn.init.ones_(self.fc.weight)\n", - " nn.init.zeros_(self.fc.bias)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "model = Net(4)\n", - "fmodel, meta_params = functorch.make_functional(model)\n", - "data = (x, y, fmodel)\n", - "\n", - "\n", - "# Clone function for parameters\n", - "def clone(params):\n", - " cloned = []\n", - " for item in params:\n", - " if isinstance(item, torch.Tensor):\n", - " cloned.append(item.clone().detach_().requires_grad_(True))\n", - " else:\n", - " cloned.append(item)\n", - " return tuple(cloned)" - ] - }, - { - "cell_type": "markdown", - "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", - "metadata": {}, - "source": [ - "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", - "\n", - "outer_loss = fmodel(optimal_params, x).mean()" - ] - }, - { - "cell_type": "markdown", - "id": "e2812351-f635-496e-9732-c80831ac04a6", - "metadata": {}, - "source": [ - "Finally, we can get the meta-gradient as shown below." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", - "metadata": {}, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] + "cells": + [ + { + "cell_type": "markdown", + "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", + "metadata": {}, + "source": [ + "# TorchOpt for Implicit Differentiation" + ] + }, + { + "cell_type": "markdown", + "id": "2b547376", + "metadata": {}, + "source": [ + "[](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "8d7f9865-dc02-43d4-be90-da1160c4e4dd", + "metadata": {}, + "source": [ + "By treating the solution $\\phi^{\\star}$ as an implicit function of $\\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\\partial \\phi^{\\star}(\\theta)/ \\partial \\theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\\left. \\frac{\\partial F (\\phi, \\theta)}{\\partial \\phi} \\right\\rvert}_{\\phi = \\phi^{\\star}} = 0$ or reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377)." + ] + }, + { + "cell_type": "markdown", + "id": "d7e4b9e1-115f-45ad-a9b3-ea338bcfe6dd", + "metadata": {}, + "source": [ + "In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f13ae67-e328-409f-84a8-1fc425c03a66", + "metadata": {}, + "outputs": [], + "source": [ + "import functorch\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchopt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0cdaac49-4b94-4900-9bb5-a39057ac8b21", + "metadata": {}, + "source": [ + "## 1. Functional API\n", + "\n", + "The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "# Functional API for implicit gradient\n", + "def stationary(params, meta_params, data):\n", + " # stationary condition construction\n", + " return stationary condition\n", + "\n", + "# Decorator that wraps the function\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)\n", + "def solve(params, meta_params, data):\n", + " # Forward optimization process for params\n", + " return optimal_params\n", + "\n", + "# Define params, meta_params and get data\n", + "params, meta_prams, data = ..., ..., ...\n", + "optimal_params = solve(params, meta_params, data)\n", + "loss = outer_loss(optimal_params)\n", + "\n", + "meta_grads = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "dbef87df-2164-4f1d-8919-37a6fbdc5011", + "metadata": {}, + "source": [ + "Here we use the example of [iMAML](https://arxiv.org/abs/1909.04630) as a real example. For iMAML, the inner-loop objective is described by the following equation.\n", + "\n", + "$$\n", + "{\\mathcal{Alg}}^{\\star} \\left( \\boldsymbol{\\theta}, \\mathcal{D}_{i}^{\\text{tr}} \\right) = \\underset{\\phi'}{\\operatorname{\\arg \\min}} ~ G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\triangleq \\mathcal{L} \\left( \\boldsymbol{\\phi}', \\mathcal{D}_{i}^{\\text{tr}} \\right) + \\frac{\\lambda}{2} {\\left\\| \\boldsymbol{\\phi}' - \\boldsymbol{\\theta} \\right\\|}^{2}\n", + "$$\n", + "\n", + "According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.\n", + "\n", + "$$\n", + "{\\left. \\nabla_{\\boldsymbol{\\phi}'} G \\left( \\boldsymbol{\\phi}', \\boldsymbol{\\theta} \\right) \\right\\rvert}_{\\boldsymbol{\\phi}' = \\boldsymbol{\\phi}^{\\star}} = 0\n", + "$$\n", + "\n", + "Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d623b2f-48ee-4df6-a2ce-cf306b4c9067", + "metadata": {}, + "outputs": [], + "source": [ + "# Inner-loop objective function\n", + "# The optimality function: grad(imaml_objective)\n", + "def imaml_objective(params, meta_params, data):\n", + " x, y, fmodel = data\n", + " y_pred = fmodel(params, x)\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " loss = F.mse_loss(y_pred, y) + regularization_loss\n", + " return loss\n", + "\n", + "\n", + "# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by\n", + "# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to\n", + "# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.\n", + "# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta-parameters\n", + "\n", + "\n", + "# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of\n", + "# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.\n", + "# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + ")\n", + "def inner_solver(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params\n", + "\n", + "\n", + "# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver\n", + "@torchopt.diff.implicit.custom_root(\n", + " functorch.grad(imaml_objective, argnums=0), # optimality function\n", + " argnums=1,\n", + " solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),\n", + ")\n", + "def inner_solver_inv_ns(params, meta_params, data):\n", + " # Initial functional optimizer based on TorchOpt\n", + " x, y, fmodel = data\n", + " optimizer = torchopt.sgd(lr=2e-2)\n", + " opt_state = optimizer.init(params)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for i in range(100):\n", + " pred = fmodel(params, x)\n", + " loss = F.mse_loss(pred, y) # compute loss\n", + "\n", + " # Compute regularization loss\n", + " regularization_loss = 0.0\n", + " for p1, p2 in zip(params, meta_params):\n", + " regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " final_loss = loss + regularization_loss\n", + "\n", + " grads = torch.autograd.grad(final_loss, params) # compute gradients\n", + " updates, opt_state = optimizer.update(grads, opt_state, inplace=True) # get updates\n", + " params = torchopt.apply_updates(params, updates, inplace=True)\n", + "\n", + " optimal_params = params\n", + " return optimal_params" + ] + }, + { + "cell_type": "markdown", + "id": "32a75c81-d479-4120-a73d-5b2b488358d0", + "metadata": {}, + "source": [ + "In the next step, we consider a specific case for one layer neural network to fit the linear data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fb95538b-1fd9-4ec8-9f57-6360bedc05b7", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "x = torch.randn(20, 4)\n", + "w = torch.randn(4, 1)\n", + "b = torch.randn(1)\n", + "y = x @ w + b + 0.5 * torch.randn(20, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "eeb1823a-2231-4471-bb68-cce7724f2578", + "metadata": {}, + "source": [ + "We instantiate an one layer neural network, where the weights and bias are initialized with constant." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d50a7bfe-ac69-4089-8cf8-3cbd69d6d4e7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, 1, bias=True)\n", + " nn.init.ones_(self.fc.weight)\n", + " nn.init.zeros_(self.fc.bias)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "model = Net(4)\n", + "fmodel, meta_params = functorch.make_functional(model)\n", + "data = (x, y, fmodel)\n", + "\n", + "\n", + "# Clone function for parameters\n", + "def clone(params):\n", + " cloned = []\n", + " for item in params:\n", + " if isinstance(item, torch.Tensor):\n", + " cloned.append(item.clone().detach_().requires_grad_(True))\n", + " else:\n", + " cloned.append(item)\n", + " return tuple(cloned)" + ] + }, + { + "cell_type": "markdown", + "id": "065c36c4-89e2-4a63-8213-63db6ee3b08e", + "metadata": {}, + "source": [ + "We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "115e79c6-911f-4743-a2ed-e50a71c3a813", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "optimal_params = inner_solver(clone(meta_params), meta_params, data)\n", + "\n", + "outer_loss = fmodel(optimal_params, x).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "e2812351-f635-496e-9732-c80831ac04a6", + "metadata": {}, + "source": [ + "Finally, we can get the meta-gradient as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6bdcbe8d-2336-4f80-b124-eb43c5a2fc0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "cell_type": "markdown", + "id": "926ae8bb", + "metadata": {}, + "source": [ + "Also we can switch to the Neumann Series inversion linear solver." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43df0374", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", + "outer_loss = fmodel(optimal_params, x).mean()\n", + "torch.autograd.grad(outer_loss, meta_params)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", + "metadata": {}, + "source": [ + "## 2. OOP API\n", + "\n", + "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", + "\n", + "```python\n", + "from torchopt.nn import ImplicitMetaGradientModule\n", + "\n", + "# Inherited from the class ImplicitMetaGradientModule\n", + "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", + "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", + " def __init__(self, meta_module):\n", + " ...\n", + "\n", + " def forward(self, batch):\n", + " # Forward process\n", + " ...\n", + "\n", + " def optimality(self, batch, labels):\n", + " # Stationary condition construction for calculating implicit gradient\n", + " # NOTE: If this method is not implemented, it will be automatically derived from the\n", + " # gradient of the `objective` function.\n", + " ...\n", + "\n", + " def objective(self, batch, labels):\n", + " # Define the inner-loop optimization objective\n", + " # NOTE: This method is optional if method `optimality` is implemented.\n", + " ...\n", + "\n", + " def solve(self, batch, labels):\n", + " # Conduct the inner-loop optimization\n", + " ...\n", + " return self # optimized module\n", + "\n", + "# Get meta_params and data\n", + "meta_params, data = ..., ...\n", + "inner_net = InnerNet()\n", + "\n", + "# Solve for inner-loop process related to the meta-parameters\n", + "optimal_inner_net = inner_net.solve(meta_params, *data)\n", + "\n", + "# Get outer-loss and solve for meta-gradient\n", + "loss = outer_loss(optimal_inner_net)\n", + "meta_grad = torch.autograd.grad(loss, meta_params)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", + "metadata": {}, + "source": + [ + "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" + ] + } + ], + "source": [ + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, n_inner_iter, reg_param):\n", + " super().__init__()\n", + " # Declaration of the meta-parameter\n", + " self.meta_net = meta_net\n", + " # Get a deepcopy, register inner-parameter\n", + " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", + " self.n_inner_iter = n_inner_iter\n", + " self.reg_param = reg_param\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + " def objective(self, x, y):\n", + " # We do not implement the optimality conditions, so it will be automatically derived from\n", + " # the gradient of the `objective` function.\n", + " y_pred = self(x)\n", + " loss = F.mse_loss(y_pred, y)\n", + " regularization_loss = 0\n", + " for p1, p2 in zip(\n", + " self.parameters(), # parameters of `self.net`\n", + " self.meta_parameters(), # parameters of `self.meta_net`\n", + " ):\n", + " regularization_loss += (\n", + " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", + " )\n", + " return loss + regularization_loss\n", + "\n", + " def solve(self, x, y):\n", + " params = tuple(self.parameters())\n", + " inner_optim = torchopt.SGD(params, lr=2e-2)\n", + " with torch.enable_grad():\n", + " # Temporarily enable gradient computation for conducting the optimization\n", + " for _ in range(self.n_inner_iter):\n", + " loss = self.objective(x, y)\n", + " inner_optim.zero_grad()\n", + " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", + " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", + " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", + " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", + " loss.backward(inputs=params) # backward pass in inner-loop\n", + " inner_optim.step() # update inner parameters\n", + " return self\n", + "\n", + "\n", + "# Initialize the meta-network\n", + "meta_net = Net(4)\n", + "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve(x, y)\n", + "outer_loss = optimal_inner_net(x).mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + }, + { + "cell_type": "markdown", + "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", + "metadata": {}, + "source": [ + "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(\n", + "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", + "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", + "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", + "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", + "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", + ")\n" + ] + } + ], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.fc = nn.Linear(dim, dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "\n", + "class InnerNet(\n", + " torchopt.nn.ImplicitMetaGradientModule,\n", + " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", + "):\n", + " def __init__(self, meta_net, x0):\n", + " super().__init__()\n", + " # Register meta-parameter\n", + " self.meta_net = meta_net\n", + " # Declaration of the inner-parameter, register inner-parameter\n", + " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", + "\n", + " def forward(self, x):\n", + " return self.meta_net(x)\n", + "\n", + " def optimality(self):\n", + " # Fixed-point condition\n", + " return (self.x - self(self.x),)\n", + "\n", + " def solve(self):\n", + " # Solving inner-loop fixed-point iteration\n", + " # This is just an illustrating example for solving fixed-point iteration\n", + " # one can use more advanced method to solve fixed-point iteration\n", + " # such as anderson acceleration.\n", + " for _ in range(10):\n", + " self.x.copy_(self(self.x))\n", + " return self\n", + "\n", + "\n", + "# Initialize meta-network\n", + "torch.manual_seed(0)\n", + "meta_net = Net(4)\n", + "x0 = torch.randn(1, 4)\n", + "inner_net = InnerNet(meta_net, x0)\n", + "\n", + "# Solve for inner-loop\n", + "optimal_inner_net = inner_net.solve()\n", + "outer_loss = optimal_inner_net.x.mean()\n", + "\n", + "# Derive the meta-gradient\n", + "torch.autograd.grad(outer_loss, meta_net.parameters())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 } - ], - "source": [ - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "cell_type": "markdown", - "id": "926ae8bb", - "metadata": {}, - "source": [ - "Also we can switch to the Neumann Series inversion linear solver." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "43df0374", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)\n", - "outer_loss = fmodel(optimal_params, x).mean()\n", - "torch.autograd.grad(outer_loss, meta_params)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "c92e67ea-b220-4a14-a1ea-4eb3c5f52b6b", - "metadata": {}, - "source": [ - "## 2. OOP API\n", - "\n", - "The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.\n", - "\n", - "```python\n", - "from torchopt.nn import ImplicitMetaGradientModule\n", - "\n", - "# Inherited from the class ImplicitMetaGradientModule\n", - "# Optionally specify the linear solver (conjugate gradient or Neumann series)\n", - "class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):\n", - " def __init__(self, meta_module):\n", - " ...\n", - "\n", - " def forward(self, batch):\n", - " # Forward process\n", - " ...\n", - "\n", - " def optimality(self, batch, labels):\n", - " # Stationary condition construction for calculating implicit gradient\n", - " # NOTE: If this method is not implemented, it will be automatically derived from the\n", - " # gradient of the `objective` function.\n", - " ...\n", - "\n", - " def objective(self, batch, labels):\n", - " # Define the inner-loop optimization objective\n", - " # NOTE: This method is optional if method `optimality` is implemented.\n", - " ...\n", - "\n", - " def solve(self, batch, labels):\n", - " # Conduct the inner-loop optimization\n", - " ...\n", - " return self # optimized module\n", - "\n", - "# Get meta_params and data\n", - "meta_params, data = ..., ...\n", - "inner_net = InnerNet()\n", - "\n", - "# Solve for inner-loop process related to the meta-parameters\n", - "optimal_inner_net = inner_net.solve(meta_params, *data)\n", - "\n", - "# Get outer-loss and solve for meta-gradient\n", - "loss = outer_loss(optimal_inner_net)\n", - "meta_grad = torch.autograd.grad(loss, meta_params)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "62fbe520-11d0-41ff-9b0a-c6508b1d01cf", - "metadata": {}, - "source": [ - "The class `ImplicitMetaGradientModule` is to enable the gradient flow from `self.parameters()` to `self.meta_parameters()`. In `__init__` function, users need to define the inner parameters and meta-parameters. By default, `ImplicitMetaGradientModule` treats all tensors and modules from input as `self.meta_parameters()`, and all tensors and modules defined in the `__init__` are regarded as `self.parameters()`. Users can also register `self.parameters()` and `self.meta_parameters()` by calling `self.register_parameter()` and `self.register_meta_parameter()` respectively." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c3999684-f4d3-4bc0-86ab-a7e803b2fe80", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor([[-0.0369, 0.0248, 0.0347, 0.0067]]), tensor([0.3156]))\n" - ] - } - ], - "source": [ - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, n_inner_iter, reg_param):\n", - " super().__init__()\n", - " # Declaration of the meta-parameter\n", - " self.meta_net = meta_net\n", - " # Get a deepcopy, register inner-parameter\n", - " self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)\n", - " self.n_inner_iter = n_inner_iter\n", - " self.reg_param = reg_param\n", - "\n", - " def forward(self, x):\n", - " return self.net(x)\n", - "\n", - " def objective(self, x, y):\n", - " # We do not implement the optimality conditions, so it will be automatically derived from\n", - " # the gradient of the `objective` function.\n", - " y_pred = self(x)\n", - " loss = F.mse_loss(y_pred, y)\n", - " regularization_loss = 0\n", - " for p1, p2 in zip(\n", - " self.parameters(), # parameters of `self.net`\n", - " self.meta_parameters(), # parameters of `self.meta_net`\n", - " ):\n", - " regularization_loss += (\n", - " 0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))\n", - " )\n", - " return loss + regularization_loss\n", - "\n", - " def solve(self, x, y):\n", - " params = tuple(self.parameters())\n", - " inner_optim = torchopt.SGD(params, lr=2e-2)\n", - " with torch.enable_grad():\n", - " # Temporarily enable gradient computation for conducting the optimization\n", - " for _ in range(self.n_inner_iter):\n", - " loss = self.objective(x, y)\n", - " inner_optim.zero_grad()\n", - " # NOTE: The parameter inputs should be explicitly specified in `backward` function\n", - " # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into\n", - " # all the leaf Tensors (including the meta-parameters) that were used to compute the\n", - " # objective output. Alternatively, please use `torch.autograd.grad` instead.\n", - " loss.backward(inputs=params) # backward pass in inner-loop\n", - " inner_optim.step() # update inner parameters\n", - " return self\n", - "\n", - "\n", - "# Initialize the meta-network\n", - "meta_net = Net(4)\n", - "inner_net = InnerNet(meta_net, 100, reg_param=1)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve(x, y)\n", - "outer_loss = optimal_inner_net(x).mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - }, - { - "cell_type": "markdown", - "id": "2b69a5d6-b5e4-4f08-af0a-40afc2382b45", - "metadata": {}, - "source": [ - "We also show an example on how to implement implicit gradient calculation when the inner-level optimal solution reaches some stationary conditions $F (\\phi^{\\star}, \\theta) = 0$, such as [DEQ](https://arxiv.org/abs/1909.01377), based on the OOP API. " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "de87c308-d847-4491-9aa1-bc393e6dd1d8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "(\n", - "│ tensor([[ 0.0272, 0.0031, -0.0156, -0.0238],\n", - "│ │ [ 0.1004, 0.0113, -0.0573, -0.0878],\n", - "│ │ [ 0.0666, 0.0075, -0.0380, -0.0583],\n", - "│ │ [ 0.1446, 0.0163, -0.0826, -0.1265]]),\n", - "│ tensor([0.0574, 0.2114, 0.1403, 0.3046])\n", - ")\n" - ] - } - ], - "source": [ - "class Net(nn.Module):\n", - " def __init__(self, dim):\n", - " super().__init__()\n", - " self.fc = nn.Linear(dim, dim)\n", - "\n", - " def forward(self, x):\n", - " return self.fc(x)\n", - "\n", - "\n", - "class InnerNet(\n", - " torchopt.nn.ImplicitMetaGradientModule,\n", - " linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),\n", - "):\n", - " def __init__(self, meta_net, x0):\n", - " super().__init__()\n", - " # Register meta-parameter\n", - " self.meta_net = meta_net\n", - " # Declaration of the inner-parameter, register inner-parameter\n", - " self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)\n", - "\n", - " def forward(self, x):\n", - " return self.meta_net(x)\n", - "\n", - " def optimality(self):\n", - " # Fixed-point condition\n", - " return (self.x - self(self.x),)\n", - "\n", - " def solve(self):\n", - " # Solving inner-loop fixed-point iteration\n", - " # This is just an illustrating example for solving fixed-point iteration\n", - " # one can use more advanced method to solve fixed-point iteration\n", - " # such as anderson acceleration.\n", - " for _ in range(10):\n", - " self.x.copy_(self(self.x))\n", - " return self\n", - "\n", - "\n", - "# Initialize meta-network\n", - "torch.manual_seed(0)\n", - "meta_net = Net(4)\n", - "x0 = torch.randn(1, 4)\n", - "inner_net = InnerNet(meta_net, x0)\n", - "\n", - "# Solve for inner-loop\n", - "optimal_inner_net = inner_net.solve()\n", - "outer_loss = optimal_inner_net.x.mean()\n", - "\n", - "# Derive the meta-gradient\n", - "torch.autograd.grad(outer_loss, meta_net.parameters())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - }, - "vscode": { - "interpreter": { - "hash": "2a8cc1ff2cbc47027bf9993941710d9ab9175f14080903d9c7c432ee63d681da" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/6_Zero_Order_Differentiation.ipynb b/tutorials/6_Zero_Order_Differentiation.ipynb index d6cb028c..683eb34d 100644 --- a/tutorials/6_Zero_Order_Differentiation.ipynb +++ b/tutorials/6_Zero_Order_Differentiation.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "markdown", "id": "8850c832-3b54-4971-8ee0-2cd64b585ea8", @@ -175,7 +175,11 @@ "\n", "\n", "@torchopt.diff.zero_order(\n", - " distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n", + " distribution=distribution,\n", + " method='forward',\n", + " argnums=0,\n", + " num_samples=100,\n", + " sigma=0.01,\n", ")\n", "def forward_process(params, fn, x, y):\n", " y_pred = fn(params, x)\n",