diff --git a/cuequivariance_jax/examples/mace_linen.py b/cuequivariance_jax/examples/mace_linen.py index 9d510a5..634e9cc 100644 --- a/cuequivariance_jax/examples/mace_linen.py +++ b/cuequivariance_jax/examples/mace_linen.py @@ -444,12 +444,19 @@ def inference(w, batch_dict): ) try: + import nvtx + cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() + if mode in ["train", "both"]: - jax.block_until_ready(step(w, opt_state, batch_dict, target_E, target_F)) + with nvtx.annotate("Train", color="green"): + jax.block_until_ready( + step(w, opt_state, batch_dict, target_E, target_F) + ) if mode in ["inference", "both"]: - jax.block_until_ready(inference(w, batch_dict)) + with nvtx.annotate("Inference", color="blue"): + jax.block_until_ready(inference(w, batch_dict)) cuda.cudaProfilerStop() except Exception: pass diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index f0f1ba0..e1439b4 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -644,15 +644,21 @@ def inference(graphdef, state, batch_dict): ) try: + import nvtx + cuda = ctypes.CDLL("libcudart.so") cuda.cudaProfilerStart() if mode in ["train", "both"]: - train_state = step( - train_graphdef, train_state, batch_dict, target_E, target_F - ) - jax.block_until_ready(train_state) + with nvtx.annotate("Train", color="green"): + train_state = step( + train_graphdef, train_state, batch_dict, target_E, target_F + ) + jax.block_until_ready(train_state) if mode in ["inference", "both"]: - jax.block_until_ready(inference(model_graphdef, model_state, batch_dict)) + with nvtx.annotate("Inference", color="blue"): + jax.block_until_ready( + inference(model_graphdef, model_state, batch_dict) + ) cuda.cudaProfilerStop() except Exception: pass