From 12ef503274e4eb2ae5b897e8bf7fdb3ecb66d06d Mon Sep 17 00:00:00 2001 From: Dylan Socolobsky Date: Tue, 3 Mar 2026 16:43:43 -0300 Subject: [PATCH 1/2] evals: avoid calling torch.compile when not needed This improves performance, we can avoid calling torch.compile in evals since that precompiles length-specific kernels which we'll not reuse since every input has different input length. --- python/python/psyche/models/ttitan.py | 8 +++++++- shared/client/src/state/evals.rs | 2 +- shared/client/src/state/stats.rs | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/python/psyche/models/ttitan.py b/python/python/psyche/models/ttitan.py index cd8fa6a15..61f60f8af 100644 --- a/python/python/psyche/models/ttitan.py +++ b/python/python/psyche/models/ttitan.py @@ -429,7 +429,13 @@ def forward( position_ids = position_ids.narrow(0, start_row, shard_size) try: with self.amp, torch.cuda.device(input_ids.device.index): - pred = self.model( + # During eval we bypass torch.compile as a speed optimization (we know we're in eval because labels=None) + # There's no need to precompile kernels for different sequence lengths since every input has a different one + # _orig_mod stores the original model so calling it directly bypasses the compilation stage. + model = self.model + if labels is None and hasattr(self.model, "_orig_mod"): + model = self.model._orig_mod + pred = model( tokens=input_ids.contiguous(), position_ids=( position_ids.contiguous() if position_ids is not None else None diff --git a/shared/client/src/state/evals.rs b/shared/client/src/state/evals.rs index cb3837173..96aeb15e4 100644 --- a/shared/client/src/state/evals.rs +++ b/shared/client/src/state/evals.rs @@ -232,7 +232,7 @@ impl ModelTaskRunner { pub fn start(&self, trainers: Vec) -> RunningEvals { let cancel = CancellationToken::new(); - trace!("Starting evals!"); + info!("Starting evals!"); RunningEvals { cancel: cancel.clone(), diff --git a/shared/client/src/state/stats.rs b/shared/client/src/state/stats.rs index fcd26c5c8..684fe5a60 100644 --- a/shared/client/src/state/stats.rs +++ b/shared/client/src/state/stats.rs @@ -7,7 +7,7 @@ use psyche_modeling::Trainer; use psyche_network::P2PEndpointInfo; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokenizers::Tokenizer; -use tracing::{debug, trace, warn}; +use tracing::{debug, info, trace, warn}; use wandb::{DataValue, LogData}; use crate::state::evals::{EnumModelTask, PROMPT_TASK_NAME}; @@ -125,6 +125,7 @@ impl StatsLogger { .collect::() ); round_log.insert(formatted_key.clone(), val); + info!("{}: {:.4}", formatted_key, val); self.metrics.record_eval_metric(&key, val); } From 46f9a71dcb1ae1a6d3c86a9ea5360e3d8c0dc571 Mon Sep 17 00:00:00 2001 From: Dylan Socolobsky Date: Mon, 9 Mar 2026 12:32:11 -0300 Subject: [PATCH 2/2] try alternate approach to avoid compiling CUDA kernels --- python/python/psyche/models/factory.py | 2 ++ python/python/psyche/models/ttitan.py | 11 ++----- shared/eval/examples/evaluate.rs | 29 ++++++++++--------- shared/modeling/src/python_causal_lm.rs | 24 ++++++++++++++- .../src/python_distributed_causal_lm.rs | 27 ++++++++++++++++- 5 files changed, 70 insertions(+), 23 deletions(-) diff --git a/python/python/psyche/models/factory.py b/python/python/psyche/models/factory.py index 95d133b36..89be8e358 100644 --- a/python/python/psyche/models/factory.py +++ b/python/python/psyche/models/factory.py @@ -15,6 +15,7 @@ def make_causal_lm( param_dtype: torch.dtype = torch.bfloat16, reduce_dtype: torch.dtype = torch.float32, fsdp_modules: Optional[Iterable[str]] = None, + compile: bool = True, ) -> CausalLM: if not isinstance(device, torch.device): device = torch.device(device if isinstance(device, str) else f"cuda:{device}") @@ -45,5 +46,6 @@ def make_causal_lm( param_dtype=param_dtype, reduce_dtype=reduce_dtype, fsdp_modules=fsdp_modules, + compile=compile, ) raise ValueError(f"Unknown architecture {architecture}") diff --git a/python/python/psyche/models/ttitan.py b/python/python/psyche/models/ttitan.py index 61f60f8af..9f59ccebd 100644 --- a/python/python/psyche/models/ttitan.py +++ b/python/python/psyche/models/ttitan.py @@ -282,6 +282,7 @@ def from_pretrained( param_dtype: torch.dtype = torch.bfloat16, reduce_dtype: torch.dtype = torch.float32, fsdp_modules: Optional[Iterable[str]] = None, + compile: bool = True, ): config_json = None if isinstance(source, PretrainedSourceStateDict): @@ -302,7 +303,7 @@ def from_pretrained( job_config = JobConfig() job_config.training.seq_len = config_tt.max_seq_len - job_config.compile.enable = True + job_config.compile.enable = compile job_config.compile.components = ["model", "loss"] job_config.compile.fullgraph = False job_config.activation_checkpoint.mode = "full" @@ -429,13 +430,7 @@ def forward( position_ids = position_ids.narrow(0, start_row, shard_size) try: with self.amp, torch.cuda.device(input_ids.device.index): - # During eval we bypass torch.compile as a speed optimization (we know we're in eval because labels=None) - # There's no need to precompile kernels for different sequence lengths since every input has a different one - # _orig_mod stores the original model so calling it directly bypasses the compilation stage. - model = self.model - if labels is None and hasattr(self.model, "_orig_mod"): - model = self.model._orig_mod - pred = model( + pred = self.model( tokens=input_ids.contiguous(), position_ids=( position_ids.contiguous() if position_ids is not None else None diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index b344abd0b..1b14706cf 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -237,19 +237,22 @@ fn run_data_parallel( { psyche_python_extension_impl::init_embedded_python()?; - Box::new(psyche_modeling::PythonDistributedCausalLM::new( - python_arch, - psyche_modeling::PretrainedSource::RepoFiles(repo), - device, - psyche_modeling::AttentionImplementation::default(), - psyche_modeling::ParallelismConfig { - dp: data_parallelism, - tp: 1, - }, - None, - None, - None, - )?) as Box + Box::new( + psyche_modeling::PythonDistributedCausalLM::new_with_options( + python_arch, + psyche_modeling::PretrainedSource::RepoFiles(repo), + device, + psyche_modeling::AttentionImplementation::default(), + psyche_modeling::ParallelismConfig { + dp: data_parallelism, + tp: 1, + }, + None, + None, + None, + false, // disable torch.compile for eval — avoids recompilation on every different sequence length + )?, + ) as Box } #[cfg(not(feature = "python"))] diff --git a/shared/modeling/src/python_causal_lm.rs b/shared/modeling/src/python_causal_lm.rs index 07e57e1c8..b3087b2b3 100644 --- a/shared/modeling/src/python_causal_lm.rs +++ b/shared/modeling/src/python_causal_lm.rs @@ -133,6 +133,26 @@ impl PythonCausalLM { attn_implementation: AttentionImplementation, parallelism: Option, override_max_position_embeddings: Option, + ) -> Result { + Self::new_with_options( + architecture, + source, + device, + attn_implementation, + parallelism, + override_max_position_embeddings, + true, + ) + } + + pub fn new_with_options( + architecture: &str, + source: &PretrainedSource, + device: Device, + attn_implementation: AttentionImplementation, + parallelism: Option, + override_max_position_embeddings: Option, + compile: bool, ) -> Result { let config = source.get_config()?; let result: PyResult = Python::with_gil(|py| { @@ -176,7 +196,9 @@ impl PythonCausalLM { parallelism.as_ref().map(|x| x.tp).unwrap_or(1), override_max_position_embeddings, ); - let causal_lm = make_causal_lm.call1(args)?; + let kwargs = PyDict::new(py); + kwargs.set_item("compile", compile)?; + let causal_lm = make_causal_lm.call(args, Some(&kwargs))?; Ok(causal_lm.unbind()) }); let causal_lm = result?; diff --git a/shared/modeling/src/python_distributed_causal_lm.rs b/shared/modeling/src/python_distributed_causal_lm.rs index 14f418ad4..a24c66172 100644 --- a/shared/modeling/src/python_distributed_causal_lm.rs +++ b/shared/modeling/src/python_distributed_causal_lm.rs @@ -228,6 +228,30 @@ impl PythonDistributedCausalLM { override_max_position_embeddings: Option, port: Option, num_local_ranks: Option, + ) -> Result { + Self::new_with_options( + architecture, + source, + device, + attn_implementation, + parallelism, + override_max_position_embeddings, + port, + num_local_ranks, + true, + ) + } + + pub fn new_with_options( + architecture: String, + source: PretrainedSource, + device: Device, + attn_implementation: AttentionImplementation, + parallelism: ParallelismConfig, + override_max_position_embeddings: Option, + port: Option, + num_local_ranks: Option, + compile: bool, ) -> Result { if !tch::Cuda::is_available() { return Err(PythonDistributedCausalLMError::CUDANotAvailable); @@ -323,13 +347,14 @@ impl PythonDistributedCausalLM { } comm.set("dp", &format!("{}", parallelism.dp))?; comm.set("tp", &format!("{}", parallelism.tp))?; - let local = PythonCausalLM::new( + let local = PythonCausalLM::new_with_options( &architecture, &source, device, attn_implementation, Some(parallelism), override_max_position_embeddings, + compile, )?; Ok((comm, local)) })