From 70282689ada3002d09ab1bfda23e0c1d22e1d100 Mon Sep 17 00:00:00 2001 From: Paul Zhu Date: Wed, 8 Apr 2026 01:49:57 -0700 Subject: [PATCH 1/3] add nvtx marker --- cuequivariance_jax/examples/mace_linen.py | 8 ++++++-- cuequivariance_jax/examples/mace_nnx.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/cuequivariance_jax/examples/mace_linen.py b/cuequivariance_jax/examples/mace_linen.py index 9d510a5..6453a00 100644 --- a/cuequivariance_jax/examples/mace_linen.py +++ b/cuequivariance_jax/examples/mace_linen.py @@ -25,6 +25,7 @@ import argparse import ctypes import time +import nvtx from typing import Callable import flax @@ -446,10 +447,13 @@ def inference(w, batch_dict): try: cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() + if mode in ["train", "both"]: - jax.block_until_ready(step(w, opt_state, batch_dict, target_E, target_F)) + with nvtx.annotate("Train", color="green"): + jax.block_until_ready(step(w, opt_state, batch_dict, target_E, target_F)) if mode in ["inference", "both"]: - jax.block_until_ready(inference(w, batch_dict)) + with nvtx.annotate("Inference", color="blue"): + jax.block_until_ready(inference(w, batch_dict)) cuda.cudaProfilerStop() except Exception: pass diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index f0f1ba0..40643d1 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -26,6 +26,7 @@ import jax import jax.numpy as jnp import numpy as np +import nvtx from cuequivariance.group_theory.experimental.mace import ( symmetric_contraction as mace_symmetric_contraction, ) @@ -647,12 +648,14 @@ def inference(graphdef, state, batch_dict): cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() if mode in ["train", "both"]: - train_state = step( - train_graphdef, train_state, batch_dict, target_E, target_F - ) - jax.block_until_ready(train_state) + with nvtx.annotate("Train", color="green"): + train_state = step( + train_graphdef, train_state, batch_dict, target_E, target_F + ) + jax.block_until_ready(train_state) if mode in ["inference", "both"]: - jax.block_until_ready(inference(model_graphdef, model_state, batch_dict)) + with nvtx.annotate("Inference", color="blue"): + jax.block_until_ready(inference(model_graphdef, model_state, batch_dict)) cuda.cudaProfilerStop() except Exception: pass From cb8d9307c4890981857e0834bd52a055e9bf4375 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Apr 2026 12:13:53 +0200 Subject: [PATCH 2/3] run pre-commit --- cuequivariance_jax/examples/mace_linen.py | 8 +++++--- cuequivariance_jax/examples/mace_nnx.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cuequivariance_jax/examples/mace_linen.py b/cuequivariance_jax/examples/mace_linen.py index 6453a00..6d336dd 100644 --- a/cuequivariance_jax/examples/mace_linen.py +++ b/cuequivariance_jax/examples/mace_linen.py @@ -25,7 +25,6 @@ import argparse import ctypes import time -import nvtx from typing import Callable import flax @@ -33,6 +32,7 @@ import jax import jax.numpy as jnp import numpy as np +import nvtx import optax from cuequivariance.group_theory.experimental.mace import symmetric_contraction from cuequivariance_jax.experimental.utils import MultiLayerPerceptron, bessel @@ -447,10 +447,12 @@ def inference(w, batch_dict): try: cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() - + if mode in ["train", "both"]: with nvtx.annotate("Train", color="green"): - jax.block_until_ready(step(w, opt_state, batch_dict, target_E, target_F)) + jax.block_until_ready( + step(w, opt_state, batch_dict, target_E, target_F) + ) if mode in ["inference", "both"]: with nvtx.annotate("Inference", color="blue"): jax.block_until_ready(inference(w, batch_dict)) diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index 40643d1..d14c677 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -655,7 +655,9 @@ def inference(graphdef, state, batch_dict): jax.block_until_ready(train_state) if mode in ["inference", "both"]: with nvtx.annotate("Inference", color="blue"): - jax.block_until_ready(inference(model_graphdef, model_state, batch_dict)) + jax.block_until_ready( + inference(model_graphdef, model_state, batch_dict) + ) cuda.cudaProfilerStop() except Exception: pass From 460ad571f97e81839aacc3b9f4340dcec6fd86aa Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Apr 2026 12:24:55 +0200 Subject: [PATCH 3/3] move import nvtx --- cuequivariance_jax/examples/mace_linen.py | 3 ++- cuequivariance_jax/examples/mace_nnx.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cuequivariance_jax/examples/mace_linen.py b/cuequivariance_jax/examples/mace_linen.py index 6d336dd..634e9cc 100644 --- a/cuequivariance_jax/examples/mace_linen.py +++ b/cuequivariance_jax/examples/mace_linen.py @@ -32,7 +32,6 @@ import jax import jax.numpy as jnp import numpy as np -import nvtx import optax from cuequivariance.group_theory.experimental.mace import symmetric_contraction from cuequivariance_jax.experimental.utils import MultiLayerPerceptron, bessel @@ -445,6 +444,8 @@ def inference(w, batch_dict): ) try: + import nvtx + cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index d14c677..e1439b4 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -26,7 +26,6 @@ import jax import jax.numpy as jnp import numpy as np -import nvtx from cuequivariance.group_theory.experimental.mace import ( symmetric_contraction as mace_symmetric_contraction, ) @@ -645,6 +644,8 @@ def inference(graphdef, state, batch_dict): ) try: + import nvtx + cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() if mode in ["train", "both"]: