Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 31 additions & 29 deletions crates/synapse-infra/src/adapters/diffusion_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ impl DiffusionAdapter {
};
let retina = HolographicRetina::with_config(genesis.clone(), retina_config);

// Initialize U-Net
// Initialize U-Net with 4 channels to match Holographic Transport standard
let unet_config = UNetConfig {
in_channels: 4,
out_channels: 4,
base_channels: 64,
layers: 2, // Lightweight
context_dim: Some(AXIOM_DIMENSION),
};
let mut unet = UNetAdapter::new(unet_config, &device).map_err(AdapterError::Candle)?;
unet.init_random().map_err(AdapterError::Candle)?;
Expand Down Expand Up @@ -334,9 +335,6 @@ impl DiffusionAdapter {
}

/// Perform one diffusion step (denoising)
///
/// This is a simplified DDPM-style step. In production, you'd use
/// a trained U-Net or Transformer model for noise prediction.
pub async fn step(&self) -> Result<DiffusionStepResult, AdapterError> {
let mut state = self.state.lock().await;

Expand Down Expand Up @@ -369,26 +367,32 @@ impl DiffusionAdapter {

// --- REAL NEURAL NETWORK PREDICTION ---
// Reshape latent to (1, 4, H, W) for U-Net
// Assuming latent_dim = 512 -> 4 * 8 * 16
let channels = 4;
let spatial_size = current_latent.len() / channels;
// Try to find reasonable dimensions
let h = if spatial_size > 0 { (spatial_size as f32).sqrt() as usize } else { 1 };
let w = if h > 0 { spatial_size / h } else { 1 };

// Fallback if dimensions don't match perfectly or are zero
let shape = if channels * h * w == current_latent.len() && current_latent.len() > 0 {
(1, channels, h, w)
let total_elements = current_latent.len();

// Find best square-ish dimensions for the given channels
let spatial_elements = total_elements / channels;
let h = if spatial_elements > 0 { (spatial_elements as f32).sqrt().floor() as usize } else { 1 };
let w = if h > 0 { spatial_elements / h } else { 1 };

let latent_tensor = if channels * h * w == total_elements {
Tensor::from_vec(current_latent.clone(), (1, channels, h, w), &self.device)
.map_err(AdapterError::Candle)?
} else {
(1, 1, 1, current_latent.len())
// Fallback to simple reshaping if not divisible
Tensor::from_vec(current_latent.clone(), (1, 1, 1, total_elements), &self.device)
.map_err(AdapterError::Candle)?
};

let latent_tensor = Tensor::from_vec(current_latent.clone(), shape, &self.device)
.map_err(AdapterError::Candle)?;
let timestep_tensor = Tensor::new(&[t as f32], &self.device)
.map_err(AdapterError::Candle)?;

let noise_pred_tensor = self.unet.predict_noise(&latent_tensor, &timestep_tensor)
// Conditioning: Project current latent to AXIOM_DIMENSION as semantic context
let mut conditioning = vec![0.0f32; AXIOM_DIMENSION];
let copy_len = current_latent.len().min(AXIOM_DIMENSION);
conditioning[..copy_len].copy_from_slice(&current_latent[..copy_len]);

let noise_pred_tensor = self.unet.predict_noise_conditioned(&latent_tensor, &timestep_tensor, &conditioning)
.map_err(AdapterError::Candle)?;

let mut predicted_noise = noise_pred_tensor.flatten_all()
Expand Down Expand Up @@ -575,6 +579,7 @@ mod tests {
fn create_test_adapter() -> DiffusionAdapter {
let config = DiffusionConfig {
num_steps: 10, // Reduced for tests
latent_dim: 256, // 4 * 8 * 8
..Default::default()
};
DiffusionAdapter::new(config).unwrap()
Expand All @@ -584,13 +589,13 @@ mod tests {
async fn test_init_from_noise() {
let adapter = create_test_adapter();
let latent = adapter.init_from_noise(12345).await.unwrap();
assert_eq!(latent.len(), 512);
assert_eq!(latent.len(), 256);
}

#[tokio::test]
async fn test_guidance_scale_computation() {
let adapter = create_test_adapter();
let latent: Vec<f32> = (0..512).map(|i| (i as f32 * 0.01).sin()).collect();
let latent: Vec<f32> = (0..256).map(|i| (i as f32 * 0.01).sin()).collect();
let scale = adapter.compute_guidance_scale(&latent).unwrap();
assert!(scale >= adapter.config.guidance_min);
assert!(scale <= adapter.config.guidance_max);
Expand All @@ -601,16 +606,14 @@ mod tests {
let adapter = create_test_adapter();
adapter.init_from_noise(54321).await.unwrap();

// Just verify we can take steps and get valid data back
// Note: This test may have numerical stability issues in some environments
match adapter.step().await {
Ok(result) => {
assert_eq!(result.latent.len(), 512);
assert_eq!(result.latent.len(), 256);
assert!(result.guidance_scale > 0.0);
}
Err(e) => {
// Log but don't fail - numerical issues may occur
eprintln!("Step error (may be expected): {:?}", e);
eprintln!("Step error: {:?}", e);
panic!("Step failed");
}
}
}
Expand All @@ -634,25 +637,24 @@ mod tests {
let adapter = create_test_adapter();
adapter.init_from_noise(11111).await.unwrap();

// Basic resonance check should work
match adapter.check_resonance().await {
Ok(resonance) => {
assert!(resonance.total_resonance >= 0.0 && resonance.total_resonance <= 1.0);
}
Err(e) => {
eprintln!("Resonance check error (may be expected): {:?}", e);
eprintln!("Resonance check error: {:?}", e);
panic!("Resonance check failed");
}
}
}

#[tokio::test]
async fn test_maternal_gradient() {
let adapter = create_test_adapter();
let latent: Vec<f32> = (0..512).map(|i| (i as f32 * 0.02).sin()).collect();
let latent: Vec<f32> = (0..256).map(|i| (i as f32 * 0.02).sin()).collect();
let gradient = adapter.compute_maternal_gradient(&latent).unwrap();

assert_eq!(gradient.len(), 512);
// Gradient should be non-zero
assert_eq!(gradient.len(), 256);
let sum: f32 = gradient.iter().map(|x| x.abs()).sum();
assert!(sum > 0.0);
}
Expand Down
Loading
Loading