From 72ad7c0b05ee3f3e8ecb6456790140b7207c756b Mon Sep 17 00:00:00 2001 From: Zihan Yang Date: Tue, 21 Apr 2026 20:37:21 +0800 Subject: [PATCH] Fix incorrect option for param group in adagrad --- csrc/embedding/adagrad.cc | 6 +++--- recis/optim/__init__.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/embedding/adagrad.cc b/csrc/embedding/adagrad.cc index 95389b3..f723974 100644 --- a/csrc/embedding/adagrad.cc +++ b/csrc/embedding/adagrad.cc @@ -70,17 +70,17 @@ void SparseAdagrad::add_param_group( const SparseOptimizerParamGroup ¶m_group) { SparseOptimizerParamGroup param_group_(param_group.params()); // set options for group - if (!param_group_.has_options()) { + if (!param_group.has_options()) { param_group_.set_options(defaults_->clone()); } else { - param_group_.set_options(param_group_.options().clone()); + param_group_.set_options(param_group.options().clone()); } // init optimizer global state for hashtable name <-> hashtable ptr for (const auto ¶m : param_group_.params()) { init_param_state(param.second, param_group_.options()); } // add param group - param_groups_.emplace_back(param_group); + param_groups_.emplace_back(param_group_); } void SparseAdagrad::add_parameters( diff --git a/recis/optim/__init__.py b/recis/optim/__init__.py index 145c7bd..12d0a14 100644 --- a/recis/optim/__init__.py +++ b/recis/optim/__init__.py @@ -10,6 +10,7 @@ from recis.optim.sparse_adam import SparseAdam as SparseAdam from recis.optim.sparse_adamw import SparseAdamW as SparseAdamW from recis.optim.sparse_adamw_tf import SparseAdamWTF as SparseAdamWTF +from recis.optim.sparse_adagrad import SparseAdagrad as SparseAdagrad __all__ = [ @@ -22,5 +23,6 @@ "SparseAdam", "SparseAdamW", "SparseAdamWTF", + "SparseAdagrad", "wrapped_named_optimizer", ]