Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2911,8 +2911,43 @@ def initialize_Ma(
# ======================================================= Core ========================================================


@qd.kernel(fastcache=gs.use_fastcache)
def _get_gpu_saturation_threshold():
"""Max concurrent warps the GPU can run — above this, envs alone saturate the GPU."""
if not hasattr(_get_gpu_saturation_threshold, "_cached"):
import torch

props = torch.cuda.get_device_properties(torch.cuda.current_device())
_get_gpu_saturation_threshold._cached = (
props.multi_processor_count * props.max_threads_per_multi_processor // props.warp_size
)
return _get_gpu_saturation_threshold._cached


def func_solve_init(
dofs_info,
dofs_state,
entities_info,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
):
if gs.backend is not gs.cpu and not static_rigid_sim_config.requires_grad:
n_envs = dofs_state.acc_smooth.shape[1]
if n_envs <= _get_gpu_saturation_threshold():
from genesis.engine.solvers.rigid.constraint.solver_breakdown import func_solve_init_decomposed

func_solve_init_decomposed(
dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config
)
return

func_solve_init_monolith(
dofs_info, dofs_state, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config
)


@qd.kernel(fastcache=gs.use_fastcache)
def func_solve_init_monolith(
dofs_info: array_class.DofsInfo,
dofs_state: array_class.DofsState,
entities_info: array_class.EntitiesInfo,
Expand Down Expand Up @@ -3138,6 +3173,7 @@ def func_solve_body(

@func_solve_body.register(
is_compatible=lambda *args, **kwargs: _get_static_config(*args, **kwargs).prefer_parallel_linesearch != 1
or _get_static_config(*args, **kwargs).requires_grad
)
@qd.kernel(fastcache=gs.use_fastcache)
def func_solve_body_monolith(
Expand Down
184 changes: 184 additions & 0 deletions genesis/engine/solvers/rigid/constraint/solver_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,189 @@ def _func_check_early_exit(
graph_counter[()] = 0


# ================================================ Init funcs (for gpu_graph) ====================================


@qd.func
def _func_init_warmstart(
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Select qacc from warmstart or acc_smooth, parallelized over (dof, env)."""
n_dofs = dofs_state.acc_smooth.shape[0]
_B = dofs_state.acc_smooth.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]:
constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
else:
constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]


@qd.func
def _func_init_Ma(
dofs_info: array_class.DofsInfo,
entities_info: array_class.EntitiesInfo,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
"""Compute Ma = M @ qacc, parallelized over (dof, env)."""
solver.initialize_Ma(
Ma=constraint_state.Ma,
qacc=constraint_state.qacc,
dofs_info=dofs_info,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)


@qd.func
def _func_init_Jaref(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Compute Jaref = -aref + J @ qacc, parallelized over (constraint, env)."""
len_constraints = constraint_state.Jaref.shape[0]
n_dofs = constraint_state.jac.shape[1]
_B = constraint_state.grad.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b]:
Jaref = -constraint_state.aref[i_c, i_b]
if qd.static(static_rigid_sim_config.sparse_solve):
for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]):
i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b]
Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b]
else:
for i_d in range(n_dofs):
Jaref += constraint_state.jac[i_c, i_d, i_b] * constraint_state.qacc[i_d, i_b]
constraint_state.Jaref[i_c, i_b] = Jaref


@qd.func
def _func_init_improved(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Set improved = (n_constraints > 0) for each env."""
_B = constraint_state.grad.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
constraint_state.improved[i_b] = constraint_state.n_constraints[i_b] > 0


@qd.func
def _func_init_search(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Set search = -Mgrad, parallelized over (dof, env)."""
n_dofs = constraint_state.search.shape[0]
_B = constraint_state.grad.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]


@qd.func
def _func_init_update_constraint(
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Init-only constraint update — wraps monolith's func_update_constraint for exact FP match."""
solver.func_update_constraint(
qacc=constraint_state.qacc,
Ma=constraint_state.Ma,
cost=constraint_state.cost,
dofs_state=dofs_state,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)


@qd.func
def _func_init_update_gradient(
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
"""Init-only gradient update — wraps monolith's func_update_gradient."""
solver.func_update_gradient(
dofs_state=dofs_state,
entities_info=entities_info,
constraint_state=constraint_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)


# ================================================ Init gpu_graph kernel =========================================


@qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache)
def _kernel_solve_init_gpu_graph(
dofs_info: array_class.DofsInfo,
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
_func_init_warmstart(dofs_state, constraint_state, static_rigid_sim_config)
_func_init_Ma(dofs_info, entities_info, constraint_state, rigid_global_info, static_rigid_sim_config)
_func_init_Jaref(constraint_state, static_rigid_sim_config)
_func_init_improved(constraint_state, static_rigid_sim_config)
_func_init_update_constraint(dofs_state, constraint_state, static_rigid_sim_config)
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
_func_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config)
_func_init_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config)
_func_init_search(constraint_state, static_rigid_sim_config)


def func_solve_init_decomposed(
dofs_info,
dofs_state,
entities_info,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
):
"""
GPU graph accelerated init using gpu_graph=True to batch all init steps into a single graph submission.

On CUDA, captures all init steps as a CUDA graph for reduced kernel launch overhead.
On other backends, falls back to a C++-side loop that still reduces Python launch overhead.

Steps (each a separate graph node):
1. Warmstart selection (ndrange over dofs)
2. Ma = M @ qacc (ndrange over dofs with entity lookup)
3. Jaref = -aref + J @ qacc (ndrange over constraints)
4. Set improved flags
5. Update constraint (wraps monolith for exact FP match)
6. Newton hessian (Newton only)
7. Update gradient
8. search = -Mgrad (ndrange over dofs)
"""
_kernel_solve_init_gpu_graph(
dofs_info,
entities_info,
dofs_state,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
)


# ============================================== Solve body dispatch ================================================


Expand Down Expand Up @@ -825,6 +1008,7 @@ def _kernel_solve_gpu_graph(

@solver.func_solve_body.register(
is_compatible=lambda *args, **kwargs: solver._get_static_config(*args, **kwargs).prefer_parallel_linesearch != 0
and not solver._get_static_config(*args, **kwargs).requires_grad
)
def func_solve_decomposed(
entities_info,
Expand Down
Loading