diff --git a/.gitignore b/.gitignore index c82f9d9b6..771158424 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,8 @@ target/ ### Python ### __pycache__ +.uv-python/ +.uv-cache/ ### Windows ### Thumbs.db @@ -42,6 +44,7 @@ slurm* node_modules/ bindings .pre-commit-config.yaml +.claude/ ### Data folders ### /data diff --git a/Cargo.lock b/Cargo.lock index 0312196d1..5460ffe10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7170,6 +7170,7 @@ dependencies = [ "clap-markdown", "plotters", "psyche-coordinator", + "psyche-core", "serde", "toml 0.8.23", ] @@ -7472,14 +7473,14 @@ name = "psyche-coordinator" version = "0.2.0" dependencies = [ "anchor-lang", - "anyhow", - "async-trait", + "anchor-lang-idl", + "bitvec", "bytemuck", - "cfg_eval", - "psyche-core", + "data-encoding", + "fnv", + "postcard", + "rand 0.9.2", "serde", - "serde_json", - "serde_with", "ts-rs", ] @@ -7491,15 +7492,14 @@ dependencies = [ "anchor-lang-idl", "anyhow", "approx", - "bitvec", "bytemuck", - "data-encoding", "fast-math", - "fnv", "postcard", + "psyche-coordinator", "rand 0.9.2", "serde", "serde_arrays", + "serde_json", "sha2 0.10.9", "ts-rs", ] @@ -7575,6 +7575,8 @@ dependencies = [ name = "psyche-deserialize-zerocopy-wasm" version = "0.2.0" dependencies = [ + "anchor-lang", + "postcard", "psyche-coordinator", "psyche-core", "psyche-solana-coordinator", @@ -7887,6 +7889,7 @@ dependencies = [ "async-trait", "backon", "futures-util", + "postcard", "psyche-coordinator", "psyche-core", "psyche-event-sourcing", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 479d6628e..d8a075331 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,13 +2,13 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use psyche_centralized_shared::{ClientToServerMessage, ServerToClientMessage}; use psyche_client::{ - CheckpointUploader, Client, ClientTUI, ClientTUIState, ModelExtraData, NC, RunInitConfig, - TrainArgs, read_identity_secret_key, + CheckpointUploader, Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; -use psyche_coordinator::model::{self, CheckpointSource}; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_coordinator::{Coordinator, HealthChecks}; -use psyche_core::NodeIdentity; +use psyche_coordinator::coordinator::{Coordinator, HealthChecks}; +use psyche_coordinator::model::{self, CheckpointBytes, CheckpointSource}; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::{CheckpointData, ModelExtraData}; use psyche_event_sourcing::event; use psyche_event_sourcing::events::RpcCallType; use psyche_metrics::ClientMetrics; @@ -69,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::CheckpointBytes) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: CheckpointBytes) -> Result<()> { self.tx.send(ToSend::Checkpoint(Box::new(checkpoint)))?; Ok(()) } @@ -211,7 +211,7 @@ impl App { // Validate upload credentials now that we have the coordinator state with checkpoint info. if !state_options.checkpoint_config.skip_upload { - let model::Model::LLM(ref llm) = first_coordinator_state.model; + let llm = first_coordinator_state.model; if llm.checkpoint_source != CheckpointSource::Ephemeral { match CheckpointData::from_fixed_vec(&llm.checkpoint_data) { Ok(CheckpointData::Hub { ref repo_id, .. }) => { diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 040458ec5..1eab68897 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -1,14 +1,15 @@ +use crate::dashboard::{DashboardState, DashboardTui}; use anyhow::{Result, bail}; use async_trait::async_trait; use psyche_centralized_shared::{ClientToServerMessage, ServerToClientMessage}; -use psyche_coordinator::model::{self, CheckpointSource, Model}; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_coordinator::{ +use psyche_coordinator::coordinator::{ Client, ClientState, Coordinator, CoordinatorError, HealthChecks, Round, RunState, SOLANA_MAX_NUM_CLIENTS, TickResult, }; - -use psyche_core::{FixedVec, NodeIdentity, Shuffle, SizedIterator, TokenSize}; +use psyche_coordinator::fixed_vec::FixedVec; +use psyche_coordinator::model::{CheckpointBytes, CheckpointSource}; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::{CheckpointData, Shuffle, SizedIterator, TokenSize}; use psyche_data_provider::{ DataProviderTcpServer, DataServerTui, LocalDataProvider, download_model_from_gcs_async, download_model_repo_async, @@ -34,8 +35,6 @@ use tokio::{select, time::Interval}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, info, info_span, warn}; -use crate::dashboard::{DashboardState, DashboardTui}; - pub(super) type TabWidgetTypes = ( DashboardTui, CoordinatorTui, @@ -83,7 +82,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::CheckpointBytes) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: CheckpointBytes) -> Result<()> { bail!("Server does not send checkpoints"); } } @@ -137,9 +136,7 @@ impl App { } pub fn get_checkpoint(&self) -> CheckpointSource { - match self.coordinator.model { - Model::LLM(llm) => llm.checkpoint_source, - } + self.coordinator.model.checkpoint_source } pub fn get_port(&self) -> u16 { @@ -191,7 +188,7 @@ impl App { }) = data_server_config { // Download model if needed based on checkpoint type - let Model::LLM(llm) = &coordinator.model; + let llm = &coordinator.model; if llm.checkpoint_source == CheckpointSource::Ephemeral { bail!("Can't start up a run with an Ephemeral checkpoint.") } diff --git a/architectures/centralized/server/src/main.rs b/architectures/centralized/server/src/main.rs index 2a0839abe..5576fce11 100644 --- a/architectures/centralized/server/src/main.rs +++ b/architectures/centralized/server/src/main.rs @@ -4,7 +4,7 @@ mod dashboard; use anyhow::{Context, Result, bail}; use app::{App, DataServerInfo}; use clap::{ArgAction, Parser}; -use psyche_coordinator::Coordinator; +use psyche_coordinator::coordinator::Coordinator; use psyche_tui::{ LogOutput, ServiceInfo, logging::{MetricsDestination, OpenTelemetry, RemoteLogsDestination, TraceDestination}, diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index bbd92b9b0..7ebd36a2a 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -1,4 +1,7 @@ -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_coordinator::{ + coordinator::{Coordinator, HealthChecks}, + model, +}; use psyche_watcher::OpportunisticData; use serde::{Deserialize, Serialize}; diff --git a/architectures/centralized/testing/src/server.rs b/architectures/centralized/testing/src/server.rs index 598538a4a..72f056a4c 100644 --- a/architectures/centralized/testing/src/server.rs +++ b/architectures/centralized/testing/src/server.rs @@ -2,12 +2,14 @@ use crate::{COOLDOWN_TIME, test_utils::sample_rand_run_id}; use crate::{MAX_ROUND_TRAIN_TIME, ROUND_WITNESS_TIME, WARMUP_TIME}; use bytemuck::Zeroable; use psyche_centralized_server::app::App as ServerApp; -use psyche_coordinator::{Client, Round}; -use psyche_coordinator::{ - Coordinator, CoordinatorConfig, CoordinatorEpochState, RunState, SOLANA_MAX_NUM_CLIENTS, - model::{CheckpointSource, LLM, Model}, +use psyche_coordinator::coordinator::{ + Client, Coordinator, CoordinatorConfig, CoordinatorEpochState, Round, RunState, + SOLANA_MAX_NUM_CLIENTS, }; -use psyche_core::{FixedVec, NodeIdentity}; +use psyche_coordinator::fixed_vec::FixedVec; +use psyche_coordinator::model::{CheckpointSource, Model}; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::CheckpointData; use std::{collections::HashSet, ops::ControlFlow}; use tokio::{ select, @@ -90,7 +92,7 @@ impl CoordinatorServer { let run_id = sample_rand_run_id(); let coordinator: Coordinator = Coordinator { run_id: run_id.as_str().try_into().unwrap(), - model: Model::LLM(LLM::dummy()), + model: Model::dummy(CheckpointData::Dummy.to_fixed_vec()), config: coordinator_config, epoch_state, ..Coordinator::zeroed() diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index ae835da36..6e41b639f 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -5,6 +5,7 @@ use crate::client::ClientHandle; use crate::server::CoordinatorServerHandle; use clap::Parser; use psyche_client::TrainArgs; +use psyche_coordinator::coordinator::Coordinator; use rand::distr::{Alphanumeric, SampleString}; use std::env; use tokio_util::sync::CancellationToken; @@ -98,9 +99,7 @@ pub async fn assert_witnesses_healthy_score( // calculate score let mut score = 0; clients.iter().for_each(|client| { - score += psyche_coordinator::Coordinator::trainer_healthy_score_by_witnesses( - &client.id, witnesses, - ); + score += Coordinator::trainer_healthy_score_by_witnesses(&client.id, witnesses); }); assert_eq!( diff --git a/architectures/centralized/testing/tests/integration_tests.rs b/architectures/centralized/testing/tests/integration_tests.rs index 256e10825..a9da35106 100644 --- a/architectures/centralized/testing/tests/integration_tests.rs +++ b/architectures/centralized/testing/tests/integration_tests.rs @@ -9,7 +9,7 @@ use psyche_centralized_testing::{ spawn_clients_with_training_delay, }, }; -use psyche_coordinator::{RunState, model::CheckpointSource}; +use psyche_coordinator::{coordinator::RunState, model::CheckpointSource}; use tracing::info; #[test_log::test(tokio::test(flavor = "multi_thread"))] diff --git a/architectures/decentralized/solana-client/Cargo.toml b/architectures/decentralized/solana-client/Cargo.toml index 95d09dbc4..35f6a7afc 100644 --- a/architectures/decentralized/solana-client/Cargo.toml +++ b/architectures/decentralized/solana-client/Cargo.toml @@ -19,7 +19,7 @@ psyche-metrics.workspace = true psyche-modeling.workspace = true psyche-network.workspace = true psyche-solana-rpc.workspace = true -psyche-solana-coordinator.workspace = true +psyche-solana-coordinator = { workspace = true, features = ["client"] } psyche-tui.workspace = true psyche-watcher.workspace = true rand.workspace = true diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 95947d17e..128c33f7e 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -1,3 +1,4 @@ +use psyche_core::{CheckpointData, ModelExtraData}; use psyche_solana_rpc::SolanaBackend; use anchor_client::{ @@ -10,15 +11,15 @@ use anchor_client::{ }; use anyhow::{Result, anyhow}; use psyche_client::{ - CheckpointUploader, Client, ClientTUI, ClientTUIState, ModelExtraData, NC, RunInitConfig, - TrainArgs, read_identity_secret_key, + CheckpointUploader, Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; use psyche_coordinator::{ - ClientState, Coordinator, CoordinatorError, RunState, - model::{CheckpointSource, Model}, - model_extra_data::CheckpointData, + coordinator::{ClientState, Coordinator, CoordinatorError, RunState}, + model::CheckpointSource, + node_identity::NodeIdentity, + sha::sha256, }; -use psyche_core::sha256; use psyche_metrics::ClientMetrics; use psyche_network::{DiscoveryMode, NetworkTUIState, NetworkTui, SecretKey, allowlist}; @@ -89,7 +90,7 @@ pub async fn build_app( let mut rng = ChaCha20Rng::from_seed(sha256(&seed_preimage)); SecretKey::generate(&mut rng) }); - let identity = psyche_core::NodeIdentity::new( + let identity = NodeIdentity::new( wallet_keypair.pubkey().to_bytes(), *identity_secret_key.public().as_bytes(), ); @@ -239,7 +240,7 @@ impl App { // sanity checks — skip credential validation when checkpoint upload is disabled if !self.state_options.checkpoint_config.skip_upload { - let Model::LLM(ref llm) = start_coordinator_state.model; + let llm = start_coordinator_state.model; if llm.checkpoint_source != CheckpointSource::Ephemeral { match CheckpointData::from_fixed_vec(&llm.checkpoint_data) { Ok(CheckpointData::Hub { ref repo_id, .. }) => { @@ -269,7 +270,7 @@ impl App { .join_run( coordinator_instance_pubkey, coordinator_account, - psyche_core::NodeIdentity::new(signer.to_bytes(), *p2p_identity.as_bytes()), + NodeIdentity::new(signer.to_bytes(), *p2p_identity.as_bytes()), self.authorizer, ) .await?; @@ -301,7 +302,7 @@ impl App { self.metrics, ); - let id = psyche_core::NodeIdentity::new(signer.to_bytes(), *p2p_identity.as_bytes()); + let id = NodeIdentity::new(signer.to_bytes(), *p2p_identity.as_bytes()); loop { select! { @@ -328,7 +329,7 @@ impl App { None }; - let pending_clients_ids: Option> = coordinator_state_in_waiting_for_members + let pending_clients_ids: Option> = coordinator_state_in_waiting_for_members .as_ref() .map(|state| state.clients_state.get_active_clients_ids().collect()); diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index a109d7f79..717871495 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -12,8 +12,8 @@ use anchor_client::{ use anyhow::{Result, bail}; use clap::{Args, Parser, Subcommand}; use psyche_client::{TrainArgs, print_identity_keys}; -use psyche_coordinator::model::{CheckpointSource, Model}; -use psyche_coordinator::model_extra_data::CheckpointData; +use psyche_coordinator::model::CheckpointSource; +use psyche_core::CheckpointData; use psyche_event_sourcing::{EventStore, FileBackend, RunStarted}; use psyche_network::SecretKey; use psyche_solana_rpc::SolanaBackend; @@ -295,10 +295,7 @@ async fn async_main() -> Result<()> { .coordinator; if model { - #[allow(irrefutable_let_patterns)] - let Model::LLM(model_config) = coordinator_account_state.model else { - bail!("Model is not an LLM, unsure how to predownload."); - }; + let model_config = coordinator_account_state.model; if model_config.checkpoint_source == CheckpointSource::Ephemeral { bail!("Can't predownload model with ephemeral checkpoint.") diff --git a/architectures/decentralized/solana-common/Cargo.toml b/architectures/decentralized/solana-common/Cargo.toml index 2b05524f2..091cab8d4 100644 --- a/architectures/decentralized/solana-common/Cargo.toml +++ b/architectures/decentralized/solana-common/Cargo.toml @@ -13,11 +13,12 @@ psyche-coordinator.workspace = true psyche-core.workspace = true psyche-event-sourcing.workspace = true psyche-solana-authorizer.workspace = true -psyche-solana-coordinator.workspace = true +psyche-solana-coordinator = { workspace = true, features = ["client"] } psyche-solana-treasurer.workspace = true psyche-watcher.workspace = true tokio.workspace = true tracing.workspace = true backon = "1.4.1" +postcard.workspace = true solana-account-decoder-client-types = "=2.1.4" solana-transaction-status-client-types = "=2.1.4" diff --git a/architectures/decentralized/solana-common/src/backend.rs b/architectures/decentralized/solana-common/src/backend.rs index cac841d0d..61eaa3ad6 100644 --- a/architectures/decentralized/solana-common/src/backend.rs +++ b/architectures/decentralized/solana-common/src/backend.rs @@ -1,6 +1,6 @@ use crate::instructions::{self, coordinator_tick}; use crate::retry::{RetryError, retry_function_with_params}; -use anchor_client::anchor_lang::{AccountDeserialize, AnchorSerialize}; +use anchor_client::anchor_lang::AccountDeserialize; use anchor_client::solana_sdk::hash::hash; use anchor_client::solana_sdk::instruction::Instruction; use anchor_client::solana_sdk::program_pack::Pack; @@ -20,8 +20,10 @@ use anchor_client::{ }; use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; -use psyche_coordinator::model::{self, CheckpointBytes}; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; +use psyche_coordinator::coordinator::{Coordinator, HealthChecks}; +use psyche_coordinator::model::CheckpointBytes; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_coordinator::types::CommitteeProof; use psyche_core::IntegrationTestLogMarker; use psyche_event_sourcing::event; use psyche_event_sourcing::events::RpcCallType; @@ -183,7 +185,7 @@ impl SolanaBackend { self, run_id: String, coordinator_account: Pubkey, - ) -> Result { + ) -> anyhow::Result { let (tx_update, rx_update) = broadcast::channel(32); let commitment_config = self.get_commitment_config(); @@ -274,7 +276,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - id: psyche_core::NodeIdentity, + id: NodeIdentity, authorizer: Option, ) -> Result { let coordinator_instance_state = @@ -336,8 +338,7 @@ impl SolanaBackend { &coordinator_account, &user, witness, - AnchorSerialize::try_to_vec(&metadata) - .expect("failed to serialize WitnessMetadata"), + postcard::to_stdvec(&metadata).expect("failed to serialize WitnessMetadata"), ), RpcCallType::Witness, ), @@ -367,7 +368,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - id: psyche_core::NodeIdentity, + id: NodeIdentity, check: CommitteeProof, ) { let user = self.get_payer(); @@ -713,7 +714,7 @@ impl WatcherBackend for SolanaBackendRunner { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::CheckpointBytes) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: CheckpointBytes) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-common/src/instructions.rs b/architectures/decentralized/solana-common/src/instructions.rs index 206de72b6..85fd63abf 100644 --- a/architectures/decentralized/solana-common/src/instructions.rs +++ b/architectures/decentralized/solana-common/src/instructions.rs @@ -5,6 +5,12 @@ use anchor_client::solana_sdk::instruction::Instruction; use anchor_client::solana_sdk::pubkey::Pubkey; use anchor_spl::associated_token; use anchor_spl::token; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::CoordinatorProgress; +use psyche_coordinator::coordinator::Witness; +use psyche_coordinator::model::Model; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_coordinator::types::CommitteeProof; pub fn coordinator_init_coordinator( payer: &Pubkey, @@ -57,9 +63,9 @@ pub fn coordinator_update( run_id: &str, coordinator_account: &Pubkey, main_authority: &Pubkey, - config: Option, - model: Option, - progress: Option, + config: Option, + model: Option, + progress: Option, ) -> Instruction { let coordinator_instance = psyche_solana_coordinator::find_coordinator_instance(run_id); anchor_instruction( @@ -99,7 +105,7 @@ pub fn coordinator_join_run( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, authorization: &Pubkey, - client_id: psyche_core::NodeIdentity, + client_id: NodeIdentity, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, @@ -135,7 +141,7 @@ pub fn coordinator_witness( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - witness: psyche_coordinator::Witness, + witness: Witness, metadata: Vec, ) -> Instruction { anchor_instruction( @@ -159,7 +165,7 @@ pub fn coordinator_warmup_witness( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - witness: psyche_coordinator::Witness, + witness: Witness, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, @@ -181,7 +187,7 @@ pub fn coordinator_cooldown_witness( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - witness: psyche_coordinator::Witness, + witness: Witness, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, @@ -203,8 +209,8 @@ pub fn coordinator_health_check( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - client_id: psyche_core::NodeIdentity, - check: psyche_coordinator::CommitteeProof, + client_id: NodeIdentity, + check: CommitteeProof, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/Cargo.lock b/architectures/decentralized/solana-coordinator/Cargo.lock index 4be8d3ae5..901de87b1 100644 --- a/architectures/decentralized/solana-coordinator/Cargo.lock +++ b/architectures/decentralized/solana-coordinator/Cargo.lock @@ -239,21 +239,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anyhow" version = "1.0.98" @@ -389,17 +374,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "async-trait" -version = "0.1.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.90", -] - [[package]] name = "atomic-polyfill" version = "1.0.3" @@ -662,11 +636,7 @@ version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ - "android-tzdata", - "iana-time-zone", "num-traits", - "serde", - "windows-targets", ] [[package]] @@ -711,12 +681,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - [[package]] name = "cpufeatures" version = "0.2.16" @@ -851,16 +815,6 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", - "serde", -] - [[package]] name = "derivation-path" version = "0.2.0" @@ -1057,6 +1011,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1066,12 +1032,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.13.2" @@ -1110,12 +1070,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "hmac" version = "0.8.1" @@ -1146,29 +1100,6 @@ dependencies = [ "hmac 0.8.1", ] -[[package]] -name = "iana-time-zone" -version = "0.1.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -1181,17 +1112,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9007da9cacbd3e6343da136e98b0d2df013f553d35bdec8b518f07bea768e19c" -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", - "serde", -] - [[package]] name = "indexmap" version = "2.7.0" @@ -1200,7 +1120,6 @@ checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown 0.15.2", - "serde", ] [[package]] @@ -1377,12 +1296,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - [[package]] name = "num-derive" version = "0.4.2" @@ -1558,12 +1471,6 @@ dependencies = [ "serde", ] -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1605,14 +1512,13 @@ name = "psyche-coordinator" version = "0.2.0" dependencies = [ "anchor-lang", - "anyhow", - "async-trait", + "anchor-lang-idl", + "bitvec", "bytemuck", - "cfg_eval", - "psyche-core", + "data-encoding", + "fnv", + "rand 0.9.2", "serde", - "serde_json", - "serde_with", "ts-rs", ] @@ -1623,12 +1529,10 @@ dependencies = [ "anchor-lang", "anchor-lang-idl", "anyhow", - "bitvec", "bytemuck", - "data-encoding", "fast-math", - "fnv", "postcard", + "psyche-coordinator", "serde", "serde_arrays", "sha2 0.10.8", @@ -1674,6 +1578,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -1704,6 +1614,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + [[package]] name = "rand_chacha" version = "0.2.2" @@ -1724,6 +1644,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + [[package]] name = "rand_core" version = "0.5.1" @@ -1742,6 +1672,15 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "rand_hc" version = "0.2.0" @@ -1881,16 +1820,9 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64 0.22.1", - "chrono", - "hex", - "indexmap 1.9.3", - "indexmap 2.7.0", "serde", "serde_derive", - "serde_json", "serde_with_macros", - "time", ] [[package]] @@ -3078,37 +3010,6 @@ dependencies = [ "syn 2.0.90", ] -[[package]] -name = "time" -version = "0.3.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" -dependencies = [ - "deranged", - "itoa", - "num-conv", - "powerfmt", - "serde", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - -[[package]] -name = "time-macros" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" -dependencies = [ - "num-conv", - "time-core", -] - [[package]] name = "tinyvec" version = "1.8.0" @@ -3160,7 +3061,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap", "serde", "serde_spanned", "toml_datetime", @@ -3253,6 +3154,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.97" @@ -3327,15 +3237,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets", -] - [[package]] name = "windows-sys" version = "0.59.0" @@ -3418,6 +3319,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + [[package]] name = "wyz" version = "0.5.1" diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/Cargo.toml b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/Cargo.toml index c540da791..2318dfd4e 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/Cargo.toml +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/Cargo.toml @@ -16,6 +16,7 @@ no-idl = [] no-log-ix-name = [] idl-build = ["anchor-lang/idl-build"] anchor-debug = [] +client = ["psyche-coordinator/client", "dep:serde", "dep:ts-rs"] [lints.rust] unexpected_cfgs = { level = "allow", check-cfg = [ @@ -32,10 +33,11 @@ bytemuck = { version = "1", features = [ ] } psyche-core = { path = "../../../../../shared/core" } psyche-coordinator = { path = "../../../../../shared/coordinator" } -serde = { version = "1.0.209", features = ["derive"] } psyche-solana-authorizer = { path = "../../../solana-authorizer/programs/solana-authorizer", features = [ "cpi", ] } -ts-rs = { git = "https://github.com/arilotter/ts-rs.git", rev = "20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34", features = [ + +ts-rs = { optional = true, git = "https://github.com/arilotter/ts-rs.git", rev = "20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34", features = [ "psyche-impl", -] } +] } # client only deps +serde = { optional = true, version = "1.0.209", features = ["derive"] } diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs index e47ebffb4..5c01e55f3 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs @@ -3,10 +3,7 @@ use std::fmt::Debug; use anchor_lang::prelude::*; use bytemuck::Pod; use bytemuck::Zeroable; -use psyche_core::NodeIdentity; -use serde::Deserialize; -use serde::Serialize; -use ts_rs::TS; +use psyche_coordinator::node_identity::NodeIdentity; #[derive( Clone, @@ -17,12 +14,13 @@ use ts_rs::TS; Pod, AnchorSerialize, AnchorDeserialize, - Serialize, - Deserialize, - TS, )] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) +)] +#[cfg_attr(feature = "client", ts(rename = "SolanaClient"))] #[repr(C)] -#[ts(rename = "SolanaClient")] pub struct Client { pub id: NodeIdentity, pub _unused: [u8; 8], diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs index b4b1da544..b8bfc6427 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs @@ -1,27 +1,18 @@ use anchor_lang::prelude::*; use bytemuck::Pod; use bytemuck::Zeroable; -use psyche_core::FixedVec; -use psyche_core::NodeIdentity; +use psyche_coordinator::fixed_vec::FixedVec; +use psyche_coordinator::node_identity::NodeIdentity; use psyche_core::SizedIterator; -use serde::Deserialize; -use serde::Serialize; -use ts_rs::TS; use crate::SOLANA_MAX_NUM_PENDING_CLIENTS; use crate::client::Client; use crate::program_error::ProgramError; -#[derive( - Debug, - Clone, - Copy, - Zeroable, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, +#[derive(Debug, Clone, Copy, Zeroable, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct ClientsState { @@ -31,16 +22,10 @@ pub struct ClientsState { pub future_epoch_rates: ClientsEpochRates, } -#[derive( - Debug, - Clone, - Copy, - Zeroable, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, +#[derive(Debug, Clone, Copy, Zeroable, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct ClientsEpochRates { diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index c2b3ff83c..20a39204a 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -1,38 +1,29 @@ use anchor_lang::prelude::*; use bytemuck::Pod; use bytemuck::Zeroable; -use psyche_coordinator::ClientState; -use psyche_coordinator::Coordinator; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::CoordinatorProgress; -use psyche_coordinator::HealthChecks; -use psyche_coordinator::RunState; -use psyche_coordinator::TickResult; -use psyche_coordinator::Witness; +use psyche_coordinator::coordinator::ClientState; +use psyche_coordinator::coordinator::Coordinator; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::CoordinatorProgress; +use psyche_coordinator::coordinator::HealthChecks; +use psyche_coordinator::coordinator::RunState; +use psyche_coordinator::coordinator::TickResult; +use psyche_coordinator::coordinator::Witness; +use psyche_coordinator::fixed_string::FixedString; use psyche_coordinator::model::CheckpointBytes; use psyche_coordinator::model::Model; -use psyche_core::FixedString; -use psyche_core::NodeIdentity; -use psyche_core::SmallBoolean; -use psyche_core::sha256v; -use serde::Deserialize; -use serde::Serialize; -use ts_rs::TS; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_coordinator::sha::sha256v; +use psyche_coordinator::small_boolean::SmallBoolean; use crate::ProgramError; use crate::client::Client; use crate::clients_state::ClientsState; -#[derive( - Debug, - Clone, - Copy, - Zeroable, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, +#[derive(Clone, Copy, Zeroable, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct CoordinatorInstanceState { @@ -40,10 +31,14 @@ pub struct CoordinatorInstanceState { pub clients_state: ClientsState, pub is_warmup_first_tick: SmallBoolean, pub is_training_first_tick: SmallBoolean, + + // TODO move this into run extra? doesn't need to be on-chain. pub client_version: FixedString<96>, } -unsafe impl Pod for CoordinatorInstanceState {} +unsafe impl Pod for CoordinatorInstanceState { + // NOT SAFE. DELETE ME. +} impl CoordinatorInstanceState { fn get_random_seed(clock: &Clock) -> u64 { diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index fdb596d39..3f8725d8e 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -10,33 +10,21 @@ pub use client::Client; pub use instance_state::CoordinatorInstanceState; use logic::*; pub use program_error::ProgramError; -use psyche_coordinator::Committee; -use psyche_coordinator::CommitteeProof; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::CoordinatorProgress; -use psyche_coordinator::SOLANA_MAX_NUM_CLIENTS; -use psyche_coordinator::SOLANA_MAX_STRING_LEN; -use psyche_coordinator::Witness; -use psyche_coordinator::WitnessBloom; -use psyche_coordinator::WitnessProof; +use psyche_coordinator::coordinator::SOLANA_MAX_NUM_CLIENTS; +use psyche_coordinator::coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::model::Model; -use psyche_core::MerkleRoot; -use psyche_core::NodeIdentity; -use serde::Deserialize; -use serde::Serialize; -use ts_rs::TS; declare_id!("4SHugWqSXwKE5fqDchkJcPEqnoZE22VYKtSTVm7axbT7"); pub const SOLANA_MAX_NUM_PENDING_CLIENTS: usize = SOLANA_MAX_NUM_CLIENTS; -pub fn bytes_from_string(str: &str) -> &[u8] { +pub fn bytes_from_str(str: &str) -> &[u8] { &str.as_bytes()[..SOLANA_MAX_STRING_LEN.min(str.len())] } pub fn find_coordinator_instance(run_id: &str) -> Pubkey { Pubkey::find_program_address( - &[CoordinatorInstance::SEEDS_PREFIX, bytes_from_string(run_id)], + &[CoordinatorInstance::SEEDS_PREFIX, bytes_from_str(run_id)], &crate::ID, ) .0 @@ -122,7 +110,10 @@ pub fn coordinator_account_from_bytes_mut( #[account(zero_copy)] #[repr(C)] -#[derive(Serialize, Deserialize, TS)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) +)] pub struct CoordinatorAccount { pub version: u64, pub state: CoordinatorInstanceState, @@ -158,11 +149,21 @@ impl CoordinatorInstance { pub const SEEDS_PREFIX: &'static [u8] = b"coordinator"; } +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::CoordinatorProgress; +use psyche_coordinator::coordinator::Witness; +use psyche_coordinator::coordinator::WitnessBloom; +use psyche_coordinator::fixed_string::FixedString; +use psyche_coordinator::hash_wrapper::HashWrapper; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_coordinator::types::Committee; +use psyche_coordinator::types::CommitteeProof; +use psyche_coordinator::types::WitnessProof; + #[program] pub mod psyche_solana_coordinator { use super::*; - use psyche_core::FixedString; pub fn init_coordinator( context: Context, @@ -248,7 +249,7 @@ pub mod psyche_solana_coordinator { proof: WitnessProof, participant_bloom: WitnessBloom, broadcast_bloom: WitnessBloom, - broadcast_merkle: MerkleRoot, + broadcast_merkle: HashWrapper, metadata: Vec, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; @@ -269,7 +270,7 @@ pub mod psyche_solana_coordinator { proof: WitnessProof, participant_bloom: WitnessBloom, broadcast_bloom: WitnessBloom, - broadcast_merkle: MerkleRoot, + broadcast_merkle: HashWrapper, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); @@ -289,7 +290,7 @@ pub mod psyche_solana_coordinator { proof: WitnessProof, participant_bloom: WitnessBloom, broadcast_bloom: WitnessBloom, - broadcast_merkle: MerkleRoot, + broadcast_merkle: HashWrapper, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); @@ -344,7 +345,7 @@ pub struct OwnerCoordinatorAccounts<'info> { #[account( seeds = [ CoordinatorInstance::SEEDS_PREFIX, - bytes_from_string(&coordinator_instance.run_id) + &bytes_from_str(&coordinator_instance.run_id) ], bump = coordinator_instance.bump, constraint = coordinator_instance.main_authority == authority.key() @@ -367,7 +368,7 @@ pub struct PermissionlessCoordinatorAccounts<'info> { #[account( seeds = [ CoordinatorInstance::SEEDS_PREFIX, - bytes_from_string(&coordinator_instance.run_id) + &bytes_from_str(&coordinator_instance.run_id) ], bump = coordinator_instance.bump )] diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/free_coordinator.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/free_coordinator.rs index bb1ef770b..185d195c8 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/free_coordinator.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/free_coordinator.rs @@ -2,7 +2,7 @@ use anchor_lang::prelude::*; use crate::CoordinatorAccount; use crate::CoordinatorInstance; -use crate::bytes_from_string; +use crate::bytes_from_str; use crate::program_error::ProgramError; #[derive(Accounts)] @@ -19,7 +19,7 @@ pub struct FreeCoordinatorAccounts<'info> { mut, seeds = [ CoordinatorInstance::SEEDS_PREFIX, - bytes_from_string(&coordinator_instance.run_id) + &bytes_from_str(&coordinator_instance.run_id) ], bump = coordinator_instance.bump, constraint = coordinator_instance.main_authority == authority.key(), diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/init_coordinator.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/init_coordinator.rs index 5bf9195ae..86985da36 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/init_coordinator.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/init_coordinator.rs @@ -1,11 +1,11 @@ use anchor_lang::prelude::*; -use psyche_coordinator::SOLANA_RUN_ID_MAX_LEN; -use psyche_core::FixedString; +use psyche_coordinator::coordinator::SOLANA_RUN_ID_MAX_LEN; +use psyche_coordinator::fixed_string::FixedString; use crate::CoordinatorAccount; use crate::CoordinatorInstance; use crate::ProgramError; -use crate::bytes_from_string; +use crate::bytes_from_str; #[derive(Accounts)] #[instruction(params: InitCoordinatorParams)] @@ -19,7 +19,7 @@ pub struct InitCoordinatorAccounts<'info> { space = 8 + CoordinatorInstance::INIT_SPACE, seeds = [ CoordinatorInstance::SEEDS_PREFIX, - bytes_from_string(¶ms.run_id) + &bytes_from_str(¶ms.run_id) ], bump )] diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs index 201a5f6e4..f382254bd 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs @@ -1,10 +1,10 @@ use anchor_lang::prelude::*; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; use psyche_solana_authorizer::state::Authorization; use crate::CoordinatorAccount; use crate::CoordinatorInstance; -use crate::bytes_from_string; +use crate::bytes_from_str; use crate::program_error::ProgramError; pub const JOIN_RUN_AUTHORIZATION_SCOPE: &[u8] = b"CoordinatorJoinRun"; @@ -27,7 +27,7 @@ pub struct JoinRunAccounts<'info> { #[account( seeds = [ CoordinatorInstance::SEEDS_PREFIX, - bytes_from_string(&coordinator_instance.run_id) + &bytes_from_str(&coordinator_instance.run_id) ], bump = coordinator_instance.bump, )] diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/program_error.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/program_error.rs index 1554e6802..41aff6378 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/program_error.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/program_error.rs @@ -1,5 +1,5 @@ use anchor_lang::prelude::*; -use psyche_coordinator::CoordinatorError; +use psyche_coordinator::coordinator::CoordinatorError; #[error_code] pub enum ProgramError { diff --git a/architectures/decentralized/solana-tooling/Cargo.toml b/architectures/decentralized/solana-tooling/Cargo.toml index 8599afb20..dce3012c0 100644 --- a/architectures/decentralized/solana-tooling/Cargo.toml +++ b/architectures/decentralized/solana-tooling/Cargo.toml @@ -13,13 +13,16 @@ anchor-lang = { git = "https://github.com/coral-xyz/anchor.git", rev = "a7a23eea anchor-spl = { git = "https://github.com/coral-xyz/anchor.git", rev = "a7a23eea308440a9fa9cb79cee7bddd30ab163d5" } psyche-core = { path = "../../../shared/core" } -psyche-coordinator = { path = "../../../shared/coordinator" } +psyche-coordinator = { path = "../../../shared/coordinator", features = [ + "client", +] } psyche-solana-authorizer = { path = "../solana-authorizer/programs/solana-authorizer", features = [ "cpi", ] } psyche-solana-coordinator = { path = "../solana-coordinator/programs/solana-coordinator", features = [ "cpi", + "client", ] } psyche-solana-treasurer = { path = "../solana-treasurer/programs/solana-treasurer", features = [ "cpi", diff --git a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs index 25e0aaa18..ce3ab837c 100644 --- a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs +++ b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs @@ -1,10 +1,10 @@ use anchor_lang::InstructionData; use anchor_lang::ToAccountMetas; use anyhow::Result; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::CoordinatorProgress; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::CoordinatorProgress; use psyche_coordinator::model::Model; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; use psyche_solana_coordinator::accounts::FreeCoordinatorAccounts; use psyche_solana_coordinator::accounts::InitCoordinatorAccounts; use psyche_solana_coordinator::accounts::JoinRunAccounts; diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs index c22daed52..d295fe1c1 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs @@ -1,11 +1,10 @@ -use psyche_coordinator::Round; -use psyche_coordinator::RunState; +use psyche_coordinator::coordinator::Round; +use psyche_coordinator::coordinator::RunState; +use psyche_coordinator::fixed_string::FixedString; +use psyche_coordinator::fixed_vec::FixedVec; use psyche_coordinator::model::CheckpointSource; -use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_core::FixedString; -use psyche_core::FixedVec; -use psyche_core::SmallBoolean; +use psyche_coordinator::small_boolean::SmallBoolean; +use psyche_core::CheckpointData; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::coordinator_account_from_bytes; @@ -30,25 +29,26 @@ pub async fn run() { assert_eq!(coordinator.run_state_start_unix_timestamp, 0); assert_eq!(coordinator.pending_pause, SmallBoolean::FALSE); // Coordinator model - match coordinator.model { - Model::LLM(llm) => { - assert_eq!(llm.max_seq_len, 2048); - assert_eq!(llm.cold_start_warmup_steps, 0); - assert_eq!(llm.checkpoint_source, CheckpointSource::Stored); - { - let checkpoint_data = - CheckpointData::from_fixed_vec(&llm.checkpoint_data) - .unwrap(); - match checkpoint_data { - CheckpointData::Hub { repo_id, revision } => { - assert_eq!(repo_id, "emozilla/llama2-1.1b-gqa-init"); - assert_eq!(revision, None); - }, - _ => panic!("Expected Hub checkpoint data"), - } - } - }, - }; + + assert_eq!(coordinator.model.max_seq_len, 2048); + assert_eq!(coordinator.model.cold_start_warmup_steps, 0); + assert_eq!( + coordinator.model.checkpoint_source, + CheckpointSource::Stored + ); + { + let checkpoint_data = + CheckpointData::from_fixed_vec(&coordinator.model.checkpoint_data) + .unwrap(); + match checkpoint_data { + CheckpointData::Hub { repo_id, revision } => { + assert_eq!(repo_id, "emozilla/llama2-1.1b-gqa-init"); + assert_eq!(revision, None); + }, + _ => panic!("Expected Hub checkpoint data"), + } + } + // Coordinator config assert_eq!(coordinator.config.warmup_time, 15); assert_eq!(coordinator.config.cooldown_time, 30); diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index 6b9195b88..cbc6d4455 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -1,12 +1,11 @@ -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::RunState; -use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; -use psyche_coordinator::WitnessProof; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::RunState; +use psyche_coordinator::coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::CheckpointSource; -use psyche_coordinator::model::LLM; use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_coordinator::types::WitnessProof; +use psyche_core::CheckpointData; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::instruction::Witness; @@ -103,12 +102,12 @@ pub async fn run() { total_steps: 100, waiting_for_members_extra_time: 3, }), - Some(Model::LLM(LLM { + Some(Model { checkpoint_source: CheckpointSource::Stored, checkpoint_data: CheckpointData::Dummy.to_fixed_vec(), max_seq_len: 4096, cold_start_warmup_steps: 0, - })), + }), None, // no explicit progress ) .await diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index d416660de..82950b0ee 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -1,12 +1,11 @@ -use psyche_coordinator::CommitteeSelection; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::SOLANA_MAX_NUM_WITNESSES; -use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; +use psyche_coordinator::committee_selection::CommitteeSelection; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::SOLANA_MAX_NUM_WITNESSES; +use psyche_coordinator::coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::CheckpointSource; -use psyche_coordinator::model::LLM; use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::CheckpointData; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::instruction::Witness; @@ -100,12 +99,12 @@ pub async fn run() { as u8, total_steps: 100, }), - Some(Model::LLM(LLM { + Some(Model { checkpoint_source: CheckpointSource::Stored, checkpoint_data: CheckpointData::Dummy.to_fixed_vec(), max_seq_len: 4096, cold_start_warmup_steps: 0, - })), + }), None, // no explicit progress ) .await diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs index 0e8dff005..ec4500457 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs @@ -1,9 +1,8 @@ -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::CheckpointSource; -use psyche_coordinator::model::LLM; use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; +use psyche_core::CheckpointData; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_tooling::create_memnet_endpoint::create_memnet_endpoint; use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_run_create; @@ -47,12 +46,12 @@ pub async fn run() { waiting_for_members_extra_time: WAITING_FOR_MEMBERS_EXTRA_SECONDS as u8, }), - model: Some(Model::LLM(LLM { + model: Some(Model { checkpoint_source: CheckpointSource::Stored, checkpoint_data: CheckpointData::Dummy.to_fixed_vec(), max_seq_len: 4096, cold_start_warmup_steps: 0, - })), + }), progress: None, epoch_earning_rate_total_shared: Some(66), epoch_slashing_rate_per_client: None, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 148a12add..239f0e1a7 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -1,14 +1,13 @@ use std::vec; -use psyche_coordinator::CommitteeSelection; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::SOLANA_MAX_NUM_WITNESSES; -use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; +use psyche_coordinator::committee_selection::CommitteeSelection; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::SOLANA_MAX_NUM_WITNESSES; +use psyche_coordinator::coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::CheckpointSource; -use psyche_coordinator::model::LLM; use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::CheckpointData; use psyche_solana_authorizer::logic::AuthorizationGranteeUpdateParams; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::CoordinatorAccount; @@ -222,12 +221,12 @@ pub async fn run() { total_steps: 100, waiting_for_members_extra_time: 3, }), - model: Some(Model::LLM(LLM { + model: Some(Model { checkpoint_source: CheckpointSource::Stored, checkpoint_data: CheckpointData::Dummy.to_fixed_vec(), max_seq_len: 4096, cold_start_warmup_steps: 0, - })), + }), progress: None, epoch_earning_rate_total_shared: Some( earned_point_per_epoch_total_shared, diff --git a/architectures/decentralized/solana-treasurer/Cargo.lock b/architectures/decentralized/solana-treasurer/Cargo.lock index 58d9afc91..56d1da879 100644 --- a/architectures/decentralized/solana-treasurer/Cargo.lock +++ b/architectures/decentralized/solana-treasurer/Cargo.lock @@ -239,21 +239,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anyhow" version = "1.0.94" @@ -389,17 +374,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "async-trait" -version = "0.1.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.90", -] - [[package]] name = "atomic-polyfill" version = "1.0.3" @@ -662,11 +636,7 @@ version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ - "android-tzdata", - "iana-time-zone", "num-traits", - "serde", - "windows-targets", ] [[package]] @@ -711,12 +681,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - [[package]] name = "cpufeatures" version = "0.2.16" @@ -851,16 +815,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", - "serde", -] - [[package]] name = "derivation-path" version = "0.2.0" @@ -1057,6 +1011,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1066,12 +1032,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.13.2" @@ -1110,12 +1070,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - [[package]] name = "hmac" version = "0.8.1" @@ -1146,29 +1100,6 @@ dependencies = [ "hmac 0.8.1", ] -[[package]] -name = "iana-time-zone" -version = "0.1.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -1181,17 +1112,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9007da9cacbd3e6343da136e98b0d2df013f553d35bdec8b518f07bea768e19c" -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", - "serde", -] - [[package]] name = "indexmap" version = "2.7.0" @@ -1200,7 +1120,6 @@ checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown 0.15.2", - "serde", ] [[package]] @@ -1377,12 +1296,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - [[package]] name = "num-derive" version = "0.4.2" @@ -1558,12 +1471,6 @@ dependencies = [ "serde", ] -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1605,13 +1512,13 @@ name = "psyche-coordinator" version = "0.2.0" dependencies = [ "anchor-lang", - "anyhow", - "async-trait", + "anchor-lang-idl", + "bitvec", "bytemuck", - "cfg_eval", - "psyche-core", + "data-encoding", + "fnv", + "rand 0.9.2", "serde", - "serde_with", "ts-rs", ] @@ -1622,12 +1529,10 @@ dependencies = [ "anchor-lang", "anchor-lang-idl", "anyhow", - "bitvec", "bytemuck", - "data-encoding", "fast-math", - "fnv", "postcard", + "psyche-coordinator", "serde", "serde_arrays", "sha2 0.10.8", @@ -1687,6 +1592,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -1717,6 +1628,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + [[package]] name = "rand_chacha" version = "0.2.2" @@ -1737,6 +1658,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + [[package]] name = "rand_core" version = "0.5.1" @@ -1755,6 +1686,15 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "rand_hc" version = "0.2.0" @@ -1894,16 +1834,9 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64 0.22.1", - "chrono", - "hex", - "indexmap 1.9.3", - "indexmap 2.7.0", "serde", "serde_derive", - "serde_json", "serde_with_macros", - "time", ] [[package]] @@ -3091,37 +3024,6 @@ dependencies = [ "syn 2.0.90", ] -[[package]] -name = "time" -version = "0.3.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" -dependencies = [ - "deranged", - "itoa", - "num-conv", - "powerfmt", - "serde", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - -[[package]] -name = "time-macros" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" -dependencies = [ - "num-conv", - "time-core", -] - [[package]] name = "tinyvec" version = "1.8.0" @@ -3173,7 +3075,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap", "serde", "serde_spanned", "toml_datetime", @@ -3183,7 +3085,7 @@ dependencies = [ [[package]] name = "ts-rs" version = "10.1.0" -source = "git+https://github.com/arilotter/ts-rs.git?rev=92ce1752227fec9bb868ad8f25b26f110a795099#92ce1752227fec9bb868ad8f25b26f110a795099" +source = "git+https://github.com/arilotter/ts-rs.git?rev=20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34#20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34" dependencies = [ "anchor-lang", "bytemuck", @@ -3196,7 +3098,7 @@ dependencies = [ [[package]] name = "ts-rs-macros" version = "10.1.0" -source = "git+https://github.com/arilotter/ts-rs.git?rev=92ce1752227fec9bb868ad8f25b26f110a795099#92ce1752227fec9bb868ad8f25b26f110a795099" +source = "git+https://github.com/arilotter/ts-rs.git?rev=20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34#20c5bef07f3aa8ca3aa1b5cdad11ecdd60c64d34" dependencies = [ "proc-macro2", "quote", @@ -3266,6 +3168,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.97" @@ -3340,15 +3251,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets", -] - [[package]] name = "windows-sys" version = "0.59.0" @@ -3431,6 +3333,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + [[package]] name = "wyz" version = "0.5.1" diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_create.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_create.rs index 790bd90b7..3cf4577ce 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_create.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_create.rs @@ -3,7 +3,7 @@ use anchor_spl::associated_token::AssociatedToken; use anchor_spl::token::Mint; use anchor_spl::token::Token; use anchor_spl::token::TokenAccount; -use psyche_coordinator::SOLANA_RUN_ID_MAX_LEN; +use psyche_coordinator::coordinator::SOLANA_RUN_ID_MAX_LEN; use psyche_solana_coordinator::cpi::accounts::InitCoordinatorAccounts; use psyche_solana_coordinator::cpi::init_coordinator; use psyche_solana_coordinator::logic::InitCoordinatorParams; @@ -119,7 +119,7 @@ pub fn run_create_processor( InitCoordinatorParams { main_authority: context.accounts.run.key(), join_authority: params.join_authority, - run_id: params.run_id.clone(), + run_id: params.run_id, client_version: params.client_version.clone(), }, )?; diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs index 99890c9e8..ea18798f6 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs @@ -1,6 +1,6 @@ use anchor_lang::prelude::*; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::CoordinatorProgress; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_coordinator::coordinator::CoordinatorProgress; use psyche_coordinator::model::Model; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::CoordinatorInstance; diff --git a/architectures/decentralized/testing/src/utils.rs b/architectures/decentralized/testing/src/utils.rs index 59dd46f12..a07cb02f2 100644 --- a/architectures/decentralized/testing/src/utils.rs +++ b/architectures/decentralized/testing/src/utils.rs @@ -8,10 +8,10 @@ use anchor_client::{ }, }; use psyche_coordinator::{ - NUM_STORED_ROUNDS, Round, RunState, - model::{CheckpointSource, Model}, + coordinator::{Client, NUM_STORED_ROUNDS, Round, RunState, SOLANA_MAX_NUM_CLIENTS}, + fixed_vec::FixedVec, + model::CheckpointSource, }; -use psyche_core::FixedVec; use psyche_solana_coordinator::SOLANA_MAX_NUM_PENDING_CLIENTS; use std::env; use std::path::PathBuf; @@ -46,7 +46,7 @@ impl SolanaTestClient { let program = client.program(psyche_solana_coordinator::ID).unwrap(); let seeds = &[ psyche_solana_coordinator::CoordinatorInstance::SEEDS_PREFIX, - psyche_solana_coordinator::bytes_from_string(&run_id), + psyche_solana_coordinator::bytes_from_str(&run_id), ]; let (instance, _) = Pubkey::find_program_address(seeds, &program.id()); let coordinator_instance: psyche_solana_coordinator::CoordinatorInstance = @@ -82,9 +82,7 @@ impl SolanaTestClient { pub async fn get_checkpoint(&self) -> CheckpointSource { let coordinator = self.get_coordinator_account().await; - match coordinator.state.coordinator.model { - Model::LLM(llm) => llm.checkpoint_source, - } + coordinator.state.coordinator.model.checkpoint_source } pub async fn get_clients( @@ -94,9 +92,7 @@ impl SolanaTestClient { coordinator.state.clients_state.clients } - pub async fn get_current_epoch_clients( - &self, - ) -> FixedVec { + pub async fn get_current_epoch_clients(&self) -> FixedVec { let coordinator = self.get_coordinator_account().await; coordinator.state.coordinator.epoch_state.clients } diff --git a/architectures/decentralized/testing/tests/integration_tests.rs b/architectures/decentralized/testing/tests/integration_tests.rs index f691daab2..167b71294 100644 --- a/architectures/decentralized/testing/tests/integration_tests.rs +++ b/architectures/decentralized/testing/tests/integration_tests.rs @@ -10,7 +10,8 @@ use std::{path::PathBuf, sync::Arc, time::Duration}; use anchor_client::solana_sdk::signature::{Keypair, Signer}; use bollard::container::StartContainerOptions; use bollard::{Docker, container::KillContainerOptions}; -use psyche_coordinator::{RunState, model::CheckpointSource}; +use psyche_coordinator::coordinator::RunState; +use psyche_coordinator::model::CheckpointSource; use psyche_core::IntegrationTestLogMarker; use psyche_decentralized_testing::docker_setup::e2e_testing_setup_rpc_fallback; use psyche_decentralized_testing::{ diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index e7b4237f3..6e6aa2b0e 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -2,7 +2,7 @@ use crate::{CheckpointConfig, WandBInfo}; use anyhow::{Context, Result, anyhow, bail}; use clap::Args; -use psyche_coordinator::model_extra_data::ModelExtraData; +use psyche_core::ModelExtraData; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index 5ab70917c..15873de6d 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -5,8 +5,11 @@ use crate::{ }; use anyhow::anyhow; use anyhow::{Error, Result, bail}; -use psyche_coordinator::{Commitment, CommitteeSelection, Coordinator, RunState}; -use psyche_core::IntegrationTestLogMarker; +use psyche_coordinator::{ + committee_selection::CommitteeSelection, + coordinator::{Coordinator, RunState}, +}; +use psyche_core::{Commitment, IntegrationTestLogMarker}; use psyche_event_sourcing::event; use psyche_metrics::{ClientMetrics, ClientRoleInRound, PeerConnection}; diff --git a/shared/client/src/fetch_data.rs b/shared/client/src/fetch_data.rs index dd79d2d5b..4652b334a 100644 --- a/shared/client/src/fetch_data.rs +++ b/shared/client/src/fetch_data.rs @@ -1,5 +1,5 @@ -use psyche_coordinator::{Coordinator, get_batch_ids_for_node}; -use psyche_core::{BatchId, NodeIdentity}; +use psyche_coordinator::{coordinator::Coordinator, node_identity::NodeIdentity}; +use psyche_core::{BatchId, get_batch_ids_for_node}; use psyche_data_provider::{DataProvider, TokenizedDataProvider}; use psyche_event_sourcing::event; use psyche_modeling::{Batch, BatchData, BatchDataCPU}; diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index e10aa62b8..b1e6bed99 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -9,7 +9,7 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, CheckpointUploader, InitRunError, ModelExtraData, RoundState, RunInitConfig, + CheckpointConfig, CheckpointUploader, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, }; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/protocol.rs b/shared/client/src/protocol.rs index 539775202..26f436f11 100644 --- a/shared/client/src/protocol.rs +++ b/shared/client/src/protocol.rs @@ -1,5 +1,5 @@ -use psyche_coordinator::{Commitment, CommitteeProof}; -use psyche_core::{BatchId, MerkleRoot}; +use psyche_coordinator::{hash_wrapper::HashWrapper, types::CommitteeProof}; +use psyche_core::{BatchId, Commitment}; use psyche_network::{BlobTicket, NetworkConnection, TransmittableDownload}; use serde::{Deserialize, Serialize}; @@ -13,7 +13,7 @@ pub struct TrainingResult { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Finished { - pub broadcast_merkle: MerkleRoot, + pub broadcast_merkle: HashWrapper, pub warmup: bool, } diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ab39708d0..277763190 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,7 +1,13 @@ +use super::{ + CheckpointConfig, + evals::{ModelTaskRunner, RunningEvals}, +}; use crate::CheckpointUploader; use psyche_coordinator::{ - CheckpointerSelection, Coordinator, model::Model, model_extra_data::CheckpointData, + checkpointer_selection::CheckpointerSelection, + coordinator::{Coordinator, CoordinatorError}, }; +use psyche_core::CheckpointData; use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; use psyche_event_sourcing::event; #[cfg(feature = "python")] @@ -27,11 +33,6 @@ use tokio::{ use tracing::error; use tracing::{Instrument, info, info_span, warn}; -use super::{ - CheckpointConfig, - evals::{ModelTaskRunner, RunningEvals}, -}; - #[derive(Error, Debug)] pub enum CooldownError { #[error("no trainers available for checkpointing!")] @@ -44,7 +45,7 @@ pub enum CooldownError { Checkpoint(#[from] CheckpointError), #[error("error in cooldown step: {0}")] - CoordinatorError(#[from] psyche_coordinator::CoordinatorError), + CoordinatorError(#[from] CoordinatorError), } pub struct CooldownStepMetadata { @@ -144,8 +145,8 @@ impl CooldownStepMetadata { let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); - let Model::LLM(ref llm) = state.model; - let checkpoint_data = llm.decode_checkpoint(); + let llm = state.model; + let checkpoint_data = CheckpointData::from_fixed_vec(&llm.checkpoint_data).ok(); let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index d0aa33bfb..104ac77a9 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -1,12 +1,13 @@ use crate::{WandBInfo, fetch_data::DataFetcher}; -pub use psyche_coordinator::model_extra_data::ModelExtraData; use psyche_coordinator::{ - Coordinator, HealthChecks, - model::{self, HttpLLMTrainingDataLocation, LLMTrainingDataLocation}, - model_extra_data::{CONFIG_PREFIX, CheckpointData, MODEL_CONFIG_FILENAME}, + coordinator::{Coordinator, HealthChecks}, + model::CheckpointSource, + node_identity::NodeIdentity, }; use psyche_core::{ - Barrier, CancellableBarrier, IntegrationTestLogMarker, NodeIdentity, Shuffle, TokenSize, + Barrier, CONFIG_PREFIX, CancellableBarrier, CheckpointData, HttpLLMTrainingDataLocation, + IntegrationTestLogMarker, LLMArchitecture, LLMTrainingDataLocation, LLMTrainingDataType, + MODEL_CONFIG_FILENAME, ModelExtraData, Shuffle, TokenSize, }; use psyche_data_provider::{ DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, @@ -205,9 +206,9 @@ impl RunInitConfigAndIO { )); } - let model::Model::LLM(llm) = state.model; + let llm = state.model; - if matches!(llm.checkpoint_source, model::CheckpointSource::Ephemeral) { + if matches!(llm.checkpoint_source, CheckpointSource::Ephemeral) { return Err(InitRunError::ModelIsEphemeral); } @@ -218,7 +219,7 @@ impl RunInitConfigAndIO { info!("Using model extra data override from CLI"); config } else { - match llm.decode_checkpoint() { + match CheckpointData::from_fixed_vec(&llm.checkpoint_data).ok() { Some(CheckpointData::Gcs { bucket, .. }) => { let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); debug!("Fetching model extra data from gs://{}/{}", bucket, path); @@ -324,13 +325,13 @@ impl RunInitConfigAndIO { let model_future: JoinHandle> = match &model_extra_data .architecture { - model::LLMArchitecture::HfLlama - | model::LLMArchitecture::HfDeepseek - | model::LLMArchitecture::HfAuto - | model::LLMArchitecture::Torchtitan => { - let checkpoint_data = llm.decode_checkpoint(); + LLMArchitecture::HfLlama + | LLMArchitecture::HfDeepseek + | LLMArchitecture::HfAuto + | LLMArchitecture::Torchtitan => { + let checkpoint_data = CheckpointData::from_fixed_vec(&llm.checkpoint_data).ok(); let is_dummy = matches!(checkpoint_data, Some(CheckpointData::Dummy)); - let is_p2p = matches!(llm.checkpoint_source, model::CheckpointSource::P2P); + let is_p2p = matches!(llm.checkpoint_source, CheckpointSource::P2P); if is_dummy { tokio::spawn(async move { @@ -394,14 +395,13 @@ impl RunInitConfigAndIO { debug!("Got p2p info, model_config: {}", model_config); let model_config = match model_extra_data.architecture { - model::LLMArchitecture::HfLlama => { + LLMArchitecture::HfLlama => { AutoConfig::Llama(serde_json::from_str(&model_config)?) } - model::LLMArchitecture::HfDeepseek => { + LLMArchitecture::HfDeepseek => { AutoConfig::Deepseek(serde_json::from_str(&model_config)?) } - model::LLMArchitecture::HfAuto - | model::LLMArchitecture::Torchtitan => { + LLMArchitecture::HfAuto | LLMArchitecture::Torchtitan => { #[cfg(feature = "python")] { AutoConfig::Auto(serde_json::from_str::< @@ -567,7 +567,7 @@ impl RunInitConfigAndIO { init_config.eval_task_max_docs, // if doing python fsdp we only have one effective dp rank for inference if init_config.data_parallelism > 1 - && model_extra_data.architecture == model::LLMArchitecture::HfAuto + && model_extra_data.architecture == LLMArchitecture::HfAuto { 1 } else { @@ -578,7 +578,7 @@ impl RunInitConfigAndIO { let serialized_config = source.serialize_config()?; let attn_implementation: Option = match model_extra_data.data_type { - model::LLMTrainingDataType::Finetuning => { + LLMTrainingDataType::Finetuning => { #[cfg(feature = "parallelism")] { // use varlen backend if available @@ -588,13 +588,13 @@ impl RunInitConfigAndIO { #[cfg(not(feature = "parallelism"))] None } - model::LLMTrainingDataType::Pretraining => None, + LLMTrainingDataType::Pretraining => None, }; let raw_loaded_model_type: RawLoadedModelType = match model_extra_data .architecture { - model::LLMArchitecture::HfAuto | model::LLMArchitecture::Torchtitan => { + LLMArchitecture::HfAuto | LLMArchitecture::Torchtitan => { #[cfg(feature = "python")] { let dp = init_config.data_parallelism; @@ -683,7 +683,7 @@ impl RunInitConfigAndIO { ModelLoadError::NoDeviceForRank(rank, devices) })?; match architecture { - model::LLMArchitecture::HfLlama => { + LLMArchitecture::HfLlama => { LlamaForCausalLM::from_pretrained( &source.try_into()?, Some(Kind::BFloat16), @@ -694,7 +694,7 @@ impl RunInitConfigAndIO { ) .map(|x| Box::new(x) as Box) } - model::LLMArchitecture::HfDeepseek => { + LLMArchitecture::HfDeepseek => { DeepseekForCausalLM::from_pretrained( &source.try_into()?, Some(Kind::BFloat16), @@ -705,8 +705,8 @@ impl RunInitConfigAndIO { ) .map(|x| Box::new(x) as Box) } - model::LLMArchitecture::HfAuto - | model::LLMArchitecture::Torchtitan => { + LLMArchitecture::HfAuto + | LLMArchitecture::Torchtitan => { unreachable!() } } diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index a7d773da4..c63aa1d19 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -13,7 +13,7 @@ mod train; mod warmup; mod witness; -pub use init::{InitRunError, ModelExtraData, RunInitConfig, RunInitConfigAndIO}; +pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; pub use types::{ diff --git a/shared/client/src/state/prompt.rs b/shared/client/src/state/prompt.rs index 0cb29fe08..833a30910 100644 --- a/shared/client/src/state/prompt.rs +++ b/shared/client/src/state/prompt.rs @@ -1,6 +1,6 @@ use crate::state::prompt_texts::get_prompt_texts; -use psyche_coordinator::MAX_TOKENS_TO_SEND; -use psyche_core::FixedVec; +use psyche_coordinator::coordinator::MAX_TOKENS_TO_SEND; +use psyche_coordinator::fixed_vec::FixedVec; use psyche_modeling::{CausalLM, EosToks}; use psyche_modeling::{LogitsProcessor, Sampling, Trainer}; use std::sync::{Mutex, RwLock}; diff --git a/shared/client/src/state/round_state.rs b/shared/client/src/state/round_state.rs index 1a1bd1ce1..942339c5d 100644 --- a/shared/client/src/state/round_state.rs +++ b/shared/client/src/state/round_state.rs @@ -1,17 +1,19 @@ use crate::{Finished, TrainingResult, fetch_data::BatchIdSet}; +use super::types::PayloadState; use psyche_coordinator::{ - Commitment, CommitteeProof, CommitteeSelection, WitnessBloom, WitnessProof, + committee_selection::CommitteeSelection, + coordinator::WitnessBloom, + node_identity::NodeIdentity, + types::{CommitteeProof, WitnessProof}, }; -use psyche_core::{BatchId, NodeIdentity}; +use psyche_core::{BatchId, Commitment}; use psyche_modeling::DistroResult; use std::{ collections::{BTreeMap, HashMap}, sync::{Arc, Mutex}, }; -use super::types::PayloadState; - pub struct RoundState { pub height: u32, pub step: u32, diff --git a/shared/client/src/state/stats.rs b/shared/client/src/state/stats.rs index 8e1c4f83e..36bfabeb9 100644 --- a/shared/client/src/state/stats.rs +++ b/shared/client/src/state/stats.rs @@ -1,7 +1,6 @@ -use psyche_coordinator::{ - Coordinator, MAX_TOKENS_TO_SEND, WitnessEvalResult, WitnessMetadata, model, -}; -use psyche_core::{BoundedQueue, FixedVec, LearningRateSchedule}; +use crate::state::evals::{EnumModelTask, PROMPT_TASK_NAME}; +use psyche_coordinator::coordinator::Coordinator; +use psyche_core::{BoundedQueue, LearningRateSchedule, WitnessEvalResult, WitnessMetadata}; use psyche_metrics::ClientMetrics; use psyche_modeling::Trainer; use psyche_network::P2PEndpointInfo; @@ -10,8 +9,6 @@ use tokenizers::Tokenizer; use tracing::{debug, trace, warn}; use wandb::{DataValue, LogData}; -use crate::state::evals::{EnumModelTask, PROMPT_TASK_NAME}; - use super::evals::ModelTaskRunner; pub struct StatsLogger { @@ -178,22 +175,19 @@ impl StatsLogger { pub fn get_witness_metadata(&self, state: &Coordinator) -> WitnessMetadata { let bandwidth_total: f64 = self.endpoint_info.iter().map(|v| v.bandwidth).sum(); - let evals = { - let mut evals: FixedVec = Default::default(); - for (key, val) in self.current_eval_results() { - let value = WitnessEvalResult::new_trunc_name(&key, no_nan(val as f32, 0.0)); - if evals.push(value).is_err() { - // fixedvec is full, that's ok! nothing we can do. - break; - } - } - evals - }; - + let evals = self + .current_eval_results() + .into_iter() + .filter(|(_, value)| !value.is_nan()) + .map(|(name, value)| WitnessEvalResult { + name, + value: value as f32, + }) + .collect(); let prompt_results = self.get_prompt_results(); let prompt_index = self.get_prompt_index(); - // NOTE: no NaNs allowed in borsh serialized data. + // NOTE: no NaNs allowed let tokens_per_sec = self.global_tokens_per_second(state); WitnessMetadata { step: state.progress.step, @@ -256,22 +250,20 @@ impl StatsLogger { pub fn global_tokens_per_second(&self, state: &Coordinator) -> f32 { match self.step_durations.is_empty() { true => 0., - false => match &state.model { - model::Model::LLM(_) => { - let tokens = state.get_target_global_batch_size(state.current_round()) as u32 - * state.get_sequence_length() - * self.step_durations.len() as u32; - let seconds = self - .step_durations - .iter() - .fold(0f32, |acc, ele| acc + ele.as_secs_f32()); - if seconds == 0.0 { - 0.0 - } else { - tokens as f32 / seconds - } + false => { + let tokens = state.get_target_global_batch_size(state.current_round()) as u32 + * state.get_sequence_length() + * self.step_durations.len() as u32; + let seconds = self + .step_durations + .iter() + .fold(0f32, |acc, ele| acc + ele.as_secs_f32()); + if seconds == 0.0 { + 0.0 + } else { + tokens as f32 / seconds } - }, + } } } @@ -312,13 +304,13 @@ impl StatsLogger { } // clear tokens_to_send buffer - pub fn get_prompt_results(&self) -> FixedVec { - let mut results = FixedVec::new(); + pub fn get_prompt_results(&self) -> Vec { + let mut results = Vec::new(); for eval_task in self.model_task_runner.tasks().iter().flatten() { if let EnumModelTask::PromptTask(prompt_task) = &eval_task.task { { let tokens = prompt_task.tokens_to_send.read().unwrap(); - results.extend(tokens.iter().cloned()).unwrap(); + results.extend(tokens.iter().cloned()); } if let Ok(decoded) = prompt_task .tokenizer @@ -360,9 +352,7 @@ fn total_tokens(state: &Coordinator) -> u64 { .current_round() .map(|y| y.data_index) .unwrap_or_default() - * match &state.model { - model::Model::LLM(llm) => llm.max_seq_len as u64, - } + * state.model.max_seq_len as u64 } fn perplexity(loss: f32) -> f32 { diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 17af6b01a..ec6253514 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -3,10 +3,28 @@ use crate::{ state::{train::FinishedTrainers, types::DeserializeError}, }; +use super::{ + FinishedBroadcast, RunInitConfigAndIO, + cooldown::{CooldownError, CooldownStep, CooldownStepMetadata}, + evals::EvalError, + init::InitRunError, + round_state::RoundState, + stats::StatsLogger, + train::{TrainError, TrainingStep, TrainingStepMetadata}, + types::PayloadState, + warmup::{WarmupStep, WarmupStepMetadata}, + witness::{WitnessStep, WitnessStepMetadata, WitnessingError}, +}; use iroh_blobs::api::Tag; -use psyche_coordinator::CheckpointerSelection; -use psyche_coordinator::{Committee, Coordinator, RunState, Witness, WitnessProof}; -use psyche_core::{IntegrationTestLogMarker, MerkleRoot, MerkleTree, NodeIdentity, sha256}; +use psyche_coordinator::{ + checkpointer_selection::CheckpointerSelection, + coordinator::{Coordinator, RunState, Witness}, + hash_wrapper::HashWrapper, + node_identity::NodeIdentity, + sha::sha256, + types::{Committee, WitnessProof}, +}; +use psyche_core::{IntegrationTestLogMarker, MerkleTree}; use psyche_event_sourcing::event; use psyche_modeling::{DistroResult, Trainer}; use psyche_network::{BlobTicket, Hash, P2PEndpointInfo, TransmittableDistroResult}; @@ -24,19 +42,6 @@ use tokio::{ }; use tracing::{Instrument, debug, info, trace, trace_span, warn}; -use super::{ - FinishedBroadcast, RunInitConfigAndIO, - cooldown::{CooldownError, CooldownStep, CooldownStepMetadata}, - evals::EvalError, - init::InitRunError, - round_state::RoundState, - stats::StatsLogger, - train::{TrainError, TrainingStep, TrainingStepMetadata}, - types::PayloadState, - warmup::{WarmupStep, WarmupStepMetadata}, - witness::{WitnessStep, WitnessStepMetadata, WitnessingError}, -}; - pub struct StepStateMachine { identity: NodeIdentity, @@ -195,7 +200,7 @@ impl StepStateMachine { .map(|i| i as u64) } - fn get_merkle_root(&self, broadcasts: &[[u8; 32]]) -> MerkleRoot { + fn get_merkle_root(&self, broadcasts: &[[u8; 32]]) -> HashWrapper { MerkleTree::new(broadcasts) .get_root() .cloned() diff --git a/shared/client/src/state/train.rs b/shared/client/src/state/train.rs index 6208b2ff9..397c3f4f1 100644 --- a/shared/client/src/state/train.rs +++ b/shared/client/src/state/train.rs @@ -3,12 +3,23 @@ use crate::{ state::types::{DeserializeError, PayloadState}, }; +use super::{ + evals::{MaybeRunningEvals, ModelTaskRunner}, + round_state::RoundState, + types::DistroBroadcastAndPayload, +}; use futures::{StreamExt, future::try_join_all, stream::FuturesUnordered}; use psyche_coordinator::{ - BLOOM_FALSE_RATE, Commitment, CommitteeSelection, Coordinator, CoordinatorError, HealthChecks, - assign_data_for_state, get_batch_ids_for_node, get_batch_ids_for_round, model, + bloom::Bloom, + committee_selection::CommitteeSelection, + coordinator::{BLOOM_FALSE_RATE, Coordinator, CoordinatorError, HealthChecks}, + model::CheckpointSource, + node_identity::NodeIdentity, +}; +use psyche_core::{ + BatchId, Commitment, IntegrationTestLogMarker, assign_data_for_state, get_batch_ids_for_node, + get_batch_ids_for_round, select_consensus_commitment_by_witnesses, }; -use psyche_core::{BatchId, Bloom, IntegrationTestLogMarker, NodeIdentity}; use psyche_event_sourcing::event; use psyche_modeling::{ ApplyDistroResultError, Batch, BatchData, DistroResult, TrainOutput, Trainer, @@ -32,12 +43,6 @@ use tokio::{sync::mpsc, task::JoinHandle}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, trace, trace_span, warn}; -use super::{ - evals::{MaybeRunningEvals, ModelTaskRunner}, - round_state::RoundState, - types::DistroBroadcastAndPayload, -}; - #[derive(Debug)] pub struct FinishedTrainers { pub evals_or_trainers: MaybeRunningEvals, @@ -523,12 +528,10 @@ impl TrainingStepMetadata { .witnesses .len() as u16, ); - let (cold_start_warmup_steps, checkpoint_is_p2p) = match &state.model { - model::Model::LLM(llm) => ( - llm.cold_start_warmup_steps, - matches!(llm.checkpoint_source, model::CheckpointSource::P2P), - ), - }; + let (cold_start_warmup_steps, checkpoint_is_p2p) = ( + state.model.cold_start_warmup_steps, + matches!(state.model.checkpoint_source, CheckpointSource::P2P), + ); let warmup_lr_between = state.get_cold_start_warmup_bounds(); // coordinator has already advanced to the next round (unless we're in cooldown) but we haven't started ours yet. @@ -584,7 +587,7 @@ impl TrainingStepMetadata { } }; trace!("Commitments for batch {batch_id}: {batch_commitments:?}"); - let consensus = match Coordinator::select_consensus_commitment_by_witnesses( + let consensus = match select_consensus_commitment_by_witnesses( &batch_commitments .iter() .map(|x| x.1.0) diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 94ae80f60..a7400c666 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -1,8 +1,10 @@ use std::path::PathBuf; use google_cloud_storage::client::{Storage, StorageControl}; -use psyche_coordinator::CommitteeProof; -use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_coordinator::{ + hash_wrapper::HashWrapper, node_identity::NodeIdentity, types::CommitteeProof, +}; +use psyche_core::BatchId; use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; @@ -119,7 +121,7 @@ pub struct DistroBroadcastAndPayload { pub struct FinishedBroadcast { pub step: u32, - pub merkle: MerkleRoot, + pub merkle: HashWrapper, pub commitment_data_hash: [u8; 32], pub proof: CommitteeProof, pub warmup: bool, diff --git a/shared/client/src/state/witness.rs b/shared/client/src/state/witness.rs index f5da4c21f..a6b47eae2 100644 --- a/shared/client/src/state/witness.rs +++ b/shared/client/src/state/witness.rs @@ -1,5 +1,13 @@ -use psyche_coordinator::{Coordinator, Witness, WitnessMetadata}; -use psyche_core::{MerkleRoot, MerkleTree, NodeIdentity}; +use super::{ + evals::{EvalError, MaybeRunningEvals, ModelTaskRunner, RunningEvals}, + round_state::RoundState, +}; +use psyche_coordinator::{ + coordinator::{Coordinator, Witness}, + hash_wrapper::HashWrapper, + node_identity::NodeIdentity, +}; +use psyche_core::{MerkleTree, WitnessMetadata}; use psyche_watcher::OpportunisticData; use thiserror::Error; use tokio::{ @@ -8,11 +16,6 @@ use tokio::{ }; use tracing::{info, trace}; -use super::{ - evals::{EvalError, MaybeRunningEvals, ModelTaskRunner, RunningEvals}, - round_state::RoundState, -}; - #[derive(Debug, Error)] pub enum WitnessingError { #[error("Failed to stop evals")] @@ -99,7 +102,7 @@ impl WitnessStep { } let merkle = MerkleTree::new(&previous_round.broadcasts); - let broadcast_merkle = merkle.get_root().cloned().unwrap_or(MerkleRoot::default()); + let broadcast_merkle = merkle.get_root().cloned().unwrap_or(HashWrapper::default()); let (participant_bloom, broadcast_bloom) = previous_round.blooms.lock().unwrap().unwrap_or_default(); diff --git a/shared/client/src/tui.rs b/shared/client/src/tui.rs index 62a74f969..e46638c36 100644 --- a/shared/client/src/tui.rs +++ b/shared/client/src/tui.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use psyche_coordinator::Committee; +use psyche_coordinator::types::Committee; use psyche_tui::ratatui::{ buffer::Buffer, layout::{Constraint, Layout, Rect}, diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index 1c427e92d..ae87ab349 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -1,16 +1,31 @@ [package] name = "psyche-coordinator" -version.workspace = true edition = "2021" +version.workspace = true [dependencies] -psyche-core.workspace = true -async-trait.workspace = true anchor-lang.workspace = true +anchor-lang-idl.workspace = true bytemuck.workspace = true -serde_with.workspace = true -anyhow.workspace = true -serde_json.workspace = true -serde.workspace = true -cfg_eval = "0.1.2" -ts-rs.workspace = true +fnv = "1.0.7" +bitvec = { git = "https://github.com/arilotter/bitvec", rev = "d33a2437f810ee4229457dfd1137d807914671f8" } + +# client feature +serde = { workspace = true, optional = true } +ts-rs = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +data-encoding = { version = "2.8.0", optional = true } + +[dev-dependencies] +postcard.workspace = true + + +[features] +# non-required for on-chain. serialization, nice debug/display impls, etc. +client = [ + "bitvec/serde", + "dep:serde", + "dep:ts-rs", + "dep:rand", + "dep:data-encoding", +] diff --git a/shared/core/src/bloom.rs b/shared/coordinator/src/bloom.rs similarity index 89% rename from shared/core/src/bloom.rs rename to shared/coordinator/src/bloom.rs index 551c27237..45a0f1b32 100644 --- a/shared/core/src/bloom.rs +++ b/shared/coordinator/src/bloom.rs @@ -1,17 +1,9 @@ -use anchor_lang::prelude::{borsh::BorshSerialize, *}; -use anchor_lang_idl::{ - build::IdlBuild, - types::{ - IdlArrayLen, IdlDefinedFields, IdlField, IdlRepr, IdlReprModifier, IdlType, IdlTypeDef, - IdlTypeDefTy, - }, -}; +use anchor_lang::prelude::*; + use bitvec::array::BitArray; use bytemuck::Zeroable; use fnv::FnvHasher; -use serde::{Deserialize, Deserializer, Serialize}; -use std::{collections::BTreeMap, fmt, hash::Hasher}; -use ts_rs::TS; +use std::{fmt, hash::Hasher}; // Modified from https://github.com/solana-labs/solana/blob/27eff8408b7223bb3c4ab70523f8a8dca3ca6645/bloom/src/bloom.rs @@ -21,16 +13,23 @@ pub trait BloomHashIndex { fn hash_at_index(&self, hash_index: u64) -> u64; } -#[derive(Clone, PartialEq, Eq, Copy, Zeroable, TS)] +#[derive(Clone, PartialEq, Eq, Copy, Zeroable)] +#[cfg_attr(feature = "client", derive(ts_rs::TS))] #[repr(C)] pub struct Bloom { pub keys: [u64; K], pub bits: BitArrayWrapper, } -#[derive(Clone, PartialEq, Eq, Copy, Default, Serialize, Deserialize, TS)] +#[derive(Clone, PartialEq, Eq, Copy, Default)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) +)] #[repr(transparent)] -pub struct BitArrayWrapper(#[ts(type = "number[]")] pub BitArray<[u64; U]>); +pub struct BitArrayWrapper( + #[cfg_attr(feature = "client", ts(type = "number[]"))] pub BitArray<[u64; U]>, +); impl AnchorSerialize for BitArrayWrapper { fn serialize(&self, writer: &mut W) -> std::io::Result<()> { @@ -52,33 +51,6 @@ impl AnchorDeserialize for BitArrayWrapper { } } -impl IdlBuild for BitArrayWrapper { - fn create_type() -> Option { - Some(IdlTypeDef { - name: format!("BitArrayWrapper{U}").to_string(), - docs: vec!["A wrapper around BitArray for serialization".to_string()], - serialization: Default::default(), - repr: Some(IdlRepr::Transparent), - generics: vec![], - ty: IdlTypeDefTy::Struct { - fields: Some(IdlDefinedFields::Named(vec![IdlField { - name: "0".to_string(), - docs: vec!["The underlying bit array".to_string()], - ty: IdlType::Array(Box::new(IdlType::U64), IdlArrayLen::Value(U)), - }])), - }, - }) - } - - fn insert_types(_types: &mut BTreeMap) { - // no inner types in idl - } - - fn get_full_path() -> String { - format!("{}::BitArrayWrapper{}", module_path!(), U) - } -} - unsafe impl Zeroable for BitArrayWrapper {} impl BitArrayWrapper { @@ -96,56 +68,12 @@ impl Default for Bloom { } } -impl Serialize for Bloom { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut state = serializer.serialize_struct("Bloom", 3)?; - state.serialize_field("keys", &self.keys.to_vec())?; - state.serialize_field("bits", &self.bits)?; - state.end() - } -} - -impl<'de, const U: usize, const K: usize> Deserialize<'de> for Bloom { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - struct BloomHelper { - keys: Vec, - bits: BitArrayWrapper, - } - - let helper = BloomHelper::deserialize(deserializer)?; - - if helper.keys.len() != K { - return Err(serde::de::Error::custom(format!( - "Expected {} keys, got {}", - K, - helper.keys.len() - ))); - } - - let mut keys = [0u64; K]; - keys.copy_from_slice(&helper.keys); - - Ok(Bloom { - keys, - bits: helper.bits, - }) - } -} - impl AnchorSerialize for Bloom { fn serialize(&self, writer: &mut W) -> std::io::Result<()> { for key in &self.keys { AnchorSerialize::serialize(key, writer)?; } - BorshSerialize::serialize(&self.bits, writer) + borsh::BorshSerialize::serialize(&self.bits, writer) } } @@ -160,76 +88,6 @@ impl AnchorDeserialize for Bloom { } } -impl IdlBuild for Bloom { - fn create_type() -> Option { - Some(IdlTypeDef { - name: format!("Bloom{U}_{K}"), - docs: vec![ - "A Bloom filter implementation with configurable size and number of hash functions" - .to_string(), - ], - serialization: Default::default(), - repr: Some(IdlRepr::C(IdlReprModifier { - packed: false, - align: None, - })), - generics: vec![], - ty: IdlTypeDefTy::Struct { - fields: Some(IdlDefinedFields::Named(vec![ - IdlField { - name: "keys".to_string(), - docs: vec!["Hash function keys".to_string()], - ty: IdlType::Array(Box::new(IdlType::U64), IdlArrayLen::Value(K)), - }, - IdlField { - name: "bits".to_string(), - docs: vec!["Bit array for the Bloom filter".to_string()], - ty: IdlType::Defined { - name: format!("BitArrayWrapper{U}"), - generics: vec![], - }, - }, - ])), - }, - }) - } - - fn insert_types(types: &mut BTreeMap) { - if let Some(ty) = BitArrayWrapper::::create_type() { - types.insert(ty.name.clone(), ty); - } - } - - fn get_full_path() -> String { - format!("{}::Bloom{}_{}", module_path!(), U, K) - } -} - -impl fmt::Debug for Bloom { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Bloom {{ keys.len: {}, bits: ", self.keys.len())?; - const MAX_PRINT_BITS: usize = 10; - - if Self::max_bits() <= MAX_PRINT_BITS { - // Print individual bits for small filters - for i in 0..Self::max_bits() { - match self.bits.0.get(i) { - Some(x) => write!(f, "{}", *x as u8)?, - None => write!(f, "X")?, - } - } - } else { - // Print byte array for larger filters - let words = self.bits.0.as_raw_slice(); - for byte in words.iter() { - write!(f, "{byte:016x}")?; // full u64 output - } - } - - write!(f, " }}") - } -} - impl Bloom { pub const fn max_bits() -> usize { U * std::mem::size_of::() * 8 @@ -251,7 +109,7 @@ impl Bloom { /// `keysize` bytes. /// /// See . - #[cfg(feature = "rand")] + #[cfg(feature = "client")] pub fn random(num_items: usize, false_rate: f64) -> Self { use rand::Rng; let m = Self::num_bits(num_items as f64, false_rate); @@ -260,26 +118,13 @@ impl Bloom { Self::new(num_bits, &keys) } - #[cfg(feature = "rand")] + #[cfg(feature = "client")] fn num_bits(num_items: f64, false_rate: f64) -> f64 { let n = num_items; let p = false_rate; ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil() } - #[cfg(feature = "rand")] - #[allow(dead_code)] - fn num_keys(num_bits: f64, num_items: f64) -> f64 { - let n = num_items; - let m = num_bits; - // infinity as usize is zero in rust 1.43 but 2^64-1 in rust 1.45; ensure it's zero here - if n == 0.0 { - 0.0 - } else { - 1f64.max(((m / n) * 2f64.ln()).round()) - } - } - fn pos(&self, key: &T, k: u64) -> u64 { key.hash_at_index(k) .checked_rem(self.bits.0.len() as u64) @@ -324,6 +169,157 @@ impl> BloomHashIndex for T { } } +impl fmt::Debug for Bloom { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Bloom {{ keys.len: {}, bits: ", self.keys.len())?; + const MAX_PRINT_BITS: usize = 10; + + if Self::max_bits() <= MAX_PRINT_BITS { + // Print individual bits for small filters + for i in 0..Self::max_bits() { + match self.bits.0.get(i) { + Some(x) => write!(f, "{}", *x as u8)?, + None => write!(f, "X")?, + } + } + } else { + // Print byte array for larger filters + let words = self.bits.0.as_raw_slice(); + for byte in words.iter() { + write!(f, "{byte:016x}")?; // full u64 output + } + } + + write!(f, " }}") + } +} + +#[cfg(feature = "client")] +use anchor_lang_idl::{ + build::IdlBuild, + types::{ + IdlArrayLen, IdlDefinedFields, IdlField, IdlRepr, IdlReprModifier, IdlType, IdlTypeDef, + IdlTypeDefTy, + }, +}; +#[cfg(feature = "client")] +impl IdlBuild for BitArrayWrapper { + fn create_type() -> Option { + Some(IdlTypeDef { + name: format!("BitArrayWrapper{U}").to_string(), + docs: vec!["A wrapper around BitArray for serialization".to_string()], + serialization: Default::default(), + repr: Some(IdlRepr::Transparent), + generics: vec![], + ty: IdlTypeDefTy::Struct { + fields: Some(IdlDefinedFields::Named(vec![IdlField { + name: "0".to_string(), + docs: vec!["The underlying bit array".to_string()], + ty: IdlType::Array(Box::new(IdlType::U64), IdlArrayLen::Value(U)), + }])), + }, + }) + } + + fn insert_types(_types: &mut std::collections::BTreeMap) { + // no inner types in idl + } + + fn get_full_path() -> String { + format!("{}::BitArrayWrapper{}", module_path!(), U) + } +} +#[cfg(feature = "client")] +impl serde::Serialize for Bloom { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("Bloom", 3)?; + state.serialize_field("keys", &self.keys.to_vec())?; + state.serialize_field("bits", &self.bits)?; + state.end() + } +} + +#[cfg(feature = "client")] +impl<'de, const U: usize, const K: usize> serde::Deserialize<'de> for Bloom { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct BloomHelper { + keys: Vec, + bits: BitArrayWrapper, + } + + let helper = BloomHelper::deserialize(deserializer)?; + + if helper.keys.len() != K { + return Err(serde::de::Error::custom(format!( + "Expected {} keys, got {}", + K, + helper.keys.len() + ))); + } + + let mut keys = [0u64; K]; + keys.copy_from_slice(&helper.keys); + + Ok(Bloom { + keys, + bits: helper.bits, + }) + } +} +#[cfg(feature = "client")] +impl IdlBuild for Bloom { + fn create_type() -> Option { + Some(IdlTypeDef { + name: format!("Bloom{U}_{K}"), + docs: vec![ + "A Bloom filter implementation with configurable size and number of hash functions" + .to_string(), + ], + serialization: Default::default(), + repr: Some(IdlRepr::C(IdlReprModifier { + packed: false, + align: None, + })), + generics: vec![], + ty: IdlTypeDefTy::Struct { + fields: Some(IdlDefinedFields::Named(vec![ + IdlField { + name: "keys".to_string(), + docs: vec!["Hash function keys".to_string()], + ty: IdlType::Array(Box::new(IdlType::U64), IdlArrayLen::Value(K)), + }, + IdlField { + name: "bits".to_string(), + docs: vec!["Bit array for the Bloom filter".to_string()], + ty: IdlType::Defined { + name: format!("BitArrayWrapper{U}"), + generics: vec![], + }, + }, + ])), + }, + }) + } + + fn insert_types(types: &mut std::collections::BTreeMap) { + if let Some(ty) = BitArrayWrapper::::create_type() { + types.insert(ty.name.clone(), ty); + } + } + + fn get_full_path() -> String { + format!("{}::Bloom{}_{}", module_path!(), U, K) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/shared/coordinator/src/checkpointer_selection.rs b/shared/coordinator/src/checkpointer_selection.rs index 64193d12a..fcee2c52d 100644 --- a/shared/coordinator/src/checkpointer_selection.rs +++ b/shared/coordinator/src/checkpointer_selection.rs @@ -1,7 +1,10 @@ use std::cmp::max; -use crate::{Coordinator, CoordinatorError, coordinator::SOLANA_MAX_NUM_CHECKPOINTERS}; -use psyche_core::{compute_shuffled_index, sha256, sha256v}; +use crate::{ + coordinator::{Coordinator, CoordinatorError, Round, SOLANA_MAX_NUM_CHECKPOINTERS}, + sha::{sha256, sha256v}, + swap_or_not::compute_shuffled_index, +}; use super::types::salts; @@ -46,7 +49,7 @@ impl CheckpointerSelection { pub(crate) fn get_round_by_offset( coordinator: &Coordinator, offset: isize, -) -> Result<&crate::Round, CoordinatorError> { +) -> Result<&Round, CoordinatorError> { match offset { -2 => coordinator.previous_previous_round(), -1 => coordinator.previous_round(), diff --git a/shared/coordinator/src/committee_selection.rs b/shared/coordinator/src/committee_selection.rs index ca1c350ee..5194cf694 100644 --- a/shared/coordinator/src/committee_selection.rs +++ b/shared/coordinator/src/committee_selection.rs @@ -1,5 +1,7 @@ -use crate::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; -use psyche_core::{NodeIdentity, compute_shuffled_index, sha256}; +use crate::coordinator::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; +use crate::node_identity::NodeIdentity; +use crate::sha::{sha256, sha256v}; +use crate::swap_or_not::compute_shuffled_index; use super::checkpointer_selection::get_round_by_offset; use super::types::{Committee, CommitteeProof, WitnessProof, salts}; @@ -147,7 +149,7 @@ impl CommitteeSelection { fn compute_shuffled_index(&self, index: u64, salt: &str) -> u64 { let mut seed = [0u8; 32]; - seed.copy_from_slice(&psyche_core::sha256v(&[&self.seed, salt.as_bytes()])); + seed.copy_from_slice(&sha256v(&[&self.seed, salt.as_bytes()])); compute_shuffled_index(index, self.total_nodes, &seed) } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 2fcfe4c1c..930a9e184 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,17 +1,20 @@ -use crate::{ - CheckpointerSelection, Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{CheckpointBytes, CheckpointSource, Model}, -}; - -use anchor_lang::{ - AnchorDeserialize, AnchorSerialize, InitSpace, - prelude::{borsh, msg}, -}; +use std::collections::HashSet; +use std::hash::Hash; + +use anchor_lang::prelude::*; use bytemuck::{Pod, Zeroable}; -use psyche_core::{Bloom, FixedString, FixedVec, MerkleRoot, NodeIdentity, SmallBoolean, sha256}; -use serde::{Deserialize, Serialize}; -use std::{collections::HashSet, hash::Hash}; -use ts_rs::TS; + +use crate::bloom::Bloom; +use crate::checkpointer_selection::CheckpointerSelection; +use crate::committee_selection::CommitteeSelection; +use crate::fixed_string::FixedString; +use crate::fixed_vec::FixedVec; +use crate::hash_wrapper::HashWrapper; +use crate::model::{CheckpointBytes, CheckpointSource, Model}; +use crate::node_identity::NodeIdentity; +use crate::sha::sha256; +use crate::small_boolean::SmallBoolean; +use crate::types::{Committee, CommitteeProof, WitnessProof}; pub const SOLANA_MAX_STRING_LEN: usize = 64; pub const SOLANA_MAX_URL_STRING_LEN: usize = 192; @@ -20,7 +23,6 @@ pub const SOLANA_MAX_NUM_WITNESSES: usize = 32; pub const SOLANA_MAX_NUM_CHECKPOINTERS: usize = 16; // run_id must be at most 32 bytes because of PDA constraints pub const SOLANA_RUN_ID_MAX_LEN: usize = 32; - pub const BLOOM_FALSE_RATE: f64 = 0.01f64; pub const WITNESS_QUORUM_RAIO: f64 = 2.0f64 / 3.0f64; pub const WAITING_FOR_MEMBERS_EXTRA_SECONDS: u64 = 10; @@ -31,18 +33,11 @@ pub const MAX_TOKENS_TO_SEND: usize = 16; pub type WitnessBloom = Bloom<16, 8>; #[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, + Clone, Copy, Default, PartialEq, Zeroable, AnchorDeserialize, AnchorSerialize, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(u8)] pub enum RunState { @@ -58,18 +53,11 @@ pub enum RunState { } #[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, + Clone, Copy, Default, PartialEq, Zeroable, AnchorDeserialize, AnchorSerialize, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(u8)] pub enum ClientState { @@ -80,18 +68,11 @@ pub enum ClientState { Ejected = 3, } -#[derive( - Clone, - Debug, - Zeroable, - Default, - Copy, - Serialize, - Deserialize, - AnchorDeserialize, - AnchorSerialize, - TS, +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] +#[derive(Clone, Zeroable, Default, Copy, AnchorDeserialize, AnchorSerialize)] #[repr(C)] pub struct Client { pub id: NodeIdentity, @@ -116,19 +97,11 @@ impl Hash for Client { } } -#[derive( - Clone, - Default, - Debug, - Zeroable, - Copy, - Serialize, - Deserialize, - AnchorSerialize, - AnchorDeserialize, - PartialEq, - TS, +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] +#[derive(Clone, Default, Zeroable, Copy, AnchorSerialize, AnchorDeserialize, PartialEq)] #[repr(C)] pub struct Round { pub witnesses: FixedVec, @@ -140,76 +113,17 @@ pub struct Round { pub tie_breaker_tasks: u16, } -#[derive( - Clone, - Debug, - Zeroable, - Default, - Copy, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - PartialEq, - TS, +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] +#[derive(Clone, Zeroable, Default, Copy, AnchorDeserialize, AnchorSerialize, PartialEq)] #[repr(C)] pub struct Witness { pub proof: WitnessProof, pub participant_bloom: WitnessBloom, pub broadcast_bloom: WitnessBloom, - pub broadcast_merkle: MerkleRoot, -} - -#[derive( - Clone, - Copy, - Zeroable, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, - Default, - Debug, -)] -#[repr(C)] -pub struct WitnessMetadata { - pub step: u32, - pub tokens_per_sec: f32, - pub bandwidth_per_sec: f32, - pub loss: f32, - pub evals: FixedVec, - pub prompt_results: FixedVec, - pub prompt_index: u8, - pub efficency: f32, -} - -#[derive( - Clone, - Copy, - Zeroable, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, - Default, - Debug, -)] -#[repr(C)] -pub struct WitnessEvalResult { - pub name: FixedString<32>, - pub value: f32, -} - -impl WitnessEvalResult { - pub fn new_trunc_name(name: &str, value: f32) -> Self { - Self { - name: FixedString::from_str_truncated(name), - value, - } - } + pub broadcast_merkle: HashWrapper, } #[derive(Clone, Copy, Debug)] @@ -237,8 +151,10 @@ pub type HealthChecks = Vec<(NodeIdentity, CommitteeProof)>; pub const NUM_STORED_ROUNDS: usize = 4; -#[derive( - Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorDeserialize, AnchorSerialize, TS, +#[derive(Clone, Zeroable, Copy, AnchorDeserialize, AnchorSerialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct CoordinatorConfig { @@ -263,8 +179,10 @@ pub struct CoordinatorConfig { pub waiting_for_members_extra_time: u8, } -#[derive( - Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, +#[derive(Clone, Zeroable, Copy, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct CoordinatorEpochState { @@ -286,8 +204,10 @@ pub struct CoordinatorEpochState { pub checkpointed: bool, } -#[derive( - Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, +#[derive(Clone, Zeroable, Copy, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct CoordinatorProgress { @@ -296,8 +216,10 @@ pub struct CoordinatorProgress { pub epoch_start_data_index: u64, } -#[derive( - Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, +#[derive(Clone, Copy, AnchorSerialize, AnchorDeserialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct Coordinator { @@ -309,20 +231,26 @@ pub struct Coordinator { pub config: CoordinatorConfig, - #[serde(default)] + #[cfg_attr(feature = "client", serde(default))] pub progress: CoordinatorProgress, - #[serde(default)] + #[cfg_attr(feature = "client", serde(default))] pub epoch_state: CoordinatorEpochState, // note, gets zeroed at the start of every epoch (not persistent through epochs) - #[serde(default)] + #[cfg_attr(feature = "client", serde(default))] pub run_state_start_unix_timestamp: u64, - #[serde(default)] + #[cfg_attr(feature = "client", serde(default))] pub pending_pause: SmallBoolean, } -unsafe impl Pod for Coordinator {} +unsafe impl Pod for Coordinator { + // BAD BAD BAD UNSAFE WRONG FIXME TODO XXX +} + +unsafe impl Zeroable for Coordinator { + // BAD BAD BAD UNSAFE WRONG FIXME TODO XXX +} impl TryFrom for RunState { type Error = CoordinatorError; @@ -342,20 +270,6 @@ impl TryFrom for RunState { } } -impl From for usize { - fn from(val: RunState) -> Self { - match val { - RunState::Uninitialized => 0, - RunState::WaitingForMembers => 1, - RunState::Warmup => 2, - RunState::RoundTrain => 3, - RunState::RoundWitness => 4, - RunState::Cooldown => 5, - RunState::Finished => 6, - RunState::Paused => 7, - } - } -} impl PartialEq for Client { fn eq(&self, other: &Self) -> bool { self.id == other.id @@ -646,10 +560,9 @@ impl Coordinator { return Err(CoordinatorError::InvalidWitness); } - let Model::LLM(llm) = &mut self.model; - match llm.checkpoint_source { + match self.model.checkpoint_source { CheckpointSource::Stored | CheckpointSource::P2P => { - llm.checkpoint_data = checkpoint_data; + self.model.checkpoint_data = checkpoint_data; } CheckpointSource::Ephemeral => {} } @@ -695,7 +608,7 @@ impl Coordinator { } } - pub fn resume(&mut self, unix_timestamp: u64) -> Result<(), CoordinatorError> { + pub fn resume(&mut self, unix_timestamp: u64) -> std::result::Result<(), CoordinatorError> { if self.run_state != RunState::Paused { return Err(CoordinatorError::CannotResume); } @@ -707,7 +620,7 @@ impl Coordinator { &self, id: &NodeIdentity, proof: &CommitteeProof, - ) -> Result { + ) -> std::result::Result { let round = self .previous_round() .ok_or(CoordinatorError::NoActiveRound)?; @@ -750,7 +663,10 @@ impl Coordinator { } } - pub fn trainer_healthy(&self, id: &NodeIdentity) -> Result { + pub fn trainer_healthy( + &self, + id: &NodeIdentity, + ) -> std::result::Result { let prev_round_witnesses = &self .previous_round() .ok_or(CoordinatorError::NoActiveRound)? @@ -775,28 +691,6 @@ impl Coordinator { score } - pub fn select_consensus_commitment_by_witnesses( - commitments: &[Commitment], - witnesses: &[Witness], - witness_quorum: u16, - ) -> Option { - let mut scores = vec![0; commitments.len()]; - for witness in witnesses { - for (index, commitment) in commitments.iter().enumerate() { - if witness.broadcast_bloom.contains(&commitment.data_hash) { - scores[index] += 1; - break; - } - } - } - scores - .into_iter() - .enumerate() - .filter(|(_, score)| *score >= witness_quorum) - .max_by_key(|(_, score)| *score) - .map(|(index, _)| index) - } - pub fn current_round(&self) -> Option<&Round> { self.epoch_state .rounds @@ -886,9 +780,7 @@ impl Coordinator { } pub fn get_sequence_length(&self) -> u32 { - match &self.model { - Model::LLM(llm) => llm.max_seq_len, - } + self.model.max_seq_len } pub fn get_target_global_batch_size(&self, round: Option<&Round>) -> u16 { @@ -906,8 +798,7 @@ impl Coordinator { } pub fn get_cold_start_warmup_bounds(&self) -> Option<(u32, u32)> { - let Model::LLM(llm) = &self.model; - let cold_start_warmup_steps = llm.cold_start_warmup_steps; + let cold_start_warmup_steps = self.model.cold_start_warmup_steps; if self.epoch_state.cold_start_epoch.is_false() || cold_start_warmup_steps == 0 { return None; } @@ -919,16 +810,15 @@ impl Coordinator { /// Check that cold_start_warmup_steps can be completed within a single epoch. pub fn check_cold_start_warmup_steps(&self) -> bool { - let Model::LLM(llm) = &self.model; - if llm.cold_start_warmup_steps == 0 { + if self.model.cold_start_warmup_steps == 0 { return true; } let training_time = self.config.epoch_time - self.config.warmup_time; let estimated_training_rounds = training_time / self.config.max_round_train_time; - if llm.cold_start_warmup_steps as u64 > estimated_training_rounds { + if self.model.cold_start_warmup_steps as u64 > estimated_training_rounds { msg!( "cold_start_warmup_steps ({}) exceeds estimated training rounds per epoch ((epoch_time={} - warmup_time={}) / max_round_train_time={} = {})", - llm.cold_start_warmup_steps, + self.model.cold_start_warmup_steps, self.config.epoch_time, self.config.warmup_time, self.config.max_round_train_time, @@ -979,11 +869,10 @@ impl Coordinator { .clients .iter() .any(|client| pending_clients_unordered.contains(&client.id)); - if all_prev_clients_disconnected { - let Model::LLM(llm) = &mut self.model; - if llm.checkpoint_source == CheckpointSource::P2P { - llm.checkpoint_source = CheckpointSource::Stored; - } + if all_prev_clients_disconnected + && self.model.checkpoint_source == CheckpointSource::P2P + { + self.model.checkpoint_source = CheckpointSource::Stored; } let cold_start_epoch = self.epoch_state.cold_start_epoch; @@ -1111,9 +1000,8 @@ impl Coordinator { self.move_clients_to_exited(height); // we've completed an epoch, switch to P2P from now on - let Model::LLM(llm) = &mut self.model; - if llm.checkpoint_source == CheckpointSource::Stored { - llm.checkpoint_source = CheckpointSource::P2P; + if self.model.checkpoint_source == CheckpointSource::Stored { + self.model.checkpoint_source = CheckpointSource::P2P; } if self.pending_pause.is_true() { @@ -1252,7 +1140,7 @@ impl CoordinatorConfig { } #[inline(always)] - pub fn check_error(&self) -> Result<(), ConfigError> { + pub fn check_error(&self) -> std::result::Result<(), ConfigError> { if self.epoch_time == 0 { return Err(ConfigError::EpochTime); } diff --git a/shared/coordinator/src/fixed_string.rs b/shared/coordinator/src/fixed_string.rs new file mode 100644 index 000000000..15953eef7 --- /dev/null +++ b/shared/coordinator/src/fixed_string.rs @@ -0,0 +1,170 @@ +use std::fmt::Display; + +use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; +use bytemuck::Zeroable; + +#[derive(Clone, Copy, AnchorSerialize, AnchorDeserialize, PartialEq, Eq, InitSpace, Zeroable)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) +)] +pub struct FixedString( + #[cfg_attr( + feature = "client", + serde( + serialize_with = "serde_serialize_string", + deserialize_with = "serde_deserialize_string" + ) + )] + #[cfg_attr(feature = "client", ts(as = "String"))] + [u8; L], +); + +impl std::fmt::Debug for FixedString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let used_bytes = match self.0.iter().position(|&b| b == 0) { + Some(null_pos) => null_pos, + None => L, + }; + + let zero_bytes = L - used_bytes; + + let string_content = String::from(self); + + write!( + f, + "\"{string_content}\" ({used_bytes}/{L} bytes, {zero_bytes} zeroes)" + ) + } +} + +impl Display for FixedString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", String::from(self)) + } +} + +impl Default for FixedString { + fn default() -> Self { + Self([0u8; L]) + } +} + +impl TryFrom<&str> for FixedString { + type Error = &'static str; + + fn try_from(value: &str) -> Result { + let bytes = value.as_bytes(); + if bytes.len() > L { + return Err("str does not fit in FixedString"); + } + Ok(Self::from_str_truncated(value)) + } +} + +impl TryFrom<&String> for FixedString { + type Error = &'static str; + + fn try_from(value: &String) -> Result { + let bytes = value.as_bytes(); + if bytes.len() > L { + return Err("str does not fit in FixedString"); + } + Ok(Self::from_str_truncated(value)) + } +} + +impl FixedString { + pub fn new() -> Self { + Default::default() + } + + pub fn from_str_truncated(s: &str) -> Self { + let mut array = [0u8; L]; + let bytes = s.as_bytes(); + let len = bytes.len().min(L); + array[..len].copy_from_slice(&bytes[..len]); + Self(array) + } + + pub fn is_empty(&self) -> bool { + self.0[0] == 0 + } +} + +impl From<&FixedString> for String { + fn from(value: &FixedString) -> Self { + let sliced = match value.0.iter().position(|&b| b == 0) { + Some(null_pos) => &value.0[0..null_pos], + None => &value.0, + }; + String::from_utf8_lossy(sliced).to_string() + } +} + +impl From<[u8; L]> for FixedString { + fn from(value: [u8; L]) -> Self { + Self(value) + } +} + +impl From> for [u8; L] { + fn from(value: FixedString) -> Self { + value.0 + } +} + +#[cfg(feature = "client")] +pub fn serde_serialize_string( + run_id: &[u8], + serializer: S, +) -> std::result::Result +where + S: serde::Serializer, +{ + // Convert bytes to string, trimming null bytes + let s = String::from_utf8_lossy(run_id) + .trim_matches(char::from(0)) + .to_string(); + serializer.serialize_str(&s) +} + +#[cfg(feature = "client")] +pub fn serde_deserialize_string<'de, D, const N: usize>( + deserializer: D, +) -> std::result::Result<[u8; N], D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: String = ::deserialize(deserializer)?; + let mut bytes = [0u8; N]; + let len = std::cmp::min(s.len(), N); + bytes[..len].copy_from_slice(&s.as_bytes()[..len]); + Ok(bytes) +} + +#[cfg(all(test, feature = "client"))] +mod test { + use serde::{Deserialize, Serialize}; + + use super::*; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] + struct MyStrStruct { + #[serde( + serialize_with = "serde_serialize_string", + deserialize_with = "serde_deserialize_string" + )] + field: [u8; 64], + } + + #[test] + fn test_serialize_deserialize_string() { + let my_struct = MyStrStruct { field: [1u8; 64] }; + + let bytes = postcard::to_stdvec(&my_struct).unwrap(); + let deserialized_struct: MyStrStruct = postcard::from_bytes(&bytes).unwrap(); + + assert_eq!(my_struct, deserialized_struct); + } +} diff --git a/shared/core/src/fixed_vec.rs b/shared/coordinator/src/fixed_vec.rs similarity index 96% rename from shared/core/src/fixed_vec.rs rename to shared/coordinator/src/fixed_vec.rs index cd17be654..746b38447 100644 --- a/shared/core/src/fixed_vec.rs +++ b/shared/coordinator/src/fixed_vec.rs @@ -1,11 +1,13 @@ use anchor_lang::{AnchorDeserialize, AnchorSerialize, prelude::borsh}; use bytemuck::Zeroable; -use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut, Range, RangeFrom, RangeFull, RangeTo}; + +#[cfg(feature = "client")] use ts_rs::TS; -#[derive(Clone, Copy, Zeroable, AnchorSerialize, AnchorDeserialize, PartialEq, TS)] -#[ts(type = "Array", bound = "T: TS")] +#[derive(Clone, Copy, Zeroable, AnchorSerialize, AnchorDeserialize, PartialEq)] +#[cfg_attr(feature = "client", derive(TS))] +#[cfg_attr(feature = "client", ts(type = "Array", bound = "T: TS"))] #[repr(C)] pub struct FixedVec { data: [T; N], @@ -317,7 +319,8 @@ impl TryFrom<[T; M]> for Fixe } } -impl Serialize for FixedVec { +#[cfg(feature = "client")] +impl serde::Serialize for FixedVec { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -327,7 +330,8 @@ impl Serialize for FixedVec } } -impl<'de, T: Deserialize<'de> + Default + Copy, const N: usize> Deserialize<'de> +#[cfg(feature = "client")] +impl<'de, T: serde::Deserialize<'de> + Default + Copy, const N: usize> serde::Deserialize<'de> for FixedVec { fn deserialize(deserializer: D) -> Result diff --git a/shared/coordinator/src/hash_wrapper.rs b/shared/coordinator/src/hash_wrapper.rs new file mode 100644 index 000000000..d32483112 --- /dev/null +++ b/shared/coordinator/src/hash_wrapper.rs @@ -0,0 +1,43 @@ +use anchor_lang::prelude::*; +use bytemuck::Zeroable; + +/// This wrapper is used to implement the `Space` trait for the actual hash. +#[derive( + AnchorSerialize, AnchorDeserialize, PartialEq, Eq, Clone, Default, Zeroable, Copy, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) +)] +pub struct HashWrapper { + pub inner: [u8; 32], +} + +impl HashWrapper { + pub fn new(inner: [u8; 32]) -> Self { + Self { inner } + } + + #[cfg(feature = "client")] + pub fn fmt_short(&self) -> String { + data_encoding::HEXLOWER.encode(&self.inner[..5]) + } + + #[cfg(feature = "client")] + pub fn fmt_full(&self) -> String { + data_encoding::HEXLOWER.encode(&self.inner) + } +} + +#[cfg(feature = "client")] +impl std::fmt::Debug for HashWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "HashWrapper({})", self.fmt_full()) + } +} + +impl AsRef<[u8]> for HashWrapper { + fn as_ref(&self) -> &[u8] { + &self.inner + } +} diff --git a/shared/coordinator/src/lib.rs b/shared/coordinator/src/lib.rs index 12f682c70..be0f8d20c 100644 --- a/shared/coordinator/src/lib.rs +++ b/shared/coordinator/src/lib.rs @@ -1,25 +1,15 @@ #![allow(unexpected_cfgs)] -mod checkpointer_selection; -mod commitment; -mod committee_selection; -mod coordinator; -mod data_selection; +pub mod bloom; +pub mod checkpointer_selection; +pub mod committee_selection; +pub mod coordinator; +pub mod fixed_string; +pub mod fixed_vec; +pub mod hash_wrapper; pub mod model; -pub mod model_extra_data; -mod types; - -pub use checkpointer_selection::CheckpointerSelection; -pub use commitment::Commitment; -pub use committee_selection::CommitteeSelection; -pub use coordinator::{ - BLOOM_FALSE_RATE, Client, ClientState, Coordinator, CoordinatorConfig, CoordinatorEpochState, - CoordinatorError, CoordinatorProgress, HealthChecks, MAX_TOKENS_TO_SEND, NUM_STORED_ROUNDS, - Round, RunState, SOLANA_MAX_NUM_CLIENTS, SOLANA_MAX_NUM_WITNESSES, SOLANA_MAX_STRING_LEN, - SOLANA_RUN_ID_MAX_LEN, TickResult, WAITING_FOR_MEMBERS_EXTRA_SECONDS, Witness, WitnessBloom, - WitnessEvalResult, WitnessMetadata, -}; -pub use data_selection::{ - assign_data_for_state, get_batch_ids_for_node, get_batch_ids_for_round, get_data_index_for_step, -}; -pub use types::{Committee, CommitteeProof, WitnessProof, salts}; +pub mod node_identity; +pub mod sha; +pub mod small_boolean; +pub mod swap_or_not; +pub mod types; diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 25c82b000..4768a9aa8 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -1,279 +1,105 @@ -use crate::model_extra_data::{CHECKPOINT_DATA_MAX_LEN, CheckpointData}; -use crate::{SOLANA_MAX_STRING_LEN, coordinator::SOLANA_MAX_URL_STRING_LEN}; +use std::fmt::Display; -use anchor_lang::{ - AnchorDeserialize, AnchorSerialize, InitSpace, - prelude::{borsh, msg}, -}; -use bytemuck::{Zeroable, ZeroableInOption}; -use psyche_core::{FixedString, FixedVec, Shuffle, TokenSize}; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; +use anchor_lang::prelude::*; +use bytemuck::Zeroable; -/// Opaque byte blob holding borsh-serialized [`CheckpointData`]. +use crate::fixed_vec::FixedVec; + +pub const CHECKPOINT_DATA_MAX_LEN: usize = 256; +/// Opaque byte blob holding serialized [`CheckpointData`]. pub type CheckpointBytes = FixedVec; -#[derive( - Clone, Debug, Copy, Zeroable, AnchorDeserialize, AnchorSerialize, Serialize, Deserialize, TS, +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Zeroable, Copy)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] -pub enum Model { - LLM(LLM), +pub struct Model { + pub max_seq_len: u32, + pub cold_start_warmup_steps: u32, + pub checkpoint_source: CheckpointSource, + #[cfg_attr(feature = "client", serde(default))] + pub checkpoint_data: FixedVec, } -unsafe impl ZeroableInOption for Model {} - -#[derive( - Clone, - Debug, - Copy, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, - PartialEq, +#[derive(AnchorSerialize, AnchorDeserialize, Clone, Zeroable, Copy, PartialEq, Default)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] -pub enum LLMArchitecture { - HfLlama, - HfDeepseek, - HfAuto, - Torchtitan, +pub enum CheckpointSource { + Ephemeral, + #[default] + Stored, + P2P, } -impl std::fmt::Display for LLMArchitecture { +#[cfg(feature = "client")] +impl std::fmt::Display for CheckpointSource { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - LLMArchitecture::HfLlama => f.write_str("HfLlama"), - LLMArchitecture::HfDeepseek => f.write_str("HfDeepseek"), - LLMArchitecture::HfAuto => f.write_str("HfAuto"), - LLMArchitecture::Torchtitan => f.write_str("Torchtitan"), - } - } -} - -#[derive( - Clone, - Debug, - Copy, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - PartialEq, - TS, -)] -#[repr(C)] -pub enum LLMTrainingDataType { - Pretraining, - Finetuning, -} - -#[derive( - AnchorSerialize, - AnchorDeserialize, - InitSpace, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - TS, -)] -#[repr(C)] -#[allow(clippy::large_enum_variant)] -#[derive(Default)] -pub enum LLMTrainingDataLocation { - #[default] - Dummy, - Server(FixedString<{ SOLANA_MAX_STRING_LEN }>), - Local(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), - Http(HttpLLMTrainingDataLocation), - /// link to a JSON file that deserializes to a Vec - WeightedHttp(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), - Preprocessed(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), -} - -#[derive( - AnchorSerialize, - AnchorDeserialize, - InitSpace, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - TS, -)] -#[repr(C)] -#[allow(clippy::large_enum_variant)] -pub struct HttpLLMTrainingDataLocation { - pub location: HttpTrainingDataLocation, - pub token_size_in_bytes: TokenSize, - pub shuffle: Shuffle, -} - -/// these are deserialized from JSON -#[derive(Serialize, Deserialize, Clone, Debug, Copy)] -pub struct LLMTrainingDataLocationAndWeight { - pub location: LLMTrainingDataLocation, - pub weight: f32, -} - -impl Default for LLMTrainingDataLocationAndWeight { - fn default() -> Self { - Self { - location: Default::default(), - weight: 1.0, + CheckpointSource::Ephemeral => write!(f, "Ephemeral"), + CheckpointSource::Stored => write!(f, "Stored"), + CheckpointSource::P2P => write!(f, "P2P"), } } } -impl From - for FixedVec -{ - fn from(location: LLMTrainingDataLocation) -> Self { - FixedVec::from_iter([LLMTrainingDataLocationAndWeight { - location, - weight: 1.0, - }]) - } -} - -impl LLMTrainingDataLocationAndWeight { - pub fn new(location: LLMTrainingDataLocation, weight: f32) -> Self { - Self { location, weight } - } -} - -/// NOTE: Support for Vecs of URLs is not enabled because of the large size it would support. -#[derive( - AnchorSerialize, - AnchorDeserialize, - InitSpace, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - TS, -)] -#[repr(C)] -#[allow(clippy::large_enum_variant)] -pub enum HttpTrainingDataLocation { - SingleUrl(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), - NumberedFiles { - url_template: FixedString<{ SOLANA_MAX_STRING_LEN }>, - start_index: u32, - n_left_pad_zeros: u8, - num_files: u32, - }, - Gcp { - bucket_name: FixedString<{ SOLANA_MAX_STRING_LEN }>, - - /// 0 len === no filter - filter_directory: FixedString<{ SOLANA_MAX_URL_STRING_LEN }>, - }, -} - -#[derive( - AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Debug, Zeroable, Copy, TS, -)] -#[repr(C)] -pub struct LLM { - pub max_seq_len: u32, - pub cold_start_warmup_steps: u32, - pub checkpoint_source: CheckpointSource, - #[serde(default)] - pub checkpoint_data: FixedVec, -} - -impl LLM { - pub fn dummy() -> Self { +#[cfg(feature = "client")] +impl Model { + pub fn dummy(checkpoint_data: FixedVec) -> Self { Self { checkpoint_source: CheckpointSource::Stored, - checkpoint_data: CheckpointData::Dummy.to_fixed_vec(), + checkpoint_data, max_seq_len: 2048, cold_start_warmup_steps: 0, } } - - /// Decode the opaque checkpoint bytes into a [`CheckpointData`]. - pub fn decode_checkpoint(&self) -> Option { - CheckpointData::from_fixed_vec(&self.checkpoint_data).ok() - } } -#[derive( - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - TS, - PartialEq, - Default, -)] -#[repr(C)] -pub enum CheckpointSource { - Ephemeral, - #[default] - Stored, - P2P, +#[derive(Debug)] +pub enum ModelError { + ZeroSeqLen, + CheckpointEphemeral, + CheckpointEmpty, } -impl std::fmt::Display for CheckpointSource { +impl Display for ModelError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CheckpointSource::Ephemeral => write!(f, "Ephemeral"), - CheckpointSource::Stored => write!(f, "Stored"), - CheckpointSource::P2P => write!(f, "P2P"), - } + write!( + f, + "{}", + match self { + ModelError::ZeroSeqLen => "model check failed: max_seq_len is 0.", + ModelError::CheckpointEphemeral => + "model check failed: ephemeral checkpoint not allowed", + ModelError::CheckpointEmpty => "model check failed: checkpoint data is empty", + } + ) } } impl Model { pub fn check(&self) -> bool { - match self { - Model::LLM(llm) => { - if llm.max_seq_len == 0 { - msg!("model check failed: max_seq_len is 0."); - return false; - } - - if matches!(llm.checkpoint_source, CheckpointSource::Ephemeral) { - msg!("model check failed: bad checkpoint (ephemeral)"); - return false; - } + self.check_error().is_ok() + } - let bad_checkpoint = match CheckpointData::from_fixed_vec(&llm.checkpoint_data) { - Ok(CheckpointData::Dummy) => false, - Ok(CheckpointData::Hub { repo_id, .. }) => repo_id.is_empty(), - Ok(CheckpointData::Gcs { bucket, .. }) => bucket.is_empty(), - Err(_) => { - msg!("model check failed: could not deserialize checkpoint data"); - true - } - }; + #[inline(always)] + pub fn check_error(&self) -> std::result::Result<(), ModelError> { + if self.max_seq_len == 0 { + return Err(ModelError::ZeroSeqLen); + } - if bad_checkpoint { - msg!("model check failed: bad checkpoint"); - return false; - } + if matches!(self.checkpoint_source, CheckpointSource::Ephemeral) { + return Err(ModelError::CheckpointEphemeral); + } - true - } + if self.checkpoint_data.is_empty() { + return Err(ModelError::CheckpointEmpty); } + + Ok(()) } } diff --git a/shared/core/src/node_identity.rs b/shared/coordinator/src/node_identity.rs similarity index 84% rename from shared/core/src/node_identity.rs rename to shared/coordinator/src/node_identity.rs index 421f0148d..8fb908fd9 100644 --- a/shared/core/src/node_identity.rs +++ b/shared/coordinator/src/node_identity.rs @@ -1,22 +1,13 @@ use std::fmt::{Debug, Display}; -use anchor_lang::{Space, prelude::*}; +use anchor_lang::prelude::*; use bytemuck::{Pod, Zeroable}; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; - #[derive( - Clone, - Copy, - Default, - Zeroable, - Pod, - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - TS, - Eq, + Clone, Copy, Default, Zeroable, Pod, AnchorSerialize, AnchorDeserialize, Eq, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct NodeIdentity { @@ -45,6 +36,7 @@ impl NodeIdentity { /// In non-Solana usage, we don't have a signer - so /// both signer and p2p_identity are the same pubkey. + #[cfg(feature = "client")] pub fn from_single_key(key: [u8; 32]) -> Self { Self { signer: key, @@ -84,7 +76,3 @@ impl Debug for NodeIdentity { write!(f, ")") } } - -impl Space for NodeIdentity { - const INIT_SPACE: usize = 64; -} diff --git a/shared/coordinator/src/sha.rs b/shared/coordinator/src/sha.rs new file mode 100644 index 000000000..67d133d26 --- /dev/null +++ b/shared/coordinator/src/sha.rs @@ -0,0 +1,30 @@ +use anchor_lang::solana_program::hash::{hash, hashv}; + +pub fn sha256(data: &[u8]) -> [u8; 32] { + let hash_result = hash(data); + hash_result.to_bytes() +} + +pub fn sha256v(data: &[&[u8]]) -> [u8; 32] { + let hash_result = hashv(data); + hash_result.to_bytes() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sha256() { + let data = b"Hello, world!"; + let hash = sha256(data); + assert_eq!( + hash, + [ + 0x31, 0x5f, 0x5b, 0xdb, 0x76, 0xd0, 0x78, 0xc4, 0x3b, 0x8a, 0xc0, 0x06, 0x4e, 0x4a, + 0x01, 0x64, 0x61, 0x2b, 0x1f, 0xce, 0x77, 0xc8, 0x69, 0x34, 0x5b, 0xfc, 0x94, 0xc7, + 0x58, 0x94, 0xed, 0xd3 + ] + ); + } +} diff --git a/shared/core/src/small_boolean.rs b/shared/coordinator/src/small_boolean.rs similarity index 86% rename from shared/core/src/small_boolean.rs rename to shared/coordinator/src/small_boolean.rs index 7ff9d8d7c..d8172bd8d 100644 --- a/shared/core/src/small_boolean.rs +++ b/shared/coordinator/src/small_boolean.rs @@ -2,23 +2,13 @@ use std::fmt::Debug; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; use bytemuck::{Pod, Zeroable}; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; #[derive( - Copy, - Clone, - Eq, - PartialEq, - Hash, - Zeroable, - Pod, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, + Copy, Clone, Eq, PartialEq, Hash, Zeroable, Pod, AnchorDeserialize, AnchorSerialize, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(transparent)] pub struct SmallBoolean(pub u8); diff --git a/shared/core/src/swap_or_not.rs b/shared/coordinator/src/swap_or_not.rs similarity index 98% rename from shared/core/src/swap_or_not.rs rename to shared/coordinator/src/swap_or_not.rs index 45827e426..c42c95ce7 100644 --- a/shared/core/src/swap_or_not.rs +++ b/shared/coordinator/src/swap_or_not.rs @@ -1,4 +1,4 @@ -use crate::sha256::sha256v; +use crate::sha::sha256v; const SHUFFLE_ROUND_COUNT: u8 = 90; diff --git a/shared/coordinator/src/types.rs b/shared/coordinator/src/types.rs index ac49815d7..8fdbcd228 100644 --- a/shared/coordinator/src/types.rs +++ b/shared/coordinator/src/types.rs @@ -1,8 +1,7 @@ -use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; +use anchor_lang::prelude::*; use bytemuck::Zeroable; -use psyche_core::SmallBoolean; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; + +use crate::small_boolean::SmallBoolean; /// Salt constants for deterministic shuffling pub mod salts { @@ -11,17 +10,10 @@ pub mod salts { pub const COOLDOWN: &str = "cooldown"; } -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, +#[derive(Clone, Copy, Default, PartialEq, Zeroable, AnchorDeserialize, AnchorSerialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub enum Committee { @@ -31,6 +23,7 @@ pub enum Committee { Trainer, } +#[cfg(feature = "client")] impl std::fmt::Display for Committee { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -41,17 +34,10 @@ impl std::fmt::Display for Committee { } } -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, +#[derive(Clone, Copy, Default, PartialEq, Zeroable, AnchorDeserialize, AnchorSerialize)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct CommitteeProof { @@ -61,18 +47,11 @@ pub struct CommitteeProof { } #[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, + Clone, Copy, Default, PartialEq, Zeroable, AnchorDeserialize, AnchorSerialize, InitSpace, +)] +#[cfg_attr( + feature = "client", + derive(Debug, serde::Serialize, serde::Deserialize, ts_rs::TS) )] #[repr(C)] pub struct WitnessProof { diff --git a/shared/core/Cargo.toml b/shared/core/Cargo.toml index 7e74f7135..ae7c07dbb 100644 --- a/shared/core/Cargo.toml +++ b/shared/core/Cargo.toml @@ -12,14 +12,10 @@ fast-math.workspace = true rand = { workspace = true, optional = true } serde.workspace = true postcard.workspace = true -fnv = "1.0.7" serde_arrays = "0.1.0" -bitvec = { git = "https://github.com/arilotter/bitvec", rev = "d33a2437f810ee4229457dfd1137d807914671f8", features = [ - "serde", - "std", -] } +serde_json.workspace = true ts-rs.workspace = true -data-encoding = "2.8.0" +psyche-coordinator = { workspace = true, features = ["client"] } [target.'cfg(not(target_os = "solana"))'.dependencies] sha2.workspace = true diff --git a/shared/coordinator/src/commitment.rs b/shared/core/src/commitment.rs similarity index 61% rename from shared/coordinator/src/commitment.rs rename to shared/core/src/commitment.rs index f8aa4ef64..1dba2c4fe 100644 --- a/shared/coordinator/src/commitment.rs +++ b/shared/core/src/commitment.rs @@ -1,8 +1,8 @@ -use anchor_lang::{AnchorDeserialize, AnchorSerialize, prelude::borsh}; use bytemuck::Zeroable; +use psyche_coordinator::coordinator::Witness; use serde::{Deserialize, Deserializer, Serialize}; -#[derive(Clone, Debug, Zeroable, Copy, AnchorDeserialize, AnchorSerialize)] +#[derive(Clone, Debug, Zeroable, Copy)] #[repr(C)] pub struct Commitment { pub data_hash: [u8; 32], @@ -45,3 +45,25 @@ impl<'de> Deserialize<'de> for Commitment { }) } } + +pub fn select_consensus_commitment_by_witnesses( + commitments: &[Commitment], + witnesses: &[Witness], + witness_quorum: u16, +) -> Option { + let mut scores = vec![0; commitments.len()]; + for witness in witnesses { + for (index, commitment) in commitments.iter().enumerate() { + if witness.broadcast_bloom.contains(&commitment.data_hash) { + scores[index] += 1; + break; + } + } + } + scores + .into_iter() + .enumerate() + .filter(|(_, score)| *score >= witness_quorum) + .max_by_key(|(_, score)| *score) + .map(|(index, _)| index) +} diff --git a/shared/core/src/coordinator.rs b/shared/core/src/coordinator.rs new file mode 100644 index 000000000..0241db775 --- /dev/null +++ b/shared/core/src/coordinator.rs @@ -0,0 +1,114 @@ +use serde::{Deserialize, Serialize}; +use ts_rs::TS; + +use crate::{Shuffle, TokenSize}; + +#[derive(Clone, Serialize, Deserialize, TS, Default, Debug)] +pub struct WitnessMetadata { + pub step: u32, + pub tokens_per_sec: f32, + pub bandwidth_per_sec: f32, + pub loss: f32, + pub evals: Vec, + pub prompt_results: Vec, + pub prompt_index: u8, + pub efficency: f32, +} + +#[derive(Clone, Serialize, Deserialize, TS, Default, Debug)] +pub struct WitnessEvalResult { + pub name: String, + pub value: f32, +} + +#[derive(Clone, Debug, Copy, Serialize, Deserialize, TS, PartialEq)] +pub enum LLMArchitecture { + HfLlama, + HfDeepseek, + HfAuto, + Torchtitan, +} + +impl std::fmt::Display for LLMArchitecture { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LLMArchitecture::HfLlama => f.write_str("HfLlama"), + LLMArchitecture::HfDeepseek => f.write_str("HfDeepseek"), + LLMArchitecture::HfAuto => f.write_str("HfAuto"), + LLMArchitecture::Torchtitan => f.write_str("Torchtitan"), + } + } +} + +#[derive(Clone, Debug, Copy, Serialize, Deserialize, PartialEq, TS)] +pub enum LLMTrainingDataType { + Pretraining, + Finetuning, +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS, Default)] +pub enum LLMTrainingDataLocation { + #[default] + Dummy, + Server(String), + Local(String), + Http(HttpLLMTrainingDataLocation), + /// link to a JSON file that deserializes to a Vec + WeightedHttp(String), + Preprocessed(String), +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub struct HttpLLMTrainingDataLocation { + pub location: HttpTrainingDataLocation, + pub token_size_in_bytes: TokenSize, + pub shuffle: Shuffle, +} + +/// these are deserialized from JSON +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct LLMTrainingDataLocationAndWeight { + pub location: LLMTrainingDataLocation, + pub weight: f32, +} + +impl Default for LLMTrainingDataLocationAndWeight { + fn default() -> Self { + Self { + location: Default::default(), + weight: 1.0, + } + } +} + +impl From for Vec { + fn from(location: LLMTrainingDataLocation) -> Self { + vec![LLMTrainingDataLocationAndWeight { + location, + weight: 1.0, + }] + } +} + +impl LLMTrainingDataLocationAndWeight { + pub fn new(location: LLMTrainingDataLocation, weight: f32) -> Self { + Self { location, weight } + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, TS)] +pub enum HttpTrainingDataLocation { + SingleUrl(String), + NumberedFiles { + url_template: String, + start_index: u32, + n_left_pad_zeros: u8, + num_files: u32, + }, + Gcp { + bucket_name: String, + + /// 0 len === no filter + filter_directory: Option, + }, +} diff --git a/shared/coordinator/src/data_selection.rs b/shared/core/src/data_selection.rs similarity index 95% rename from shared/coordinator/src/data_selection.rs rename to shared/core/src/data_selection.rs index 2dc566d3e..20d312aee 100644 --- a/shared/coordinator/src/data_selection.rs +++ b/shared/core/src/data_selection.rs @@ -1,8 +1,13 @@ -use crate::{Committee, CommitteeSelection, Coordinator, Round}; - -use psyche_core::{BatchId, ClosedInterval, NodeIdentity, deterministic_shuffle}; +use psyche_coordinator::{ + committee_selection::CommitteeSelection, + coordinator::{Coordinator, Round}, + node_identity::NodeIdentity, + types::Committee, +}; use std::{collections::BTreeMap, fmt}; +use crate::{BatchId, ClosedInterval, deterministic_shuffle}; + /// Assigns data batches to nodes based on committee roles. pub fn assign_data_for_state( coordinator: &Coordinator, @@ -131,9 +136,11 @@ pub fn get_data_index_for_step(coordinator: &Coordinator, target_step: u32) -> u #[cfg(test)] mod tests { use super::*; - use crate::{Client, ClientState, CommitteeSelection, Coordinator}; use bytemuck::Zeroable; - use psyche_core::{FixedVec, NodeIdentity}; + use psyche_coordinator::{ + coordinator::{Client, ClientState}, + fixed_vec::FixedVec, + }; fn create_test_coordinator( num_nodes: usize, diff --git a/shared/core/src/data_shuffle.rs b/shared/core/src/data_shuffle.rs index 09252cbe8..a77fd6c37 100644 --- a/shared/core/src/data_shuffle.rs +++ b/shared/core/src/data_shuffle.rs @@ -1,23 +1,7 @@ -use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; -use bytemuck::Zeroable; use serde::{Deserialize, Serialize}; use ts_rs::TS; -#[derive( - AnchorSerialize, - AnchorDeserialize, - InitSpace, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - PartialEq, - TS, -)] -#[repr(C)] -#[derive(Default)] +#[derive(Serialize, Deserialize, Clone, Debug, Copy, PartialEq, TS, Default)] pub enum Shuffle { #[default] DontShuffle, diff --git a/shared/core/src/fixed_string.rs b/shared/core/src/fixed_string.rs deleted file mode 100644 index 3ec110574..000000000 --- a/shared/core/src/fixed_string.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::fmt::Display; - -use anchor_lang::{ - AnchorDeserialize, AnchorSerialize, InitSpace, - prelude::{borsh, thiserror}, -}; -use bytemuck::Zeroable; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; - -use crate::serde_utils::{serde_deserialize_string, serde_serialize_string}; - -#[derive(thiserror::Error, Debug)] -#[error("string of length {} doesn't fit in FixedString<{}>", 0.0, 0.1)] -pub struct FixedStringError((usize, usize)); - -#[derive( - Serialize, - Deserialize, - Clone, - Copy, - TS, - AnchorSerialize, - AnchorDeserialize, - PartialEq, - Eq, - InitSpace, - Zeroable, -)] -#[repr(C)] -pub struct FixedString( - #[serde( - serialize_with = "serde_serialize_string", - deserialize_with = "serde_deserialize_string" - )] - #[ts(as = "String")] - [u8; L], -); - -impl std::fmt::Debug for FixedString { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let used_bytes = match self.0.iter().position(|&b| b == 0) { - Some(null_pos) => null_pos, - None => L, - }; - - let zero_bytes = L - used_bytes; - - let string_content = String::from(self); - - write!( - f, - "\"{string_content}\" ({used_bytes}/{L} bytes, {zero_bytes} zeroes)" - ) - } -} - -impl Display for FixedString { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", String::from(self)) - } -} - -impl Default for FixedString { - fn default() -> Self { - Self([0u8; L]) - } -} - -impl FixedString { - pub fn new() -> Self { - Default::default() - } - pub fn from_str_truncated(s: &str) -> Self { - let mut array = [0u8; L]; - let bytes = s.as_bytes(); - let len = bytes.len().min(L); - array[..len].copy_from_slice(&bytes[..len]); - Self(array) - } - - pub fn is_empty(&self) -> bool { - self.0[0] == 0 - } -} - -impl TryFrom<&str> for FixedString { - type Error = FixedStringError; - - fn try_from(s: &str) -> Result { - let mut array = [0u8; L]; - let bytes = s.as_bytes(); - if bytes.len() > L { - return Err(FixedStringError((bytes.len(), L))); - } - array[..bytes.len()].copy_from_slice(bytes); - Ok(Self(array)) - } -} - -impl From<&FixedString> for String { - fn from(value: &FixedString) -> Self { - let sliced = match value.0.iter().position(|&b| b == 0) { - Some(null_pos) => &value.0[0..null_pos], - None => &value.0, - }; - String::from_utf8_lossy(sliced).to_string() - } -} diff --git a/shared/core/src/lib.rs b/shared/core/src/lib.rs index 88ea2a932..c12dcfc51 100644 --- a/shared/core/src/lib.rs +++ b/shared/core/src/lib.rs @@ -1,31 +1,38 @@ #![allow(unexpected_cfgs)] mod batch_id; -mod bloom; mod bounded_queue; mod boxed_future; mod cancellable_barrier; +mod commitment; +mod coordinator; +mod data_selection; mod data_shuffle; mod definitions; mod deterministic_shuffle; -mod fixed_string; -mod fixed_vec; mod interval_tree; mod lcg; mod merkle_tree; -mod node_identity; +mod model_extra_data; mod running_average; mod serde_utils; -mod sha256; mod similarity; mod sized_iterator; -mod small_boolean; -mod swap_or_not; mod testing; mod token_size; +pub use coordinator::{ + HttpLLMTrainingDataLocation, HttpTrainingDataLocation, LLMArchitecture, + LLMTrainingDataLocation, LLMTrainingDataLocationAndWeight, LLMTrainingDataType, + WitnessEvalResult, WitnessMetadata, +}; + +pub use commitment::{Commitment, select_consensus_commitment_by_witnesses}; +pub use data_selection::{ + assign_data_for_state, get_batch_ids_for_node, get_batch_ids_for_round, get_data_index_for_step, +}; + pub use batch_id::BatchId; -pub use bloom::Bloom; pub use bounded_queue::BoundedQueue; pub use boxed_future::BoxedFuture; pub use cancellable_barrier::{Barrier, CancellableBarrier, CancelledBarrier}; @@ -35,24 +42,19 @@ pub use definitions::{ OptimizerDefinition, }; pub use deterministic_shuffle::deterministic_shuffle; -pub use fixed_string::FixedString; -pub use fixed_vec::FixedVec; pub use interval_tree::{ClosedInterval, IntervalTree}; pub use lcg::LCG; -pub use merkle_tree::{HashWrapper as MerkleRoot, MerkleTree, OwnedProof, Proof}; -pub use node_identity::NodeIdentity; -pub use running_average::RunningAverage; -pub use serde_utils::{ - serde_deserialize_optional_string, serde_deserialize_string, serde_deserialize_vec_to_array, - serde_serialize_array_as_vec, serde_serialize_optional_string, serde_serialize_string, +pub use merkle_tree::{MerkleTree, OwnedProof, Proof}; +pub use model_extra_data::{ + CHECKPOINT_DATA_MAX_LEN, CONFIG_PREFIX, CheckpointData, MODEL_CONFIG_FILENAME, ModelExtraData, + RunMetadata, }; -pub use sha256::{sha256, sha256v}; +pub use running_average::RunningAverage; +pub use serde_utils::{serde_deserialize_vec_to_array, serde_serialize_array_as_vec}; pub use similarity::{ DistanceThresholds, hamming_distance, is_similar, jaccard_distance, manhattan_distance, }; pub use sized_iterator::SizedIterator; -pub use small_boolean::SmallBoolean; -pub use swap_or_not::compute_shuffled_index; pub use testing::IntegrationTestLogMarker; pub use token_size::TokenSize; diff --git a/shared/core/src/merkle_tree.rs b/shared/core/src/merkle_tree.rs index 5b5924a1d..2d0857345 100644 --- a/shared/core/src/merkle_tree.rs +++ b/shared/core/src/merkle_tree.rs @@ -1,13 +1,8 @@ #![allow(clippy::manual_is_multiple_of)] -use std::fmt::Debug; - -use crate::sha256::sha256v; - use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; -use bytemuck::Zeroable; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; +use psyche_coordinator::{hash_wrapper::HashWrapper, sha::sha256v}; +use std::fmt::Debug; // from https://github.com/solana-labs/solana/blob/27eff8408b7223bb3c4ab70523f8a8dca3ca6645/merkle-tree/src/merkle_tree.rs @@ -32,54 +27,6 @@ macro_rules! hash_intermediate { } } -/// This wrapper is used to implement the `Space` trait for the actual hash. -#[derive( - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - PartialEq, - Eq, - Clone, - Default, - Zeroable, - Copy, - TS, -)] -pub struct HashWrapper { - pub inner: [u8; 32], -} - -impl HashWrapper { - pub fn new(inner: [u8; 32]) -> Self { - Self { inner } - } - - pub fn fmt_short(&self) -> String { - data_encoding::HEXLOWER.encode(&self.inner[..5]) - } - - pub fn fmt_full(&self) -> String { - data_encoding::HEXLOWER.encode(&self.inner) - } -} - -impl Debug for HashWrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "HashWrapper({})", self.fmt_full()) - } -} - -impl AsRef<[u8]> for HashWrapper { - fn as_ref(&self) -> &[u8] { - &self.inner - } -} - -impl anchor_lang::Space for HashWrapper { - const INIT_SPACE: usize = 32; -} - #[derive(Debug)] pub struct MerkleTree { leaf_count: usize, @@ -93,17 +40,7 @@ pub struct ProofEntry<'a>( Option<&'a HashWrapper>, ); -#[derive( - Debug, - PartialEq, - Eq, - Clone, - Serialize, - Deserialize, - AnchorDeserialize, - AnchorSerialize, - InitSpace, -)] +#[derive(Debug, PartialEq, Eq, Clone, AnchorDeserialize, AnchorSerialize, InitSpace)] pub struct OwnedProofEntry { target: HashWrapper, left_sibling: Option, @@ -134,18 +71,7 @@ impl<'a> From> for OwnedProofEntry { #[derive(Debug, Default, PartialEq, Eq)] pub struct Proof<'a>(Vec>); -#[derive( - Debug, - Default, - PartialEq, - Eq, - Clone, - AnchorDeserialize, - AnchorSerialize, - Deserialize, - Serialize, - InitSpace, -)] +#[derive(Debug, Default, PartialEq, Eq, Clone, AnchorDeserialize, AnchorSerialize, InitSpace)] pub struct OwnedProof { #[max_len(SOLANA_MAX_PROOFS_LEN)] entries: Vec, diff --git a/shared/coordinator/src/model_extra_data.rs b/shared/core/src/model_extra_data.rs similarity index 70% rename from shared/coordinator/src/model_extra_data.rs rename to shared/core/src/model_extra_data.rs index 627dceefc..c95575d8a 100644 --- a/shared/coordinator/src/model_extra_data.rs +++ b/shared/core/src/model_extra_data.rs @@ -1,8 +1,12 @@ -use anchor_lang::prelude::borsh; -use psyche_core::{FixedVec, LearningRateSchedule, OptimizerDefinition}; +use psyche_coordinator::fixed_vec::FixedVec; use serde::{Deserialize, Serialize}; -use crate::model::{LLMArchitecture, LLMTrainingDataLocation, LLMTrainingDataType}; +use crate::{ + ConstantLR, LearningRateSchedule, OptimizerDefinition, + coordinator::{ + HttpTrainingDataLocation, LLMArchitecture, LLMTrainingDataLocation, LLMTrainingDataType, + }, +}; /// Path within the bucket where config is stored pub const CONFIG_PREFIX: &str = "config"; @@ -42,7 +46,7 @@ impl Default for ModelExtraData { architecture: LLMArchitecture::HfLlama, data_type: LLMTrainingDataType::Pretraining, data_location: LLMTrainingDataLocation::default(), - lr_schedule: LearningRateSchedule::Constant(psyche_core::ConstantLR::default()), + lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), optimizer: OptimizerDefinition::Dummy, run_metadata: None, checkpoint: CheckpointData::default(), @@ -84,18 +88,15 @@ impl ModelExtraData { LLMTrainingDataLocation::Dummy => false, LLMTrainingDataLocation::Server(url) => url.is_empty(), LLMTrainingDataLocation::Local(_) => false, - LLMTrainingDataLocation::Http(http_loc) => { - use crate::model::HttpTrainingDataLocation; - match &http_loc.location { - HttpTrainingDataLocation::SingleUrl(url) => url.is_empty(), - HttpTrainingDataLocation::NumberedFiles { - url_template, - num_files, - .. - } => url_template.is_empty() || *num_files == 0, - HttpTrainingDataLocation::Gcp { bucket_name, .. } => bucket_name.is_empty(), - } - } + LLMTrainingDataLocation::Http(http_loc) => match &http_loc.location { + HttpTrainingDataLocation::SingleUrl(url) => url.is_empty(), + HttpTrainingDataLocation::NumberedFiles { + url_template, + num_files, + .. + } => url_template.is_empty() || *num_files == 0, + HttpTrainingDataLocation::Gcp { bucket_name, .. } => bucket_name.is_empty(), + }, LLMTrainingDataLocation::WeightedHttp(url) => url.is_empty(), LLMTrainingDataLocation::Preprocessed(url) => url.is_empty(), }; @@ -114,9 +115,7 @@ impl ModelExtraData { /// Off-chain checkpoint data that gets serialized into opaque bytes for on-chain storage. /// This decouples the on-chain account layout from storage backend details. -#[derive( - Debug, Clone, Serialize, Deserialize, borsh::BorshSerialize, borsh::BorshDeserialize, Default, -)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub enum CheckpointData { #[default] Dummy, @@ -134,19 +133,16 @@ pub const CHECKPOINT_DATA_MAX_LEN: usize = 256; impl CheckpointData { pub fn to_fixed_vec(&self) -> FixedVec { - let bytes = - borsh::to_vec(self).expect("CheckpointData borsh serialization should not fail"); - let mut fv = FixedVec::new(); - for b in bytes { - fv.push(b) - .expect("CheckpointData serialized size exceeds CHECKPOINT_DATA_MAX_LEN"); - } - fv + let bytes = postcard::to_stdvec(self) + .expect("CheckpointData postcard serialization should not fail"); + + FixedVec::try_from_iter(bytes) + .expect("CheckpointData serialized size exceeds CHECKPOINT_DATA_MAX_LEN") } pub fn from_fixed_vec( fv: &FixedVec, - ) -> Result { - borsh::BorshDeserialize::try_from_slice(&fv[..]) + ) -> Result { + postcard::from_bytes(&fv[..]) } } diff --git a/shared/core/src/serde_utils.rs b/shared/core/src/serde_utils.rs index 623959951..cfda24a02 100644 --- a/shared/core/src/serde_utils.rs +++ b/shared/core/src/serde_utils.rs @@ -1,66 +1,5 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; -pub fn serde_serialize_string( - run_id: &[u8], - serializer: S, -) -> std::result::Result -where - S: Serializer, -{ - // Convert bytes to string, trimming null bytes - let s = String::from_utf8_lossy(run_id) - .trim_matches(char::from(0)) - .to_string(); - serializer.serialize_str(&s) -} - -pub fn serde_deserialize_string<'de, D, const N: usize>( - deserializer: D, -) -> std::result::Result<[u8; N], D::Error> -where - D: Deserializer<'de>, -{ - let s: String = ::deserialize(deserializer)?; - let mut bytes = [0u8; N]; - let len = std::cmp::min(s.len(), N); - bytes[..len].copy_from_slice(&s.as_bytes()[..len]); - Ok(bytes) -} - -pub fn serde_serialize_optional_string( - str_bytes: &Option<[u8; N]>, - serializer: S, -) -> std::result::Result -where - S: Serializer, -{ - if let Some(run_id) = str_bytes { - let s = String::from_utf8_lossy(run_id) - .trim_matches(char::from(0)) - .to_string(); - serializer.serialize_some(&s) - } else { - serializer.serialize_none() - } -} - -pub fn serde_deserialize_optional_string<'de, D, const N: usize>( - deserializer: D, -) -> std::result::Result, D::Error> -where - D: Deserializer<'de>, -{ - let s: Option = Option::deserialize(deserializer)?; - if let Some(s) = s { - let mut bytes = [0u8; N]; - let len = std::cmp::min(s.len(), N); - bytes[..len].copy_from_slice(&s.as_bytes()[..len]); - Ok(Some(bytes)) - } else { - Ok(None) - } -} - pub fn serde_serialize_array_as_vec( array: &[T], serializer: S, @@ -84,59 +23,3 @@ where arr[..len].copy_from_slice(&vec[..len]); Ok(arr) } - -#[cfg(test)] -mod test { - use super::*; - - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - struct MyOptionalStrStruct { - #[serde( - serialize_with = "serde_serialize_optional_string", - deserialize_with = "serde_deserialize_optional_string", - default - )] - field: Option<[u8; 64]>, - } - - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] - struct MyStrStruct { - #[serde( - serialize_with = "serde_serialize_string", - deserialize_with = "serde_deserialize_string" - )] - field: [u8; 64], - } - - #[test] - fn test_serialize_deserialize_optional_string_some() { - let my_struct = MyOptionalStrStruct { - field: Some([1u8; 64]), - }; - - let bytes = postcard::to_stdvec(&my_struct).unwrap(); - let deserialized_struct: MyOptionalStrStruct = postcard::from_bytes(&bytes).unwrap(); - - assert_eq!(my_struct, deserialized_struct); - } - - #[test] - fn test_serialize_deserialize_optional_string_none() { - let my_struct = MyOptionalStrStruct { field: None }; - - let bytes = postcard::to_stdvec(&my_struct).unwrap(); - let deserialized_struct: MyOptionalStrStruct = postcard::from_bytes(&bytes).unwrap(); - - assert_eq!(my_struct, deserialized_struct); - } - - #[test] - fn test_serialize_deserialize_string() { - let my_struct = MyStrStruct { field: [1u8; 64] }; - - let bytes = postcard::to_stdvec(&my_struct).unwrap(); - let deserialized_struct: MyStrStruct = postcard::from_bytes(&bytes).unwrap(); - - assert_eq!(my_struct, deserialized_struct); - } -} diff --git a/shared/core/src/sha256.rs b/shared/core/src/sha256.rs deleted file mode 100644 index 169b9c8a9..000000000 --- a/shared/core/src/sha256.rs +++ /dev/null @@ -1,56 +0,0 @@ -#[cfg(not(target_os = "solana"))] -use sha2::{Digest, Sha256}; - -#[cfg(target_os = "solana")] -use anchor_lang::solana_program::hash::{hash, hashv}; - -pub fn sha256(data: &[u8]) -> [u8; 32] { - #[cfg(not(target_os = "solana"))] - { - let mut hasher = Sha256::new(); - hasher.update(data); - hasher.finalize().into() - } - - #[cfg(target_os = "solana")] - { - let hash_result = hash(data); - hash_result.to_bytes() - } -} - -pub fn sha256v(data: &[&[u8]]) -> [u8; 32] { - #[cfg(not(target_os = "solana"))] - { - let mut hasher = Sha256::new(); - for val in data { - hasher.update(val) - } - hasher.finalize().into() - } - - #[cfg(target_os = "solana")] - { - let hash_result = hashv(data); - hash_result.to_bytes() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sha256() { - let data = b"Hello, world!"; - let hash = sha256(data); - assert_eq!( - hash, - [ - 0x31, 0x5f, 0x5b, 0xdb, 0x76, 0xd0, 0x78, 0xc4, 0x3b, 0x8a, 0xc0, 0x06, 0x4e, 0x4a, - 0x01, 0x64, 0x61, 0x2b, 0x1f, 0xce, 0x77, 0xc8, 0x69, 0x34, 0x5b, 0xfc, 0x94, 0xc7, - 0x58, 0x94, 0xed, 0xd3 - ] - ); - } -} diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 48f452418..79064cebe 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -2,8 +2,10 @@ use anyhow::{Result, bail}; use async_trait::async_trait; use bytemuck::Zeroable; use futures::future::try_join_all; -use parquet::data_type::AsBytes; -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_coordinator::{ + coordinator::{Coordinator, HealthChecks}, + model::CheckpointBytes, +}; use psyche_core::BatchId; use psyche_data_provider::{ DataProviderTcpClient, DataProviderTcpServer, LengthKnownDataProvider, TokenizedData, @@ -33,7 +35,7 @@ impl WatcherBackend for DummyBackend { bail!("Data provider does not send health check"); } - async fn send_checkpoint(&mut self, _checkpoint: model::CheckpointBytes) -> anyhow::Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: CheckpointBytes) -> anyhow::Result<()> { bail!("Data provider does not send checkpoints"); } } diff --git a/shared/data-provider/src/http.rs b/shared/data-provider/src/http.rs index 860c26ca6..38cf7ba71 100644 --- a/shared/data-provider/src/http.rs +++ b/shared/data-provider/src/http.rs @@ -2,8 +2,7 @@ use std::{str::FromStr, time::Duration}; use anyhow::{Context, Result, anyhow, bail}; use futures::future::join_all; -use psyche_coordinator::model::HttpTrainingDataLocation; -use psyche_core::{BatchId, Shuffle, TokenSize}; +use psyche_core::{BatchId, HttpTrainingDataLocation, Shuffle, TokenSize}; use rand::seq::SliceRandom; use rand_chacha::ChaCha8Rng; use rand_chacha::rand_core::SeedableRng; @@ -187,30 +186,13 @@ impl FileURLs { n_left_pad_zeros, num_files, } => { - Self::from_template( - &String::from(url_template), - *start_index, - *n_left_pad_zeros, - *num_files, - ) - .await + Self::from_template(url_template, *start_index, *n_left_pad_zeros, *num_files).await } HttpTrainingDataLocation::SingleUrl(url) => Self::from_list(&[String::from(url)]).await, HttpTrainingDataLocation::Gcp { bucket_name, filter_directory, - } => { - let filter_directory = String::from(filter_directory); - Self::from_gcp_bucket( - &String::from(bucket_name), - if filter_directory.is_empty() { - None - } else { - Some(filter_directory) - }, - ) - .await - } + } => Self::from_gcp_bucket(bucket_name, filter_directory.clone()).await, } } } diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 5972a00bb..9eb6bf60f 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -7,7 +7,7 @@ use hf_hub::{ tokio::{ApiError, UploadSource}, }, }; -use psyche_coordinator::model_extra_data::ModelExtraData; +use psyche_core::ModelExtraData; use std::{path::PathBuf, time::Instant}; use tracing::{debug, error, info}; diff --git a/shared/data-provider/src/remote/server.rs b/shared/data-provider/src/remote/server.rs index defe86125..b5f77d4d6 100644 --- a/shared/data-provider/src/remote/server.rs +++ b/shared/data-provider/src/remote/server.rs @@ -1,6 +1,6 @@ use anyhow::Result; use bytemuck::Zeroable; -use psyche_coordinator::Coordinator; +use psyche_coordinator::coordinator::Coordinator; use psyche_core::BatchId; use psyche_network::{ClientNotification, PublicKey, TcpServer}; use psyche_watcher::Backend; diff --git a/shared/data-provider/src/weighted/http.rs b/shared/data-provider/src/weighted/http.rs index 41abf4674..44c3bc147 100644 --- a/shared/data-provider/src/weighted/http.rs +++ b/shared/data-provider/src/weighted/http.rs @@ -1,5 +1,4 @@ -use psyche_coordinator::model::HttpLLMTrainingDataLocation; -use psyche_core::Shuffle; +use psyche_core::{HttpLLMTrainingDataLocation, Shuffle}; use crate::http::{FileURLs, HttpDataProvider}; @@ -75,8 +74,9 @@ pub enum HttpProviderConfigs { #[cfg(test)] mod tests { use anyhow::Result; - use psyche_coordinator::model::{HttpLLMTrainingDataLocation, HttpTrainingDataLocation}; - use psyche_core::{BatchId, Shuffle, TokenSize}; + use psyche_core::{ + BatchId, HttpLLMTrainingDataLocation, HttpTrainingDataLocation, Shuffle, TokenSize, + }; use std::{ fs::{self, File}, io::Write, @@ -148,7 +148,7 @@ mod tests { ( HttpLLMTrainingDataLocation { location: HttpTrainingDataLocation::NumberedFiles { - url_template: url_template.as_str().try_into().unwrap(), + url_template: url_template.as_str().into(), start_index: 0, n_left_pad_zeros: 3, num_files: files.len() as u32, diff --git a/shared/event-sourcing/src/events.rs b/shared/event-sourcing/src/events.rs index 8d1bd41e4..b7f079731 100644 --- a/shared/event-sourcing/src/events.rs +++ b/shared/event-sourcing/src/events.rs @@ -3,7 +3,7 @@ use derive_more::Display; use first_class_variants::first_class_variants; use iroh::EndpointId; use iroh_blobs::Hash as BlobHash; -use psyche_coordinator::RunState; +use psyche_coordinator::coordinator::RunState; use psyche_core::BatchId; use serde::{Deserialize, Serialize}; diff --git a/shared/event-sourcing/src/projection.rs b/shared/event-sourcing/src/projection.rs index 54473a274..33162b333 100644 --- a/shared/event-sourcing/src/projection.rs +++ b/shared/event-sourcing/src/projection.rs @@ -1,17 +1,16 @@ +use crate::events::{ + Client, Cooldown, CoordinatorEvent, Event, EventData, P2P, ResourceSnapshot, RpcCallType, + Train, Warmup, +}; use chrono::{DateTime, Utc}; use indexmap::IndexMap; use iroh::EndpointId; use iroh_blobs::Hash as BlobHash; -use psyche_coordinator::{RunState, model::CheckpointSource}; +use psyche_coordinator::{coordinator::RunState, model::CheckpointSource}; use psyche_core::BatchId; use psyche_metrics::SelectedPath; use std::collections::{BTreeMap, HashMap, HashSet}; -use crate::events::{ - Client, Cooldown, CoordinatorEvent, Event, EventData, P2P, ResourceSnapshot, RpcCallType, - Train, Warmup, -}; - // ── Coordinator ─────────────────────────────────────────────────────────────── #[derive(Debug, Clone)] diff --git a/shared/event-sourcing/src/store.rs b/shared/event-sourcing/src/store.rs index 59d5e0efb..9d48c3c3a 100644 --- a/shared/event-sourcing/src/store.rs +++ b/shared/event-sourcing/src/store.rs @@ -288,7 +288,7 @@ impl FileWriterState { mod tests { use super::*; use crate::{event, events::*}; - use psyche_coordinator::RunState; + use psyche_coordinator::coordinator::RunState; use psyche_core::{BatchId, ClosedInterval}; use serial_test::serial; use std::fs; diff --git a/shared/event-sourcing/src/timeline.rs b/shared/event-sourcing/src/timeline.rs index 856300a5b..d38efc204 100644 --- a/shared/event-sourcing/src/timeline.rs +++ b/shared/event-sourcing/src/timeline.rs @@ -5,7 +5,9 @@ use std::time::SystemTime; use chrono::{DateTime, Utc}; use indexmap::IndexMap; -use psyche_coordinator::{CommitteeSelection, Coordinator, assign_data_for_state}; +use psyche_coordinator::committee_selection::CommitteeSelection; +use psyche_coordinator::coordinator::Coordinator; +use psyche_core::assign_data_for_state; use crate::events::Event; use crate::projection::{ClusterProjection, ClusterSnapshot, CoordinatorStateSnapshot}; @@ -55,9 +57,8 @@ fn coordinator_to_snapshot( timestamp: DateTime, coord: &Coordinator, ) -> CoordinatorStateSnapshot { - let checkpoint_source = match coord.model { - psyche_coordinator::model::Model::LLM(llm) => llm.checkpoint_source, - }; + let checkpoint_source = coord.model.checkpoint_source; + let client_ids: Vec = coord .epoch_state .clients diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 32981f725..cf6b2b88e 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -1,5 +1,9 @@ use anyhow::Result; -use psyche_coordinator::{Coordinator, HealthChecks, Witness, WitnessMetadata, model}; +use psyche_coordinator::{ + coordinator::{Coordinator, HealthChecks, Witness}, + model::CheckpointBytes, +}; +use psyche_core::WitnessMetadata; use serde::{Deserialize, Serialize}; #[allow(clippy::large_enum_variant)] @@ -28,5 +32,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::CheckpointBytes) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: CheckpointBytes) -> Result<()>; } diff --git a/shared/watcher/src/tui.rs b/shared/watcher/src/tui.rs index fe525f5a7..fb0d06ce2 100644 --- a/shared/watcher/src/tui.rs +++ b/shared/watcher/src/tui.rs @@ -1,15 +1,15 @@ -use std::{ - fmt::{Display, Formatter}, - time::{Duration, Instant, SystemTime, UNIX_EPOCH}, -}; - -use psyche_coordinator::{Coordinator, RunState, model::Model}; +use psyche_coordinator::coordinator::{Coordinator, RunState}; +use psyche_core::CheckpointData; use psyche_tui::ratatui::{ buffer::Buffer, layout::{Constraint, Layout, Rect}, text::Line, widgets::{Block, Paragraph, Widget}, }; +use std::{ + fmt::{Display, Formatter}, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, +}; #[derive(Default, Debug)] pub struct CoordinatorTui; @@ -182,14 +182,14 @@ impl From<&Coordinator> for CoordinatorTuiState { .iter() .map(|c| format!("{:?}", c.id)) .collect(), - model_checkpoint: match &value.model { - Model::LLM(l) => format!( - "{} ({})", - l.checkpoint_source, - l.decode_checkpoint() - .map_or("unknown".to_string(), |d| format!("{:?}", d)) - ), - }, + model_checkpoint: format!( + "{} ({})", + value.model.checkpoint_source, + CheckpointData::from_fixed_vec(&value.model.checkpoint_data) + .ok() + .map_or("unknown".to_string(), |d| format!("{:?}", d)) + ), + exited_clients: value.epoch_state.exited_clients.len(), pending_pause: value.pending_pause.is_true(), } diff --git a/shared/watcher/src/watcher.rs b/shared/watcher/src/watcher.rs index 1c630e640..a99d641c4 100644 --- a/shared/watcher/src/watcher.rs +++ b/shared/watcher/src/watcher.rs @@ -1,6 +1,6 @@ use crate::traits::Backend; use anyhow::Result; -use psyche_coordinator::{Client, Coordinator, RunState}; +use psyche_coordinator::coordinator::{Client, Coordinator, RunState}; use std::collections::HashMap; use std::hash::{DefaultHasher, Hasher}; diff --git a/tools/rust-tools/preview-lr/Cargo.toml b/tools/rust-tools/preview-lr/Cargo.toml index 3d9fd18ca..8ad9c4f3d 100644 --- a/tools/rust-tools/preview-lr/Cargo.toml +++ b/tools/rust-tools/preview-lr/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true clap.workspace = true clap-markdown.workspace = true plotters = "0.3.7" -psyche-coordinator.workspace = true +psyche-coordinator = { workspace = true, features = ["client"] } +psyche-core.workspace = true serde.workspace = true toml.workspace = true diff --git a/tools/rust-tools/preview-lr/src/main.rs b/tools/rust-tools/preview-lr/src/main.rs index 51c52a3c6..9ea555363 100644 --- a/tools/rust-tools/preview-lr/src/main.rs +++ b/tools/rust-tools/preview-lr/src/main.rs @@ -1,6 +1,7 @@ use clap::Parser; use plotters::prelude::*; -use psyche_coordinator::{CoordinatorConfig, model_extra_data::ModelExtraData}; +use psyche_coordinator::coordinator::CoordinatorConfig; +use psyche_core::ModelExtraData; use serde::Deserialize; use std::path::PathBuf; diff --git a/tools/rust-tools/run-manager/src/commands/can_join.rs b/tools/rust-tools/run-manager/src/commands/can_join.rs index 2f8e00137..b60951383 100644 --- a/tools/rust-tools/run-manager/src/commands/can_join.rs +++ b/tools/rust-tools/run-manager/src/commands/can_join.rs @@ -4,8 +4,8 @@ use anyhow::Result; use anyhow::bail; use async_trait::async_trait; use clap::Args; -use psyche_coordinator::RunState; +use psyche_coordinator::coordinator::RunState; use psyche_solana_rpc::SolanaBackend; #[derive(Debug, Clone, Args)] diff --git a/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs b/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs index 062f77271..77b76684a 100644 --- a/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs +++ b/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs @@ -2,8 +2,8 @@ use crate::commands::Command; use anyhow::Result; use async_trait::async_trait; use clap::Args; -use psyche_coordinator::model_extra_data::CheckpointData; +use psyche_core::CheckpointData; use psyche_solana_rpc::SolanaBackend; use psyche_solana_rpc::instructions; diff --git a/tools/rust-tools/run-manager/src/commands/run/create_run.rs b/tools/rust-tools/run-manager/src/commands/run/create_run.rs index 8408379e8..e531dc0cf 100644 --- a/tools/rust-tools/run-manager/src/commands/run/create_run.rs +++ b/tools/rust-tools/run-manager/src/commands/run/create_run.rs @@ -9,7 +9,7 @@ use anyhow::Result; use anyhow::bail; use async_trait::async_trait; use clap::Args; -use psyche_coordinator::SOLANA_RUN_ID_MAX_LEN; +use psyche_coordinator::coordinator::SOLANA_RUN_ID_MAX_LEN; use crate::commands::Command; use psyche_solana_rpc::SolanaBackend; diff --git a/tools/rust-tools/run-manager/src/commands/run/update_config.rs b/tools/rust-tools/run-manager/src/commands/run/update_config.rs index 0bec159d5..253098530 100644 --- a/tools/rust-tools/run-manager/src/commands/run/update_config.rs +++ b/tools/rust-tools/run-manager/src/commands/run/update_config.rs @@ -1,21 +1,22 @@ use crate::commands::Command; use async_trait::async_trait; +use psyche_coordinator::{ + coordinator::{CoordinatorConfig, CoordinatorProgress}, + model::{CheckpointSource, Model}, +}; +use psyche_core::{ + CONFIG_PREFIX, CheckpointData, MODEL_CONFIG_FILENAME, ModelExtraData, get_data_index_for_step, +}; use std::path::PathBuf; +use crate::{SolanaBackend, instructions}; use anyhow::{Context, Result, bail}; use clap::Args; -use psyche_coordinator::{ - CoordinatorConfig, CoordinatorProgress, get_data_index_for_step, - model::{CheckpointSource, Model}, - model_extra_data::{CONFIG_PREFIX, CheckpointData, MODEL_CONFIG_FILENAME, ModelExtraData}, -}; use psyche_data_provider::upload_json_to_gcs; use psyche_solana_treasurer::logic::RunUpdateParams; use serde::{Deserialize, Serialize}; use tracing::info; -use crate::{SolanaBackend, instructions}; - #[derive(Debug, Clone, Args)] #[command()] pub struct CommandUpdateConfig { @@ -91,18 +92,16 @@ impl Command for CommandUpdateConfig { }; if let (Some(ref model_extra_data), Some(ref mut model)) = (&model_extra_data, &mut model) { - let Model::LLM(llm) = model; - llm.checkpoint_data = model_extra_data.checkpoint.to_fixed_vec(); - llm.checkpoint_source = CheckpointSource::Stored; + model.checkpoint_data = model_extra_data.checkpoint.to_fixed_vec(); + model.checkpoint_source = CheckpointSource::Stored; } model = if switch_to_hub { - let Model::LLM(mut llm) = - model.unwrap_or(coordinator_account_state.state.coordinator.model); - if llm.checkpoint_source == CheckpointSource::P2P { - llm.checkpoint_source = CheckpointSource::Stored; + let mut model = model.unwrap_or(coordinator_account_state.state.coordinator.model); + if model.checkpoint_source == CheckpointSource::P2P { + model.checkpoint_source = CheckpointSource::Stored; } - Some(Model::LLM(llm)) + Some(model) } else { model }; @@ -119,8 +118,8 @@ impl Command for CommandUpdateConfig { // Upload model extra data to GCS or hub repo depending of the model checkpoint if !skip_upload_model_extra_data { if let Some(model_extra_data) = model_extra_data { - let Model::LLM(llm) = &coordinator_account_state.state.coordinator.model; - match llm.decode_checkpoint() { + let llm = &coordinator_account_state.state.coordinator.model; + match CheckpointData::from_fixed_vec(&llm.checkpoint_data).ok() { Some(CheckpointData::Gcs { bucket, .. }) => { let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); info!("Uploading model extra data to gs://{}/{}", bucket, path); diff --git a/tools/rust-tools/run-manager/src/docker/coordinator_client.rs b/tools/rust-tools/run-manager/src/docker/coordinator_client.rs index a116b3932..2fa826794 100644 --- a/tools/rust-tools/run-manager/src/docker/coordinator_client.rs +++ b/tools/rust-tools/run-manager/src/docker/coordinator_client.rs @@ -3,9 +3,8 @@ use anchor_client::solana_sdk::{ }; use anchor_lang::AccountDeserialize; use anyhow::{Context, Result}; -use psyche_coordinator::RunState; -use psyche_coordinator::model::Model; -use psyche_coordinator::model_extra_data::CheckpointData; +use psyche_coordinator::coordinator::RunState; +use psyche_core::CheckpointData; use psyche_solana_authorizer::state::Authorization; use psyche_solana_coordinator::{ CoordinatorInstance, coordinator_account_from_bytes, find_coordinator_instance, @@ -227,7 +226,7 @@ impl CoordinatorClient { let coordinator_account = coordinator_account_from_bytes(&coordinator_account_data.data) .context("Failed to deserialize CoordinatorAccount")?; - let Model::LLM(ref llm) = coordinator_account.state.coordinator.model; + let llm = coordinator_account.state.coordinator.model; CheckpointData::from_fixed_vec(&llm.checkpoint_data) .map_err(|e| anyhow::anyhow!("Failed to decode checkpoint data: {e}")) } diff --git a/tools/rust-tools/run-manager/src/docker/manager.rs b/tools/rust-tools/run-manager/src/docker/manager.rs index 7377c73a8..fd02623a8 100644 --- a/tools/rust-tools/run-manager/src/docker/manager.rs +++ b/tools/rust-tools/run-manager/src/docker/manager.rs @@ -2,7 +2,8 @@ use anchor_client::solana_sdk::bs58; use anchor_client::solana_sdk::pubkey::Pubkey; use anchor_client::solana_sdk::signature::{EncodableKey, Keypair, Signer}; use anyhow::{Context, Result, anyhow, bail}; -use psyche_coordinator::model_extra_data::CheckpointData; +use psyche_coordinator::coordinator::RunState; +use psyche_core::CheckpointData; use std::io::{BufRead, BufReader, Cursor}; use std::path::PathBuf; use std::process::{Command, Stdio}; @@ -14,7 +15,6 @@ use crate::docker::coordinator_client::CoordinatorClient; use crate::get_env_var; use crate::load_and_apply_env_file; use crate::load_wallet_key; -use psyche_coordinator::RunState; const RETRY_DELAY_SECS: u64 = 5; const VERSION_MISMATCH_EXIT_CODE: i32 = 10; diff --git a/tools/rust-tools/run-manager/tests/common/mod.rs b/tools/rust-tools/run-manager/tests/common/mod.rs index 0aec6e196..29266ddd7 100644 --- a/tools/rust-tools/run-manager/tests/common/mod.rs +++ b/tools/rust-tools/run-manager/tests/common/mod.rs @@ -9,7 +9,7 @@ use anchor_client::{ }, }; use anyhow::{Context, Result, bail}; -use psyche_coordinator::RunState; +use psyche_coordinator::coordinator::RunState; use psyche_solana_rpc::SolanaBackend; use std::sync::Arc; use std::{ diff --git a/tools/rust-tools/run-manager/tests/integration_tests.rs b/tools/rust-tools/run-manager/tests/integration_tests.rs index edb623d39..7c0cacbe8 100644 --- a/tools/rust-tools/run-manager/tests/integration_tests.rs +++ b/tools/rust-tools/run-manager/tests/integration_tests.rs @@ -8,7 +8,7 @@ use anchor_client::{ solana_sdk::{commitment_config::CommitmentConfig, signature::Signer}, }; use common::{TestClient, TestValidator, create_test_keypair}; -use psyche_coordinator::RunState; +use psyche_coordinator::coordinator::RunState; use psyche_solana_rpc::SolanaBackend; use run_manager::commands::{ Command, diff --git a/website/wasm/Cargo.toml b/website/wasm/Cargo.toml index 4b64f7472..fc25d9696 100644 --- a/website/wasm/Cargo.toml +++ b/website/wasm/Cargo.toml @@ -7,11 +7,14 @@ version.workspace = true crate-type = ["cdylib"] [dependencies] -psyche-solana-coordinator = { path = "../../architectures/decentralized/solana-coordinator/programs/solana-coordinator" } -psyche-coordinator.workspace = true +psyche-solana-coordinator = { path = "../../architectures/decentralized/solana-coordinator/programs/solana-coordinator", features = [ + "client", +] } +psyche-coordinator = { workspace = true, features = ["client"] } anchor-lang.workspace = true serde.workspace = true serde-wasm-bindgen = "0.6.5" wasm-bindgen = "=0.2.108" ts-rs.workspace = true psyche-core.workspace = true +postcard.workspace = true diff --git a/website/wasm/src/lib.rs b/website/wasm/src/lib.rs index 79d668d9a..f374a0bb3 100644 --- a/website/wasm/src/lib.rs +++ b/website/wasm/src/lib.rs @@ -1,7 +1,5 @@ -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model_extra_data::CheckpointData; -use psyche_core::LearningRateSchedule; -use psyche_core::NodeIdentity; +use psyche_coordinator::node_identity::NodeIdentity; +use psyche_core::{CheckpointData, LLMArchitecture, LearningRateSchedule}; use psyche_solana_coordinator::{CoordinatorAccount, coordinator_account_from_bytes}; use serde::ser::Serialize; use ts_rs::TS; @@ -32,12 +30,11 @@ pub fn lr_at_step( Ok(lr.get_lr(step)) } -/// Decode borsh-serialized checkpoint_data bytes into a CheckpointData JSON value. +/// Decode postcard-serialized checkpoint_data bytes into a CheckpointData JSON value. /// Returns null for Dummy or if decoding fails. #[wasm_bindgen] pub fn decode_checkpoint_data(bytes: Vec) -> JsValue { - use anchor_lang::prelude::borsh::BorshDeserialize; - let Ok(data) = CheckpointData::try_from_slice(&bytes) else { + let Ok(data) = postcard::from_bytes(&bytes) else { return JsValue::NULL; }; match &data {