diff --git a/src/routes/iso.rs b/src/routes/iso.rs index 2eebd0b..1ec3346 100644 --- a/src/routes/iso.rs +++ b/src/routes/iso.rs @@ -13,9 +13,7 @@ use axum::extract::{Path, Query, State}; use axum::http::{HeaderMap, StatusCode, header}; use axum::response::Response; use serde::Deserialize; -use tokio::fs::File; use tokio_stream::wrappers::ReceiverStream; -use tokio_util::io::ReaderStream; /// Query parameters for the ISO endpoint. #[derive(Debug, Deserialize, Default)] @@ -55,7 +53,7 @@ pub async fn handle_iso( // Check if this is a request for the ISO file itself if iso_service.is_iso_file(&iso_name, &path)? { tracing::info!("Serving ISO file: {}/{}", iso_name, path); - return serve_iso_file(&iso_service, &iso_name).await; + return serve_iso_file(&iso_service, &iso_name); } // Extract host from headers @@ -84,22 +82,10 @@ pub async fn handle_iso( serve_from_iso(&iso_service, &iso_name, &path) } -/// Serve the ISO file itself for streaming. -async fn serve_iso_file(iso_service: &IsoService, iso_name: &str) -> AppResult { - let iso_path = iso_service.iso_file_path(iso_name)?; - - let file = File::open(&iso_path).await.map_err(|e| AppError::FileRead { - path: iso_path.clone(), - source: e, - })?; - - let metadata = file.metadata().await.map_err(|e| AppError::FileRead { - path: iso_path.clone(), - source: e, - })?; - let content_length = metadata.len(); - - let stream = ReaderStream::new(file); +/// Serve the ISO file itself using chunked streaming for memory efficiency. +fn serve_iso_file(iso_service: &IsoService, iso_name: &str) -> AppResult { + let (content_length, receiver) = iso_service.stream_iso_file(iso_name)?; + let stream = ReceiverStream::new(receiver); let body = Body::from_stream(stream); Ok(Response::builder() diff --git a/src/services/iso.rs b/src/services/iso.rs index 1926a93..26959aa 100644 --- a/src/services/iso.rs +++ b/src/services/iso.rs @@ -16,6 +16,35 @@ const ISO_BLOCK_SIZE: u64 = 2048; /// Chunk size for streaming (32MB) const CHUNK_SIZE: usize = 32 * 1024 * 1024; +/// Stream file contents in chunks to a channel. +/// +/// Reads the file in CHUNK_SIZE chunks and sends each chunk to the channel. +/// Stops early if receiver is dropped. +fn stream_file_to_channel( + file: &mut File, + file_size: u64, + tx: &mpsc::Sender>, +) -> Result<(), std::io::Error> { + let mut bytes_remaining = file_size as usize; + + while bytes_remaining > 0 { + let chunk_size = std::cmp::min(bytes_remaining, CHUNK_SIZE); + + let mut buffer = vec![0u8; chunk_size]; + file.read_exact(&mut buffer)?; + + let bytes = Bytes::from(buffer); + if tx.blocking_send(Ok(bytes)).is_err() { + // Receiver dropped, stop sending + return Ok(()); + } + + bytes_remaining -= chunk_size; + } + + Ok(()) +} + /// Wrapper to implement BlockIo for std::fs::File. struct FileBlockIo { file: File, @@ -408,6 +437,42 @@ impl IsoService { Err(AppError::TemplateNotFound { path: iso_path }) } + /// Stream the ISO file itself with chunked reads for memory efficiency. + /// + /// Returns the file size and a receiver that yields chunks. + /// Uses spawn_blocking for the file reads with backpressure via bounded channel. + pub fn stream_iso_file( + &self, + iso_name: &str, + ) -> AppResult<(u64, mpsc::Receiver>)> { + let iso_path = self.iso_file_path(iso_name)?; + + // Get file size + let metadata = std::fs::metadata(&iso_path).map_err(|e| AppError::FileRead { + path: iso_path.clone(), + source: e, + })?; + let file_size = metadata.len(); + + // Create bounded channel for backpressure (2 chunks max in flight) + let (tx, rx) = mpsc::channel(2); + + // Spawn blocking task to read chunks + tokio::task::spawn_blocking(move || { + let result = (|| -> Result<(), std::io::Error> { + let mut file = File::open(&iso_path)?; + stream_file_to_channel(&mut file, file_size, &tx)?; + Ok(()) + })(); + + if let Err(e) = result { + let _ = tx.blocking_send(Err(e)); + } + }); + + Ok((file_size, rx)) + } + /// Stream a file from within an ISO. /// /// Returns the file size and a receiver that yields chunks. @@ -609,21 +674,7 @@ impl IsoService { // Phase 2: Stream firmware from disk let mut firmware_file = File::open(&firmware_path_clone)?; - let mut offset: u64 = 0; - while offset < firmware_size { - let remaining = firmware_size - offset; - let chunk_size = std::cmp::min(remaining as usize, CHUNK_SIZE); - - let mut buffer = vec![0u8; chunk_size]; - firmware_file.read_exact(&mut buffer)?; - - let bytes = Bytes::from(buffer); - if tx.blocking_send(Ok(bytes)).is_err() { - return Ok(()); - } - - offset += chunk_size as u64; - } + stream_file_to_channel(&mut firmware_file, firmware_size, &tx)?; Ok(()) })(); @@ -899,4 +950,108 @@ mod tests { let result = service.should_concat_firmware("test", "/install/initrd.gz").unwrap(); assert!(result.is_none()); } + + #[test] + fn test_stream_file_to_channel() { + // Create a test file with known content + let dir = setup_test_dir(); + let test_file = dir.path().join("test.bin"); + let test_data = vec![0xABu8; 1024 * 100]; // 100KB of 0xAB + std::fs::write(&test_file, &test_data).unwrap(); + + // Create channel and stream + let (tx, mut rx) = mpsc::channel(2); + let mut file = File::open(&test_file).unwrap(); + let file_size = test_data.len() as u64; + + // Run in a thread since blocking_send requires it + std::thread::spawn(move || { + stream_file_to_channel(&mut file, file_size, &tx).unwrap(); + }); + + // Collect all chunks + let mut received = Vec::new(); + while let Some(result) = rx.blocking_recv() { + let bytes = result.unwrap(); + received.extend_from_slice(&bytes); + } + + assert_eq!(received.len(), test_data.len()); + assert_eq!(received, test_data); + } + + #[test] + fn test_stream_file_to_channel_multiple_chunks() { + // Create a file larger than CHUNK_SIZE (32MB) to test chunking + // 70MB = 2 full chunks (32MB each) + 1 partial chunk (6MB) + let dir = setup_test_dir(); + let test_file = dir.path().join("large.bin"); + let file_size = 70 * 1024 * 1024; // 70MB + let test_data: Vec = (0..file_size).map(|i| (i % 256) as u8).collect(); + std::fs::write(&test_file, &test_data).unwrap(); + + let (tx, mut rx) = mpsc::channel(2); + let mut file = File::open(&test_file).unwrap(); + + std::thread::spawn(move || { + stream_file_to_channel(&mut file, file_size as u64, &tx).unwrap(); + }); + + // Collect chunks and verify we get multiple + let mut received = Vec::new(); + let mut chunk_count = 0; + while let Some(result) = rx.blocking_recv() { + let bytes = result.unwrap(); + chunk_count += 1; + received.extend_from_slice(&bytes); + } + + assert_eq!(chunk_count, 3); // 32MB + 32MB + 6MB + assert_eq!(received.len(), test_data.len()); + assert_eq!(received, test_data); + } + + #[tokio::test] + async fn test_stream_iso_file() { + let dir = setup_test_dir(); + let iso_dir = dir.path().join("iso").join("test-iso"); + std::fs::create_dir_all(&iso_dir).unwrap(); + + // Create iso.cfg + std::fs::write(iso_dir.join("iso.cfg"), "filename=test.iso\n").unwrap(); + + // Create a test "ISO" file with known content (1MB) + let test_data: Vec = (0..1024 * 1024).map(|i| (i % 256) as u8).collect(); + std::fs::write(iso_dir.join("test.iso"), &test_data).unwrap(); + + let service = IsoService::new(dir.path().to_path_buf()); + let (size, mut rx) = service.stream_iso_file("test-iso").unwrap(); + + assert_eq!(size, test_data.len() as u64); + + // Collect all chunks + let mut received = Vec::new(); + while let Some(result) = rx.recv().await { + let bytes = result.unwrap(); + received.extend_from_slice(&bytes); + } + + assert_eq!(received.len(), test_data.len()); + assert_eq!(received, test_data); + } + + #[tokio::test] + async fn test_stream_iso_file_not_found() { + let dir = setup_test_dir(); + let iso_dir = dir.path().join("iso").join("test-iso"); + std::fs::create_dir_all(&iso_dir).unwrap(); + + // Create iso.cfg pointing to non-existent file + std::fs::write(iso_dir.join("iso.cfg"), "filename=missing.iso\n").unwrap(); + + let service = IsoService::new(dir.path().to_path_buf()); + let result = service.stream_iso_file("test-iso"); + + assert!(matches!(result, Err(AppError::IsoFileNotFound { .. }))); + } }