Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions src/routes/iso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, StatusCode, header};
use axum::response::Response;
use serde::Deserialize;
use tokio::fs::File;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::io::ReaderStream;

/// Query parameters for the ISO endpoint.
#[derive(Debug, Deserialize, Default)]
Expand Down Expand Up @@ -55,7 +53,7 @@ pub async fn handle_iso(
// Check if this is a request for the ISO file itself
if iso_service.is_iso_file(&iso_name, &path)? {
tracing::info!("Serving ISO file: {}/{}", iso_name, path);
return serve_iso_file(&iso_service, &iso_name).await;
return serve_iso_file(&iso_service, &iso_name);
}

// Extract host from headers
Expand Down Expand Up @@ -84,22 +82,10 @@ pub async fn handle_iso(
serve_from_iso(&iso_service, &iso_name, &path)
}

/// Serve the ISO file itself for streaming.
async fn serve_iso_file(iso_service: &IsoService, iso_name: &str) -> AppResult<Response> {
let iso_path = iso_service.iso_file_path(iso_name)?;

let file = File::open(&iso_path).await.map_err(|e| AppError::FileRead {
path: iso_path.clone(),
source: e,
})?;

let metadata = file.metadata().await.map_err(|e| AppError::FileRead {
path: iso_path.clone(),
source: e,
})?;
let content_length = metadata.len();

let stream = ReaderStream::new(file);
/// Serve the ISO file itself using chunked streaming for memory efficiency.
fn serve_iso_file(iso_service: &IsoService, iso_name: &str) -> AppResult<Response> {
let (content_length, receiver) = iso_service.stream_iso_file(iso_name)?;
let stream = ReceiverStream::new(receiver);
let body = Body::from_stream(stream);

Ok(Response::builder()
Expand Down
185 changes: 170 additions & 15 deletions src/services/iso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@ const ISO_BLOCK_SIZE: u64 = 2048;
/// Chunk size for streaming (32MB)
const CHUNK_SIZE: usize = 32 * 1024 * 1024;

/// Stream file contents in chunks to a channel.
///
/// Reads the file in CHUNK_SIZE chunks and sends each chunk to the channel.
/// Stops early if receiver is dropped.
fn stream_file_to_channel(
file: &mut File,
file_size: u64,
tx: &mpsc::Sender<Result<Bytes, std::io::Error>>,
) -> Result<(), std::io::Error> {
let mut bytes_remaining = file_size as usize;

while bytes_remaining > 0 {
let chunk_size = std::cmp::min(bytes_remaining, CHUNK_SIZE);

let mut buffer = vec![0u8; chunk_size];
file.read_exact(&mut buffer)?;

let bytes = Bytes::from(buffer);
if tx.blocking_send(Ok(bytes)).is_err() {
// Receiver dropped, stop sending
return Ok(());
}

bytes_remaining -= chunk_size;
}

Ok(())
}

/// Wrapper to implement BlockIo for std::fs::File.
struct FileBlockIo {
file: File,
Expand Down Expand Up @@ -408,6 +437,42 @@ impl IsoService {
Err(AppError::TemplateNotFound { path: iso_path })
}

/// Stream the ISO file itself with chunked reads for memory efficiency.
///
/// Returns the file size and a receiver that yields chunks.
/// Uses spawn_blocking for the file reads with backpressure via bounded channel.
pub fn stream_iso_file(
&self,
iso_name: &str,
) -> AppResult<(u64, mpsc::Receiver<Result<Bytes, std::io::Error>>)> {
let iso_path = self.iso_file_path(iso_name)?;

// Get file size
let metadata = std::fs::metadata(&iso_path).map_err(|e| AppError::FileRead {
path: iso_path.clone(),
source: e,
})?;
let file_size = metadata.len();

// Create bounded channel for backpressure (2 chunks max in flight)
let (tx, rx) = mpsc::channel(2);

// Spawn blocking task to read chunks
tokio::task::spawn_blocking(move || {
let result = (|| -> Result<(), std::io::Error> {
let mut file = File::open(&iso_path)?;
stream_file_to_channel(&mut file, file_size, &tx)?;
Ok(())
})();

if let Err(e) = result {
let _ = tx.blocking_send(Err(e));
}
});

Ok((file_size, rx))
}

/// Stream a file from within an ISO.
///
/// Returns the file size and a receiver that yields chunks.
Expand Down Expand Up @@ -609,21 +674,7 @@ impl IsoService {

// Phase 2: Stream firmware from disk
let mut firmware_file = File::open(&firmware_path_clone)?;
let mut offset: u64 = 0;
while offset < firmware_size {
let remaining = firmware_size - offset;
let chunk_size = std::cmp::min(remaining as usize, CHUNK_SIZE);

let mut buffer = vec![0u8; chunk_size];
firmware_file.read_exact(&mut buffer)?;

let bytes = Bytes::from(buffer);
if tx.blocking_send(Ok(bytes)).is_err() {
return Ok(());
}

offset += chunk_size as u64;
}
stream_file_to_channel(&mut firmware_file, firmware_size, &tx)?;

Ok(())
})();
Expand Down Expand Up @@ -899,4 +950,108 @@ mod tests {
let result = service.should_concat_firmware("test", "/install/initrd.gz").unwrap();
assert!(result.is_none());
}

#[test]
fn test_stream_file_to_channel() {
// Create a test file with known content
let dir = setup_test_dir();
let test_file = dir.path().join("test.bin");
let test_data = vec![0xABu8; 1024 * 100]; // 100KB of 0xAB
std::fs::write(&test_file, &test_data).unwrap();

// Create channel and stream
let (tx, mut rx) = mpsc::channel(2);
let mut file = File::open(&test_file).unwrap();
let file_size = test_data.len() as u64;

// Run in a thread since blocking_send requires it
std::thread::spawn(move || {
stream_file_to_channel(&mut file, file_size, &tx).unwrap();
});

// Collect all chunks
let mut received = Vec::new();
while let Some(result) = rx.blocking_recv() {
let bytes = result.unwrap();
received.extend_from_slice(&bytes);
}

assert_eq!(received.len(), test_data.len());
assert_eq!(received, test_data);
}

#[test]
fn test_stream_file_to_channel_multiple_chunks() {
// Create a file larger than CHUNK_SIZE (32MB) to test chunking
// 70MB = 2 full chunks (32MB each) + 1 partial chunk (6MB)
let dir = setup_test_dir();
let test_file = dir.path().join("large.bin");
let file_size = 70 * 1024 * 1024; // 70MB
let test_data: Vec<u8> = (0..file_size).map(|i| (i % 256) as u8).collect();
std::fs::write(&test_file, &test_data).unwrap();

let (tx, mut rx) = mpsc::channel(2);
let mut file = File::open(&test_file).unwrap();

std::thread::spawn(move || {
stream_file_to_channel(&mut file, file_size as u64, &tx).unwrap();
});

// Collect chunks and verify we get multiple
let mut received = Vec::new();
let mut chunk_count = 0;
while let Some(result) = rx.blocking_recv() {
let bytes = result.unwrap();
chunk_count += 1;
received.extend_from_slice(&bytes);
}

assert_eq!(chunk_count, 3); // 32MB + 32MB + 6MB
assert_eq!(received.len(), test_data.len());
assert_eq!(received, test_data);
}

#[tokio::test]
async fn test_stream_iso_file() {
let dir = setup_test_dir();
let iso_dir = dir.path().join("iso").join("test-iso");
std::fs::create_dir_all(&iso_dir).unwrap();

// Create iso.cfg
std::fs::write(iso_dir.join("iso.cfg"), "filename=test.iso\n").unwrap();

// Create a test "ISO" file with known content (1MB)
let test_data: Vec<u8> = (0..1024 * 1024).map(|i| (i % 256) as u8).collect();
std::fs::write(iso_dir.join("test.iso"), &test_data).unwrap();

let service = IsoService::new(dir.path().to_path_buf());
let (size, mut rx) = service.stream_iso_file("test-iso").unwrap();

assert_eq!(size, test_data.len() as u64);

// Collect all chunks
let mut received = Vec::new();
while let Some(result) = rx.recv().await {
let bytes = result.unwrap();
received.extend_from_slice(&bytes);
}

assert_eq!(received.len(), test_data.len());
assert_eq!(received, test_data);
}

#[tokio::test]
async fn test_stream_iso_file_not_found() {
let dir = setup_test_dir();
let iso_dir = dir.path().join("iso").join("test-iso");
std::fs::create_dir_all(&iso_dir).unwrap();

// Create iso.cfg pointing to non-existent file
std::fs::write(iso_dir.join("iso.cfg"), "filename=missing.iso\n").unwrap();

let service = IsoService::new(dir.path().to_path_buf());
let result = service.stream_iso_file("test-iso");

assert!(matches!(result, Err(AppError::IsoFileNotFound { .. })));
}
}