diff --git a/python/python/psyche/models/factory.py b/python/python/psyche/models/factory.py index 95d133b36..89be8e358 100644 --- a/python/python/psyche/models/factory.py +++ b/python/python/psyche/models/factory.py @@ -15,6 +15,7 @@ def make_causal_lm( param_dtype: torch.dtype = torch.bfloat16, reduce_dtype: torch.dtype = torch.float32, fsdp_modules: Optional[Iterable[str]] = None, + compile: bool = True, ) -> CausalLM: if not isinstance(device, torch.device): device = torch.device(device if isinstance(device, str) else f"cuda:{device}") @@ -45,5 +46,6 @@ def make_causal_lm( param_dtype=param_dtype, reduce_dtype=reduce_dtype, fsdp_modules=fsdp_modules, + compile=compile, ) raise ValueError(f"Unknown architecture {architecture}") diff --git a/python/python/psyche/models/ttitan.py b/python/python/psyche/models/ttitan.py index cd8fa6a15..9f59ccebd 100644 --- a/python/python/psyche/models/ttitan.py +++ b/python/python/psyche/models/ttitan.py @@ -282,6 +282,7 @@ def from_pretrained( param_dtype: torch.dtype = torch.bfloat16, reduce_dtype: torch.dtype = torch.float32, fsdp_modules: Optional[Iterable[str]] = None, + compile: bool = True, ): config_json = None if isinstance(source, PretrainedSourceStateDict): @@ -302,7 +303,7 @@ def from_pretrained( job_config = JobConfig() job_config.training.seq_len = config_tt.max_seq_len - job_config.compile.enable = True + job_config.compile.enable = compile job_config.compile.components = ["model", "loss"] job_config.compile.fullgraph = False job_config.activation_checkpoint.mode = "full" diff --git a/shared/client/src/state/evals.rs b/shared/client/src/state/evals.rs index cb3837173..96aeb15e4 100644 --- a/shared/client/src/state/evals.rs +++ b/shared/client/src/state/evals.rs @@ -232,7 +232,7 @@ impl ModelTaskRunner { pub fn start(&self, trainers: Vec) -> RunningEvals { let cancel = CancellationToken::new(); - trace!("Starting evals!"); + info!("Starting evals!"); RunningEvals { cancel: cancel.clone(), diff --git a/shared/client/src/state/stats.rs b/shared/client/src/state/stats.rs index 8e1c4f83e..f12d0ed03 100644 --- a/shared/client/src/state/stats.rs +++ b/shared/client/src/state/stats.rs @@ -7,7 +7,7 @@ use psyche_modeling::Trainer; use psyche_network::P2PEndpointInfo; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokenizers::Tokenizer; -use tracing::{debug, trace, warn}; +use tracing::{debug, info, trace, warn}; use wandb::{DataValue, LogData}; use crate::state::evals::{EnumModelTask, PROMPT_TASK_NAME}; @@ -125,6 +125,7 @@ impl StatsLogger { .collect::() ); round_log.insert(formatted_key.clone(), val); + info!("{}: {:.4}", formatted_key, val); self.metrics.record_eval_metric(&key, val); } diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index b344abd0b..1b14706cf 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -237,19 +237,22 @@ fn run_data_parallel( { psyche_python_extension_impl::init_embedded_python()?; - Box::new(psyche_modeling::PythonDistributedCausalLM::new( - python_arch, - psyche_modeling::PretrainedSource::RepoFiles(repo), - device, - psyche_modeling::AttentionImplementation::default(), - psyche_modeling::ParallelismConfig { - dp: data_parallelism, - tp: 1, - }, - None, - None, - None, - )?) as Box + Box::new( + psyche_modeling::PythonDistributedCausalLM::new_with_options( + python_arch, + psyche_modeling::PretrainedSource::RepoFiles(repo), + device, + psyche_modeling::AttentionImplementation::default(), + psyche_modeling::ParallelismConfig { + dp: data_parallelism, + tp: 1, + }, + None, + None, + None, + false, // disable torch.compile for eval — avoids recompilation on every different sequence length + )?, + ) as Box } #[cfg(not(feature = "python"))] diff --git a/shared/modeling/src/python_causal_lm.rs b/shared/modeling/src/python_causal_lm.rs index 07e57e1c8..b3087b2b3 100644 --- a/shared/modeling/src/python_causal_lm.rs +++ b/shared/modeling/src/python_causal_lm.rs @@ -133,6 +133,26 @@ impl PythonCausalLM { attn_implementation: AttentionImplementation, parallelism: Option, override_max_position_embeddings: Option, + ) -> Result { + Self::new_with_options( + architecture, + source, + device, + attn_implementation, + parallelism, + override_max_position_embeddings, + true, + ) + } + + pub fn new_with_options( + architecture: &str, + source: &PretrainedSource, + device: Device, + attn_implementation: AttentionImplementation, + parallelism: Option, + override_max_position_embeddings: Option, + compile: bool, ) -> Result { let config = source.get_config()?; let result: PyResult = Python::with_gil(|py| { @@ -176,7 +196,9 @@ impl PythonCausalLM { parallelism.as_ref().map(|x| x.tp).unwrap_or(1), override_max_position_embeddings, ); - let causal_lm = make_causal_lm.call1(args)?; + let kwargs = PyDict::new(py); + kwargs.set_item("compile", compile)?; + let causal_lm = make_causal_lm.call(args, Some(&kwargs))?; Ok(causal_lm.unbind()) }); let causal_lm = result?; diff --git a/shared/modeling/src/python_distributed_causal_lm.rs b/shared/modeling/src/python_distributed_causal_lm.rs index 14f418ad4..a24c66172 100644 --- a/shared/modeling/src/python_distributed_causal_lm.rs +++ b/shared/modeling/src/python_distributed_causal_lm.rs @@ -228,6 +228,30 @@ impl PythonDistributedCausalLM { override_max_position_embeddings: Option, port: Option, num_local_ranks: Option, + ) -> Result { + Self::new_with_options( + architecture, + source, + device, + attn_implementation, + parallelism, + override_max_position_embeddings, + port, + num_local_ranks, + true, + ) + } + + pub fn new_with_options( + architecture: String, + source: PretrainedSource, + device: Device, + attn_implementation: AttentionImplementation, + parallelism: ParallelismConfig, + override_max_position_embeddings: Option, + port: Option, + num_local_ranks: Option, + compile: bool, ) -> Result { if !tch::Cuda::is_available() { return Err(PythonDistributedCausalLMError::CUDANotAvailable); @@ -323,13 +347,14 @@ impl PythonDistributedCausalLM { } comm.set("dp", &format!("{}", parallelism.dp))?; comm.set("tp", &format!("{}", parallelism.tp))?; - let local = PythonCausalLM::new( + let local = PythonCausalLM::new_with_options( &architecture, &source, device, attn_implementation, Some(parallelism), override_max_position_embeddings, + compile, )?; Ok((comm, local)) })