Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ ci:
autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
autoupdate_schedule: monthly
default_stages: [commit, push, manual]
default_stages: [pre-commit, pre-push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v6.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,24 +26,24 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v21.1.2
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.13.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 6.1.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 24.4.2
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.9.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.20.0
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -52,7 +52,7 @@ repos:
^examples/
)
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
rev: 7.3.0
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -68,7 +68,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
40 changes: 20 additions & 20 deletions include/adam_op/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,51 @@ namespace py = pybind11;

namespace adam_op {

TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMu(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNu(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
const pyuint_t count);

void buildSubmodule(py::module &mod); // NOLINT[runtime/references]
void buildSubmodule(py::module& mod); // NOLINT[runtime/references]

} // namespace adam_op
} // namespace torchopt
38 changes: 19 additions & 19 deletions include/adam_op/adam_op_impl_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,45 @@

namespace torchopt {
namespace adam_op {
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
Expand Down
38 changes: 19 additions & 19 deletions include/adam_op/adam_op_impl_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,45 @@

namespace torchopt {
namespace adam_op {
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMuCUDA(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNuCUDA(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1);

TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
Expand Down
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#endif

namespace torchopt {
__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) {
__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
const auto dim = tensor.dim();
size_t n = 1;
for (std::decay_t<decltype(dim)> i = 0; i < dim; ++i) {
Expand Down
40 changes: 20 additions & 20 deletions src/adam_op/adam_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace py = pybind11;

namespace adam_op {

TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu,
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
Expand All @@ -49,8 +49,8 @@ TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
}
}

torch::Tensor adamForwardMu(const torch::Tensor &updates,
const torch::Tensor &mu,
torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1) {
#if defined(__USE_CUDA__)
if (updates.device().is_cuda()) {
Expand All @@ -64,8 +64,8 @@ torch::Tensor adamForwardMu(const torch::Tensor &updates,
}
}

torch::Tensor adamForwardNu(const torch::Tensor &updates,
const torch::Tensor &nu,
torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2) {
#if defined(__USE_CUDA__)
if (updates.device().is_cuda()) {
Expand All @@ -79,8 +79,8 @@ torch::Tensor adamForwardNu(const torch::Tensor &updates,
}
}

torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps,
Expand All @@ -98,9 +98,9 @@ torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
}
}

TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu,
const pyfloat_t b1) {
#if defined(__USE_CUDA__)
if (dmu.device().is_cuda()) {
Expand All @@ -114,9 +114,9 @@ TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
}
}

TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu,
const pyfloat_t b2) {
#if defined(__USE_CUDA__)
if (dnu.device().is_cuda()) {
Expand All @@ -130,10 +130,10 @@ TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
}
}

TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const pyfloat_t b1,
const pyfloat_t b2,
const pyfloat_t eps_root,
Expand All @@ -152,7 +152,7 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
}
}

void buildSubmodule(py::module &mod) { // NOLINT[runtime/references]
void buildSubmodule(py::module& mod) { // NOLINT[runtime/references]
py::module m = mod.def_submodule("adam_op", "Adam Ops");
m.def("forward_",
&adamForwardInplace,
Expand Down
Loading
Loading