diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 38aa9306b1..e566c20a7d 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2911,8 +2911,43 @@ def initialize_Ma( # ======================================================= Core ======================================================== -@qd.kernel(fastcache=gs.use_fastcache) +def _get_gpu_saturation_threshold(): + """Max concurrent warps the GPU can run — above this, envs alone saturate the GPU.""" + if not hasattr(_get_gpu_saturation_threshold, "_cached"): + import torch + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + _get_gpu_saturation_threshold._cached = ( + props.multi_processor_count * props.max_threads_per_multi_processor // props.warp_size + ) + return _get_gpu_saturation_threshold._cached + + def func_solve_init( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + if gs.backend is not gs.cpu and not static_rigid_sim_config.requires_grad: + n_envs = dofs_state.acc_smooth.shape[1] + if n_envs <= _get_gpu_saturation_threshold(): + from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed + + func_solve_init_decomposed( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) + return + + func_solve_init_monolith( + dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config + ) + + +@qd.kernel(fastcache=gs.use_fastcache) +def func_solve_init_monolith( dofs_info: array_class.DofsInfo, dofs_state: array_class.DofsState, entities_info: array_class.EntitiesInfo, @@ -3138,6 +3173,7 @@ def func_solve_body( @func_solve_body.register( is_compatible=lambda *args, **kwargs: _get_static_config(*args, **kwargs).prefer_parallel_linesearch != 1 + or _get_static_config(*args, **kwargs).requires_grad ) @qd.kernel(fastcache=gs.use_fastcache) def func_solve_body_monolith( diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ea6fadbc7d..60c0203323 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -791,6 +791,189 @@ def _func_check_early_exit( graph_counter[()] = 0 +# ================================================ Init funcs (for gpu_graph) ==================================== + + +@qd.func +def _func_init_warmstart( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Select qacc from warmstart or acc_smooth, parallelized over (dof, env).""" + n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + + +@qd.func +def _func_init_Ma( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Compute Ma = M @ qacc, parallelized over (dof, env).""" + solver.initialize_Ma( + Ma=constraint_state.Ma, + qacc=constraint_state.qacc, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.func +def _func_init_Jaref( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env).""" + len_constraints = constraint_state.Jaref.shape[0] + n_dofs = constraint_state.jac.shape[1] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange(len_constraints, _B): + if i_c < constraint_state.n_constraints[i_b]: + Jaref = -constraint_state.aref[i_c, i_b] + if qd.static(static_rigid_sim_config.sparse_solve): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + else: + for i_d in range(n_dofs): + Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b] + constraint_state.Jaref[i_c, i_b] = Jaref + + +@qd.func +def _func_init_improved( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set improved = (n_constraints > 0) for each env.""" + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0 + + +@qd.func +def _func_init_search( + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Set search = -Mgrad, parallelized over (dof, env).""" + n_dofs = constraint_state.search.shape[0] + _B = constraint_state.grad.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + + +@qd.func +def _func_init_update_constraint( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Init-only constraint update — wraps monolith's func_update_constraint for exact FP match.""" + solver.func_update_constraint( + qacc=constraint_state.qacc, + Ma=constraint_state.Ma, + cost=constraint_state.cost, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@qd.func +def _func_init_update_gradient( + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Init-only gradient update — wraps monolith's func_update_gradient.""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +# ================================================ Init gpu_graph kernel ========================================= + + +@qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache) +def _kernel_solve_init_gpu_graph( + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + _func_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config) + _func_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_Jaref(constraint_state, static_rigid_sim_config) + _func_init_improved(constraint_state, static_rigid_sim_config) + _func_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config) + if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): + _func_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config) + _func_init_search(constraint_state, static_rigid_sim_config) + + +def func_solve_init_decomposed( + dofs_info, + dofs_state, + entities_info, + constraint_state, + rigid_global_info, + static_rigid_sim_config, +): + """ + GPU graph accelerated init using gpu_graph=True to batch all init steps into a single graph submission. + + On CUDA, captures all init steps as a CUDA graph for reduced kernel launch overhead. + On other backends, falls back to a C++-side loop that still reduces Python launch overhead. + + Steps (each a separate graph node): + 1. Warmstart selection (ndrange over dofs) + 2. Ma = M @ qacc (ndrange over dofs with entity lookup) + 3. Jaref = -aref + J @ qacc (ndrange over constraints) + 4. Set improved flags + 5. Update constraint (wraps monolith for exact FP match) + 6. Newton hessian (Newton only) + 7. Update gradient + 8. search = -Mgrad (ndrange over dofs) + """ + _kernel_solve_init_gpu_graph( + dofs_info, + entities_info, + dofs_state, + constraint_state, + rigid_global_info, + static_rigid_sim_config, + ) + + # ============================================== Solve body dispatch ================================================ @@ -825,6 +1008,7 @@ def _kernel_solve_gpu_graph( @solver.func_solve_body.register( is_compatible=lambda *args, **kwargs: solver._get_static_config(*args, **kwargs).prefer_parallel_linesearch != 0 + and not solver._get_static_config(*args, **kwargs).requires_grad ) def func_solve_decomposed( entities_info,