From 54edc02d832e1e5950f12deb93c83a9a8c4a60ae Mon Sep 17 00:00:00 2001 From: Peter Fontana Date: Thu, 5 Mar 2026 14:58:13 -0300 Subject: [PATCH 1/2] Add parallelism auto-detection from GCS via signed URLs When --parallelism-auto is set, the client downloads parallelism_data.json from GCS (via run-down signed URLs) or HuggingFace to auto-detect optimal dp/tp/micro_batch_size based on GPU type and count. This removes the need for users to manually configure parallelism settings. --- architectures/centralized/client/src/app.rs | 1 + .../decentralized/solana-client/src/app.rs | 1 + shared/client/Cargo.toml | 1 + shared/client/src/cli.rs | 4 + shared/client/src/lib.rs | 1 + shared/client/src/parallelism_lookup.rs | 139 ++++++++++++++++++ shared/client/src/state/init.rs | 45 +++++- shared/data-provider/src/gcs_signed.rs | 44 ++++++ shared/data-provider/src/lib.rs | 5 +- 9 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 shared/client/src/parallelism_lookup.rs diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index b1ee8430d..ec9299394 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -135,6 +135,7 @@ pub async fn build_app( .await?; let state_options: RunInitConfig = RunInitConfig { + parallelism_auto: p.parallelism_auto, data_parallelism: p.data_parallelism, tensor_parallelism: p.tensor_parallelism, micro_batch_size: p.micro_batch_size, diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 000f76343..b81905a25 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -134,6 +134,7 @@ pub async fn build_app( let state_options: RunInitConfig = RunInitConfig { + parallelism_auto: p.parallelism_auto, data_parallelism: p.data_parallelism, tensor_parallelism: p.tensor_parallelism, micro_batch_size: p.micro_batch_size, diff --git a/shared/client/Cargo.toml b/shared/client/Cargo.toml index 91c314474..bff0c46fb 100644 --- a/shared/client/Cargo.toml +++ b/shared/client/Cargo.toml @@ -37,6 +37,7 @@ sysinfo = "0.32.0" iroh.workspace = true iroh-blobs.workspace = true google-cloud-storage.workspace = true +nvml-wrapper = "0.11.0" [features] parallelism = ["psyche-modeling/parallelism"] diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index b3871695b..17ef0e23f 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -113,6 +113,10 @@ pub struct TrainArgs { #[clap(long, env, value_parser = parse_trim_quotes)] pub run_id: String, + /// Auto-detect parallelism settings from lookup table based on model and GPU count + #[clap(long, env)] + pub parallelism_auto: bool, + #[clap(long, default_value_t = 1, env)] pub data_parallelism: usize, diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index b3b810685..c299cdb7f 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -1,6 +1,7 @@ mod cli; mod client; mod fetch_data; +pub mod parallelism_lookup; mod protocol; mod state; mod tui; diff --git a/shared/client/src/parallelism_lookup.rs b/shared/client/src/parallelism_lookup.rs new file mode 100644 index 000000000..4a2301f5b --- /dev/null +++ b/shared/client/src/parallelism_lookup.rs @@ -0,0 +1,139 @@ +use anyhow::{Context, Result}; +use nvml_wrapper::Nvml; +use psyche_coordinator::model; +use psyche_data_provider::{download_parallelism_data_from_gcs_signed, RunDownClient}; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::Arc; +use tracing::info; + +#[derive(Debug, Clone, Copy, Deserialize)] +pub struct ParallelismConfig { + pub dp: usize, + pub tp: usize, + pub micro_batch_size: usize, +} + +// Table format: gpu_type -> num_gpus -> config +type Table = HashMap>; + +/// Auto-detect parallelism settings by downloading parallelism_data.json +/// from GCS (via signed URLs) or HuggingFace, then looking up the config +/// for the detected GPU type and count. +pub async fn lookup( + checkpoint: &model::Checkpoint, + run_down_client: Option<&Arc>, + hub_read_token: Option<&str>, +) -> Result { + let device_count = tch::Cuda::device_count() as usize; + if device_count == 0 { + anyhow::bail!("No GPUs found for parallelism auto-detection"); + } + + let gpu_type = normalize_gpu_name(&get_gpu_type_from_nvml()?); + info!("Detected {} x {} GPU(s)", device_count, gpu_type); + + let json = download_parallelism_data(checkpoint, run_down_client, hub_read_token).await?; + let table: Table = + serde_json::from_str(&json).context("Failed to parse parallelism_data.json")?; + + lookup_in_table(&table, &gpu_type, device_count) +} + +fn get_gpu_type_from_nvml() -> Result { + let nvml = Nvml::init().context("Failed to initialize NVML")?; + let device = nvml + .device_by_index(0) + .context("Failed to get GPU device 0")?; + device.name().context("Failed to get GPU name") +} + +fn normalize_gpu_name(raw_name: &str) -> String { + let upper = raw_name.to_uppercase(); + if upper.contains("H200") { + "H200".to_string() + } else if upper.contains("H100") { + "H100".to_string() + } else if upper.contains("A100") { + "A100".to_string() + } else if upper.contains("L40S") { + "L40S".to_string() + } else if upper.contains("L40") { + "L40".to_string() + } else if upper.contains("4090") { + "RTX4090".to_string() + } else if upper.contains("3090") { + "RTX3090".to_string() + } else { + raw_name.to_string() + } +} + +async fn download_parallelism_data( + checkpoint: &model::Checkpoint, + run_down_client: Option<&Arc>, + hub_read_token: Option<&str>, +) -> Result { + match checkpoint { + model::Checkpoint::Gcs(_) | model::Checkpoint::P2PGcs(_) => { + let client = run_down_client + .ok_or_else(|| anyhow::anyhow!("RunDownClient required for GCS parallelism lookup"))?; + info!( + "Fetching parallelism_data.json from GCS via run-down signed URLs for run {}", + client.run_id() + ); + download_parallelism_data_from_gcs_signed(client) + .await + .map_err(|e| anyhow::anyhow!("{}", e)) + } + model::Checkpoint::Hub(hub_repo) | model::Checkpoint::P2P(hub_repo) => { + let repo_id: String = (&hub_repo.repo_id).into(); + info!( + "Fetching parallelism_data.json from HuggingFace repo '{}'", + repo_id + ); + download_from_hub(&repo_id, hub_read_token).await + } + _ => anyhow::bail!("Parallelism auto-detection requires Hub or GCS checkpoint type"), + } +} + +async fn download_from_hub(repo_id: &str, token: Option<&str>) -> Result { + let mut builder = hf_hub::api::tokio::ApiBuilder::new(); + if let Some(token) = token { + builder = builder.with_token(Some(token.to_string())); + } + let api = builder.build()?; + let repo = api.model(repo_id.to_string()); + let path = repo.get("parallelism_data.json").await.with_context(|| { + format!( + "parallelism_data.json not found in HuggingFace repo '{}'", + repo_id + ) + })?; + tokio::fs::read_to_string(path) + .await + .context("Failed to read parallelism_data.json") +} + +fn lookup_in_table(table: &Table, gpu_type: &str, num_gpus: usize) -> Result { + let gpu_configs = table.get(gpu_type).ok_or_else(|| { + anyhow::anyhow!( + "No parallelism config for GPU type '{}'. Available: {:?}", + gpu_type, + table.keys().collect::>() + ) + })?; + + gpu_configs + .get(&num_gpus.to_string()) + .copied() + .ok_or_else(|| { + anyhow::anyhow!( + "No parallelism config for {} x {}. Available counts: {:?}", + num_gpus, + gpu_type, + gpu_configs.keys().collect::>() + ) + }) +} diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index a1c9ec48b..02cda12e8 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "parallelism")] +use crate::parallelism_lookup; use crate::{WandBInfo, fetch_data::DataFetcher}; use psyche_coordinator::{ Coordinator, HealthChecks, @@ -51,6 +53,8 @@ pub struct RunInitConfig { pub device: Devices, pub hub_read_token: Option, pub hub_max_concurrent_downloads: usize, + /// If true, auto-detect parallelism from lookup table (overrides dp/tp/micro_batch_size) + pub parallelism_auto: bool, pub data_parallelism: usize, pub tensor_parallelism: usize, pub micro_batch_size: usize, @@ -119,6 +123,9 @@ pub enum InitRunError { #[error("Unsupported architecture: {0}")] UnsupportedArchitecture(String), + #[error("Parallelism auto-detection failed: {0}")] + ParallelismLookupFailed(anyhow::Error), + #[cfg(feature = "python")] #[error("Python distributed error: {0}")] PythonDistributedError(#[from] psyche_modeling::PythonDistributedCausalLMError), @@ -173,7 +180,7 @@ impl RunInitConfigAndIO, ) -> Result, InitRunError> { let Self { - init_config, + mut init_config, tx_witness, tx_health_check, tx_model, @@ -197,6 +204,42 @@ impl RunInitConfigAndIO Result { + let http = reqwest::Client::new(); + + let download_response = run_down + .get_download_urls() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + let entry = download_response + .urls + .iter() + .find(|e| e.path.ends_with("parallelism_data.json")) + .ok_or_else(|| { + DownloadError::RunDown( + "parallelism_data.json not found in GCS. Upload it alongside the model files." + .to_string(), + ) + })?; + + info!("Downloading parallelism_data.json via signed URL"); + + let response = http + .get(&entry.url) + .send() + .await + .map_err(|e| DownloadError::RunDown(e.to_string()))?; + + if !response.status().is_success() { + return Err(DownloadError::RunDown(format!( + "Failed to download parallelism_data.json: {}", + response.status() + ))); + } + + response + .text() + .await + .map_err(|e| DownloadError::RunDown(e.to_string())) +} + pub async fn download_model_from_gcs_signed_async( run_down: &RunDownClient, ) -> Result, DownloadError> { diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index d6d46850f..23a0b328e 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -23,7 +23,10 @@ pub use gcs::{ GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, }; -pub use gcs_signed::{download_model_from_gcs_signed_async, upload_to_gcs_signed}; +pub use gcs_signed::{ + download_model_from_gcs_signed_async, download_parallelism_data_from_gcs_signed, + upload_to_gcs_signed, +}; pub use hub::{ HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_to_hub, From 380a6e410688e9ff7e03dca82747129beceea909 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Fri, 6 Mar 2026 15:38:38 +0000 Subject: [PATCH 2/2] Simplify parallelism auto-detection to GCS-only and add docs --- Cargo.lock | 1 + psyche-book/src/enduser/create-run.md | 24 ++++++++ psyche-book/src/enduser/join-run.md | 10 ++++ shared/client/src/parallelism_lookup.rs | 79 ++++--------------------- shared/client/src/state/init.rs | 20 ++++--- shared/data-provider/src/gcs_signed.rs | 1 - 6 files changed, 58 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 207564f3f..20f2e493a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7404,6 +7404,7 @@ dependencies = [ "iroh", "iroh-blobs", "lazy_static", + "nvml-wrapper", "postcard", "psyche-coordinator", "psyche-core", diff --git a/psyche-book/src/enduser/create-run.md b/psyche-book/src/enduser/create-run.md index 7a9354ca1..42959ddb5 100644 --- a/psyche-book/src/enduser/create-run.md +++ b/psyche-book/src/enduser/create-run.md @@ -86,6 +86,30 @@ run-manager create-run \ At this point, your run has been successfully created. +### Adding parallelism configuration (required for --parallelism-auto) + +If you want clients to use `PARALLELISM_AUTO=true` for automatic configuration, you must add a `parallelism_data.json` file to your model's GCS bucket alongside the model files. + +```json +{ + "H100": { + "1": { "dp": 1, "tp": 1, "micro_batch_size": 4 }, + "8": { "dp": 4, "tp": 2, "micro_batch_size": 4 } + }, + "H200": { + "8": { "dp": 8, "tp": 1, "micro_batch_size": 8 } + } +} +``` + +Format: `gpu_type` → `num_gpus` → config + +- **gpu_type**: GPU model name (e.g., "H100", "H200") +- **num_gpus**: Number of GPUs available (e.g., "1", "8") +- **dp**: Data parallelism +- **tp**: Tensor parallelism +- **micro_batch_size**: Micro batch size per GPU + ### Initializing configuration Initially, the run will not have any configuration defined and will remain paused, so no clients can join yet. diff --git a/psyche-book/src/enduser/join-run.md b/psyche-book/src/enduser/join-run.md index d17fcd909..5b871b372 100644 --- a/psyche-book/src/enduser/join-run.md +++ b/psyche-book/src/enduser/join-run.md @@ -122,19 +122,29 @@ though you might need to. **`NVIDIA_DRIVER_CAPABILITIES`** - An environment variable that the NVIDIA Container Toolkit uses to determine which compute capabilities should be provided to your container. It is recommended to set it to 'all', e.g. `NVIDIA_DRIVER_CAPABILITIES=all`. +**`PARALLELISM_AUTO`** - Set to `true` to automatically detect optimal parallelism settings based on your GPU hardware. + +- When enabled, the client fetches a `parallelism_data.json` lookup table from the model's GCS bucket and picks the best `DATA_PARALLELISM`, `TENSOR_PARALLELISM`, and `MICRO_BATCH_SIZE` for your GPU type and count +- Your GPU type and count must be present in the lookup table +- This is the recommended option for most users +- If set, manual parallelism settings below will be ignored + **`DATA_PARALLELISM`** - Number of GPUs to distribute training data across. - If you have multiple GPUs, you can set this to 2, 4, etc. to speed up training - If you have 1 GPU, set this to `1` +- Ignored if `PARALLELISM_AUTO=true` **`TENSOR_PARALLELISM`** - Number of GPUs to distribute the model across, this lets you train a model you can't fit on one single GPU. - If you have 1 GPU, set this to `1` - If your have `n` GPUs you can distribute the model across all of them by setting it to `n`. +- Ignored if `PARALLELISM_AUTO=true` **`MICRO_BATCH_SIZE`** - Number of samples processed per GPU per training step - Set as high as your GPU memory allows +- Ignored if `PARALLELISM_AUTO=true` **`AUTHORIZER`** - The Solana address that authorized your wallet to join this run diff --git a/shared/client/src/parallelism_lookup.rs b/shared/client/src/parallelism_lookup.rs index 4a2301f5b..3ebac0b3f 100644 --- a/shared/client/src/parallelism_lookup.rs +++ b/shared/client/src/parallelism_lookup.rs @@ -1,7 +1,6 @@ use anyhow::{Context, Result}; use nvml_wrapper::Nvml; -use psyche_coordinator::model; -use psyche_data_provider::{download_parallelism_data_from_gcs_signed, RunDownClient}; +use psyche_data_provider::{RunDownClient, download_parallelism_data_from_gcs_signed}; use serde::Deserialize; use std::collections::HashMap; use std::sync::Arc; @@ -14,17 +13,9 @@ pub struct ParallelismConfig { pub micro_batch_size: usize, } -// Table format: gpu_type -> num_gpus -> config type Table = HashMap>; -/// Auto-detect parallelism settings by downloading parallelism_data.json -/// from GCS (via signed URLs) or HuggingFace, then looking up the config -/// for the detected GPU type and count. -pub async fn lookup( - checkpoint: &model::Checkpoint, - run_down_client: Option<&Arc>, - hub_read_token: Option<&str>, -) -> Result { +pub async fn lookup(run_down_client: &Arc) -> Result { let device_count = tch::Cuda::device_count() as usize; if device_count == 0 { anyhow::bail!("No GPUs found for parallelism auto-detection"); @@ -33,7 +24,14 @@ pub async fn lookup( let gpu_type = normalize_gpu_name(&get_gpu_type_from_nvml()?); info!("Detected {} x {} GPU(s)", device_count, gpu_type); - let json = download_parallelism_data(checkpoint, run_down_client, hub_read_token).await?; + info!( + "Fetching parallelism_data.json from GCS via run-down signed URLs for run {}", + run_down_client.run_id() + ); + let json = download_parallelism_data_from_gcs_signed(run_down_client) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + let table: Table = serde_json::from_str(&json).context("Failed to parse parallelism_data.json")?; @@ -54,68 +52,11 @@ fn normalize_gpu_name(raw_name: &str) -> String { "H200".to_string() } else if upper.contains("H100") { "H100".to_string() - } else if upper.contains("A100") { - "A100".to_string() - } else if upper.contains("L40S") { - "L40S".to_string() - } else if upper.contains("L40") { - "L40".to_string() - } else if upper.contains("4090") { - "RTX4090".to_string() - } else if upper.contains("3090") { - "RTX3090".to_string() } else { raw_name.to_string() } } -async fn download_parallelism_data( - checkpoint: &model::Checkpoint, - run_down_client: Option<&Arc>, - hub_read_token: Option<&str>, -) -> Result { - match checkpoint { - model::Checkpoint::Gcs(_) | model::Checkpoint::P2PGcs(_) => { - let client = run_down_client - .ok_or_else(|| anyhow::anyhow!("RunDownClient required for GCS parallelism lookup"))?; - info!( - "Fetching parallelism_data.json from GCS via run-down signed URLs for run {}", - client.run_id() - ); - download_parallelism_data_from_gcs_signed(client) - .await - .map_err(|e| anyhow::anyhow!("{}", e)) - } - model::Checkpoint::Hub(hub_repo) | model::Checkpoint::P2P(hub_repo) => { - let repo_id: String = (&hub_repo.repo_id).into(); - info!( - "Fetching parallelism_data.json from HuggingFace repo '{}'", - repo_id - ); - download_from_hub(&repo_id, hub_read_token).await - } - _ => anyhow::bail!("Parallelism auto-detection requires Hub or GCS checkpoint type"), - } -} - -async fn download_from_hub(repo_id: &str, token: Option<&str>) -> Result { - let mut builder = hf_hub::api::tokio::ApiBuilder::new(); - if let Some(token) = token { - builder = builder.with_token(Some(token.to_string())); - } - let api = builder.build()?; - let repo = api.model(repo_id.to_string()); - let path = repo.get("parallelism_data.json").await.with_context(|| { - format!( - "parallelism_data.json not found in HuggingFace repo '{}'", - repo_id - ) - })?; - tokio::fs::read_to_string(path) - .await - .context("Failed to read parallelism_data.json") -} - fn lookup_in_table(table: &Table, gpu_type: &str, num_gpus: usize) -> Result { let gpu_configs = table.get(gpu_type).ok_or_else(|| { anyhow::anyhow!( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 02cda12e8..df3ce6e44 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -223,13 +223,19 @@ impl RunInitConfigAndIO Result {