From b0812597f2c0aafd5b708b1dc1feb4b00a8bbd02 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 15:44:15 +0000 Subject: [PATCH 1/6] fix reading field in python scope avoiding gpu-cpu sync --- genesis/engine/solvers/rigid/rigid_solver.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 0927234a24..8cc117cf20 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -463,6 +463,9 @@ def _create_data_manager(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info + self._rigid_global_info._n_iterations = ( + self._options.iterations + ) # Python-native mirror to avoid CPU-GPU sync in Python-scope functions self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs From 56e085d33e5c47dd41b028d4a7b2cb8282d98ab5 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 16:29:05 +0000 Subject: [PATCH 2/6] update --- genesis/engine/solvers/rigid/rigid_solver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 8cc117cf20..0927234a24 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -463,9 +463,6 @@ def _create_data_manager(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info - self._rigid_global_info._n_iterations = ( - self._options.iterations - ) # Python-native mirror to avoid CPU-GPU sync in Python-scope functions self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs From 80fcf8919b05125bbc3395b664fdf42e3a19d38a Mon Sep 17 00:00:00 2001 From: Mingrui Date: Mon, 9 Mar 2026 18:09:48 +0000 Subject: [PATCH 3/6] Decompose func_solve_init with perf_dispatch and 8 separate kernels Convert func_solve_init from a plain @qd.kernel to a @qd.perf_dispatch, and register func_solve_init_decomposed for CUDA backend. This breaks the monolithic init into 8 separate kernel launches: 1. _kernel_init_warmstart (warmstart selection, ndrange dofs) 2. _kernel_init_Ma (Ma = M @ qacc, ndrange dofs) 3. _kernel_init_Jaref (Jaref = -aref + J @ qacc, ndrange constraints) 4. _kernel_init_improved (set improved flags) 5. _kernel_init_update_constraint (wraps monolith for FP match) 6. Newton hessian (conditional, reuses existing kernel) 7. _kernel_init_update_gradient (wraps monolith tiled gradient) 8. _kernel_init_search (search = -Mgrad, ndrange dofs) Co-Authored-By: Claude Opus 4.6 --- .../engine/solvers/rigid/constraint/solver.py | 16 +- .../rigid/constraint/solver_breakdown.py | 179 +++++++++++++++++- 2 files changed, 193 insertions(+), 2 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 38aa9306b1..0891e08828 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2911,7 +2911,9 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.kernel(fastcache=gs.use_fastcache) +@qd.perf_dispatch( + get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 +) def func_solve_init( dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, @@ -2919,6 +2921,18 @@ def func_solve_init( constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), +) -> None: ... + + +@func_solve_init.register(is_compatible=lambda *args, **kwargs: True) +@qd.kernel(fastcache=gs.use_fastcache) +def func_solve_init_monolith( + dofs_info: array_class.DofsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), ): _B = dofs_state.acc_smooth.shape[1] n_dofs = dofs_state.acc_smooth.shape[0] diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ea6fadbc7d..b8ef7ba55a 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -771,6 +771,7 @@ def _func_update_search_direction( ) +<<<<<<< HEAD @qd.func def _func_check_early_exit( constraint_state: array_class.ConstraintState, @@ -791,8 +792,184 @@ def _func_check_early_exit( graph_counter[()] = 0 -# ============================================== Solve body dispatch ================================================ +# ================================================ Init kernels ================================================ + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_warmstart( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Select qacc from warmstart or acc_smooth, parallelized over (dof, env).""" + n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Ma( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Compute Ma = M @ qacc, parallelized over (dof, env).""" + solver.initialize_Ma( + Ma=constraint_state.Ma, + qacc=constraint_state.qacc, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_Jaref( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + n_dofs = constraint_state.jac.shape[1] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + Jaref = -constraint_state.aref[i_c, i_b] + if qd.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + else: + for i_d in range(n_dofs): + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + constraint_state.Jaref[i_c, i_b] = Jaref + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_improved( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set improved = (n_constraints > 0) for each env.""" + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_search( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set search = -Mgrad, parallelized over (dof, env).""" + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_constraint( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Init-only constraint update — wraps monolith's func_update_constraint for exact FP match.""" + solver.func_update_constraint( + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def _kernel_init_update_gradient( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU).""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +def func_solve_init_decomposed( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + """ + Decomposed version of func_solve_init for CUDA backend (non-mujoco path). + + Breaks the monolithic init kernel into separate kernel launches: + 1. Warmstart selection (ndrange over dofs) + 2. Ma = M @ qacc (ndrange over dofs with entity lookup) + 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) + 4. Set improved flags + 5. Update constraint (wraps monolith's func_update_constraint for exact FP match) + 6. Newton hessian (Newton only — reuse existing kernel) + 7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU) + 8. search = -Mgrad (ndrange over dofs) + """ + # 1. Warmstart selection + _kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + + # 2. Ma = M @ qacc + _kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + + # 3. Jaref = -aref + J @ qacc (parallelized over constraints) + _kernel_init_Jaref(constraint_state, static_rigid_sim_config) + + # 4. Set improved flags (needed by decomposed update_constraint kernels) + _kernel_init_improved(constraint_state, static_rigid_sim_config) + + # 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match) + _kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) + + # 6. Newton hessian (Newton only) + if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: + _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) + + # 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU) + _kernel_init_update_gradient( + entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config + ) + + # 8. search = -Mgrad + _kernel_init_search(constraint_state, static_rigid_sim_config) + + +# ============================================== Solve body dispatch ================================================ @qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache) def _kernel_solve_gpu_graph( From 889024ddf89fcb14fee7e6626dc4a2256a45b318 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Tue, 31 Mar 2026 22:23:53 +0100 Subject: [PATCH 4/6] update to use cuda graph --- .../engine/solvers/rigid/constraint/solver.py | 28 +++-- .../rigid/constraint/solver_breakdown.py | 106 +++++++++--------- 2 files changed, 73 insertions(+), 61 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 0891e08828..21e01d0a70 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2911,20 +2911,26 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.perf_dispatch( - get_geometry_hash=lambda *args, **kwargs: (*args, frozendict(kwargs)), warmup=3, active=3, repeat_after_seconds=0 -) def func_solve_init( - dofs_info: array_class.DofsInfo, - dofs_state: array_class.DofsState, - entities_info: array_class.EntitiesInfo, - constraint_state: array_class.ConstraintState, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: qd.template(), -) -> None: ... + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + if gs.backend is not gs.cpu: + from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed + + func_solve_init_decomposed( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) + else: + func_solve_init_monolith( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) -@func_solve_init.register(is_compatible=lambda *args, **kwargs: True) @qd.kernel(fastcache=gs.use_fastcache) def func_solve_init_monolith( dofs_info: array_class.DofsInfo, diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index b8ef7ba55a..eefcfaf967 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -771,7 +771,6 @@ def _func_update_search_direction( ) -<<<<<<< HEAD @qd.func def _func_check_early_exit( constraint_state: array_class.ConstraintState, @@ -792,12 +791,11 @@ def _func_check_early_exit( graph_counter[()] = 0 +# ================================================ Init funcs (for gpu_graph) ==================================== -# ================================================ Init kernels ================================================ - -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_warmstart( +@qd.func +def _func_init_warmstart( dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, static_rigid_sim_config: qd.template(), @@ -814,8 +812,8 @@ def _kernel_init_warmstart( constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_Ma( +@qd.func +def _func_init_Ma( dofs_info: array_class.DofsInfo, entities_info: array_class.EntitiesInfo, constraint_state: array_class.ConstraintState, @@ -833,8 +831,8 @@ def _kernel_init_Ma( ) -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_Jaref( +@qd.func +def _func_init_Jaref( constraint_state: array_class.ConstraintState, static_rigid_sim_config: qd.template(), ): @@ -857,8 +855,8 @@ def _kernel_init_Jaref( constraint_state.Jaref[i_c, i_b] = Jaref -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_improved( +@qd.func +def _func_init_improved( constraint_state: array_class.ConstraintState, static_rigid_sim_config: qd.template(), ): @@ -870,8 +868,8 @@ def _kernel_init_improved( constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_search( +@qd.func +def _func_init_search( constraint_state: array_class.ConstraintState, static_rigid_sim_config: qd.template(), ): @@ -884,8 +882,8 @@ def _kernel_init_search( constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_update_constraint( +@qd.func +def _func_init_update_constraint( dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, static_rigid_sim_config: qd.template(), @@ -901,15 +899,15 @@ def _kernel_init_update_constraint( ) -@qd.kernel(fastcache=gs.use_fastcache) -def _kernel_init_update_gradient( +@qd.func +def _func_init_update_gradient( entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), ): - """Init-only gradient update — wraps monolith's func_update_gradient (dispatches to tiled on GPU).""" + """Init-only gradient update — wraps monolith's func_update_gradient.""" solver.func_update_gradient( dofs_state=dofs_state, entities_info=entities_info, @@ -919,7 +917,29 @@ def _kernel_init_update_gradient( ) -@solver.func_solve_init.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +# ================================================ Init gpu_graph kernel ========================================= + + +@qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache) +def _kernel_solve_init_gpu_graph( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + _func_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + _func_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_Jaref(constraint_state, static_rigid_sim_config) + _func_init_improved(constraint_state, static_rigid_sim_config) + _func_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) + if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): + _func_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_search(constraint_state, static_rigid_sim_config) + + def func_solve_init_decomposed( dofs_info, dofs_state, @@ -929,48 +949,34 @@ def func_solve_init_decomposed( static_rigid_sim_config, ): """ - Decomposed version of func_solve_init for CUDA backend (non-mujoco path). + GPU graph accelerated init using gpu_graph=True to batch all init steps into a single graph submission. + + On CUDA, captures all init steps as a CUDA graph for reduced kernel launch overhead. + On other backends, falls back to a C++-side loop that still reduces Python launch overhead. - Breaks the monolithic init kernel into separate kernel launches: + Steps (each a separate graph node): 1. Warmstart selection (ndrange over dofs) 2. Ma = M @ qacc (ndrange over dofs with entity lookup) - 3. Jaref = -aref + J @ qacc (ndrange over constraints — main optimization) + 3. Jaref = -aref + J @ qacc (ndrange over constraints) 4. Set improved flags - 5. Update constraint (wraps monolith's func_update_constraint for exact FP match) - 6. Newton hessian (Newton only — reuse existing kernel) - 7. Update gradient (wraps monolith's func_update_gradient — uses tiled on GPU) + 5. Update constraint (wraps monolith for exact FP match) + 6. Newton hessian (Newton only) + 7. Update gradient 8. search = -Mgrad (ndrange over dofs) """ - # 1. Warmstart selection - _kernel_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) - - # 2. Ma = M @ qacc - _kernel_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) - - # 3. Jaref = -aref + J @ qacc (parallelized over constraints) - _kernel_init_Jaref(constraint_state, static_rigid_sim_config) - - # 4. Set improved flags (needed by decomposed update_constraint kernels) - _kernel_init_improved(constraint_state, static_rigid_sim_config) - - # 5. Update constraint (init-specific: wraps monolith's func_update_constraint for exact FP match) - _kernel_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) - - # 6. Newton hessian (Newton only) - if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton: - _kernel_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) - - # 7. Update gradient (init-specific: wraps monolith's func_update_gradient, dispatches to tiled on GPU) - _kernel_init_update_gradient( - entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config + _kernel_solve_init_gpu_graph( + dofs_info, + entities_info, + dofs_state, + constraint_state, + rigid_global_info, + static_rigid_sim_config, ) - # 8. search = -Mgrad - _kernel_init_search(constraint_state, static_rigid_sim_config) - # ============================================== Solve body dispatch ================================================ + @qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache) def _kernel_solve_gpu_graph( dofs_info: array_class.DofsInfo, From 5cda29654bb2d7ac0be86dfae880a22bb8b1cae1 Mon Sep 17 00:00:00 2001 From: Mingrui Date: Wed, 1 Apr 2026 00:15:21 +0100 Subject: [PATCH 5/6] not use cuda graph on grad tests --- genesis/engine/solvers/rigid/constraint/solver.py | 3 ++- genesis/engine/solvers/rigid/constraint/solver_breakdown.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 21e01d0a70..0e7ef4569e 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2919,7 +2919,7 @@ def func_solve_init( rigid_global_info, static_rigid_sim_config, ): - if gs.backend is not gs.cpu: + if gs.backend is not gs.cpu and not static_rigid_sim_config.requires_grad: from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed func_solve_init_decomposed( @@ -3158,6 +3158,7 @@ def func_solve_body( @func_solve_body.register( is_compatible=lambda *args, **kwargs: _get_static_config(*args, **kwargs).prefer_parallel_linesearch != 1 + or _get_static_config(*args, **kwargs).requires_grad ) @qd.kernel(fastcache=gs.use_fastcache) def func_solve_body_monolith( diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index eefcfaf967..60c0203323 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -1008,6 +1008,7 @@ def _kernel_solve_gpu_graph( @solver.func_solve_body.register( is_compatible=lambda *args, **kwargs: solver._get_static_config(*args, **kwargs).prefer_parallel_linesearch != 0 + and not solver._get_static_config(*args, **kwargs).requires_grad ) def func_solve_decomposed( entities_info, From c2b210f8916876092fcec69fbc0c7a7ab5bc368b Mon Sep 17 00:00:00 2001 From: Mingrui Date: Wed, 1 Apr 2026 02:01:59 +0100 Subject: [PATCH 6/6] add a gpu thread saturation check --- .../engine/solvers/rigid/constraint/solver.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 0e7ef4569e..e566c20a7d 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2911,6 +2911,18 @@ def initialize_Ma( # ======================================================= Core ======================================================== +def _get_gpu_saturation_threshold(): + """Max concurrent warps the GPU can run — above this, envs alone saturate the GPU.""" + if not hasattr(_get_gpu_saturation_threshold, "_cached"): + import torch + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + _get_gpu_saturation_threshold._cached = ( + props.multi_processor_count * props.max_threads_per_multi_processor // props.warp_size + ) + return _get_gpu_saturation_threshold._cached + + def func_solve_init( dofs_info, dofs_state, @@ -2920,15 +2932,18 @@ def func_solve_init( static_rigid_sim_config, ): if gs.backend is not gs.cpu and not static_rigid_sim_config.requires_grad: - from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed + n_envs = dofs_state.acc_smooth.shape[1] + if n_envs <= _get_gpu_saturation_threshold(): + from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed - func_solve_init_decomposed( - dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config - ) - else: - func_solve_init_monolith( - dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config - ) + func_solve_init_decomposed( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) + return + + func_solve_init_monolith( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) @qd.kernel(fastcache=gs.use_fastcache)