diff --git a/CHANGELOG.md b/CHANGELOG.md index 052ae6d..c52d370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed +- Stream cache writes to disk via `std::io::copy` instead of buffering the full payload in memory +- `download_async()` now preserves raw bytes, matching `download()` +- Default blocking HTTP clients are reused across reads and content-length probes +- S3 status failures now use structured errors instead of string parsing +- S3 readers now stream data through a bounded channel instead of materializing the full object in memory + +### Added +- Added `bzip2_decompress` benchmark coverage +- Added a benchmark helper script for comparing gzip backend feature flags and bz2 decompression + ## v0.20.1 -- 2025-12-18 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 7c31989..7ee60e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ s3 = ["rust-s3"] gz = ["gz-zlib-rs"] # internal feature to enable gzip support any_gz = [] -gz-miniz = ["any_gz", "flate2/miniz_oxide"] +gz-miniz = ["any_gz", "flate2/rust_backend"] gz-zlib-rs = ["any_gz", "flate2/zlib-rs"] gz-zlib-ng = ["any_gz", "flate2/zlib-ng"] gz-zlib-cloudflare = ["any_gz", "flate2/cloudflare_zlib"] @@ -128,6 +128,11 @@ name = "gzip_decompress" harness = false required-features = ["any_gz"] +[[bench]] +name = "bzip2_decompress" +harness = false +required-features = ["bz"] + # This list only includes examples which require additional features to run. These are more in the examples' directory. [[example]] name = "s3_operations" diff --git a/README.md b/README.md index 84f2cb3..10d94da 100644 --- a/README.md +++ b/README.md @@ -191,6 +191,8 @@ async fn main() -> Result<(), Box> { "local_data.csv.gz" ).await?; + // download_async preserves the remote bytes. + Ok(()) } ``` @@ -295,6 +297,7 @@ match oneio::get_reader("file.txt") { Ok(reader) => { /* use reader */ }, Err(OneIoError::Io(e)) => { /* filesystem error */ }, Err(OneIoError::Network(e)) => { /* network error */ }, + Err(OneIoError::Status { service, code }) => { /* remote status error */ }, Err(OneIoError::NotSupported(msg)) => { /* feature not compiled */ }, } ``` diff --git a/benches/bzip2_decompress.rs b/benches/bzip2_decompress.rs new file mode 100644 index 0000000..54be42f --- /dev/null +++ b/benches/bzip2_decompress.rs @@ -0,0 +1,53 @@ +mod common; + +use std::hint::black_box; +use std::io::{Read, Write}; + +use bzip2::read::BzDecoder; +use bzip2::write::BzEncoder; +use bzip2::Compression; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; + +fn build_bzip2_fixture() -> (Vec, usize, String) { + let corpus = common::build_text_corpus(); + let mut encoder = BzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&corpus).unwrap(); + let compressed = encoder.finish().unwrap(); + let fixture = common::write_fixture("bzip2.txt.bz2", &compressed); + + ( + compressed, + corpus.len(), + fixture.to_string_lossy().into_owned(), + ) +} + +fn bench_bzip2_decompress(c: &mut Criterion) { + let (input, output_len, fixture_path) = build_bzip2_fixture(); + + let mut group = c.benchmark_group("bzip2_decompress"); + group.throughput(Throughput::Bytes(output_len as u64)); + + group.bench_function("raw_decoder", |b| { + b.iter(|| { + let mut reader = BzDecoder::new(input.as_slice()); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.bench_function("oneio_get_reader", |b| { + b.iter(|| { + let mut reader = oneio::get_reader(&fixture_path).unwrap(); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.finish(); +} + +criterion_group!(benches, bench_bzip2_decompress); +criterion_main!(benches); diff --git a/benches/common/mod.rs b/benches/common/mod.rs new file mode 100644 index 0000000..7013938 --- /dev/null +++ b/benches/common/mod.rs @@ -0,0 +1,38 @@ +use std::fs; +use std::io::Write; +use std::path::PathBuf; + +const TARGET_CORPUS_SIZE: usize = 16 * 1024 * 1024; + +pub fn build_text_corpus() -> Vec { + let mut data = Vec::with_capacity(TARGET_CORPUS_SIZE); + let mut seq = 0_u64; + + while data.len() < TARGET_CORPUS_SIZE { + writeln!( + &mut data, + "{seq},AS{:05},AS{:05},peer=route-views.eqix,next-hop=192.0.2.{},med={},local-pref={},community={}:{}", + (seq % 64512) + 100, + ((seq * 7) % 64512) + 100, + (seq % 254) + 1, + seq % 1000, + 100 + (seq % 200), + 64512 + (seq % 64), + 100 + (seq % 4096) + ) + .unwrap(); + seq += 1; + } + + data.truncate(TARGET_CORPUS_SIZE); + data +} + +pub fn write_fixture(name: &str, bytes: &[u8]) -> PathBuf { + let fixture_dir = PathBuf::from("target/bench-fixtures"); + fs::create_dir_all(&fixture_dir).unwrap(); + + let path = fixture_dir.join(name); + fs::write(&path, bytes).unwrap(); + path +} diff --git a/benches/gzip_decompress.rs b/benches/gzip_decompress.rs index a22e413..df45b1e 100644 --- a/benches/gzip_decompress.rs +++ b/benches/gzip_decompress.rs @@ -1,42 +1,67 @@ -use std::fs::File; +mod common; + use std::hint::black_box; -use std::io::Read; +use std::io::{Read, Write}; -use criterion::{criterion_group, criterion_main, BatchSize, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use flate2::read::GzDecoder; +use flate2::write::GzEncoder; +use flate2::Compression; + +#[cfg(feature = "gz-miniz")] +const GZIP_BACKEND: &str = "miniz_oxide"; +#[cfg(all(not(feature = "gz-miniz"), feature = "gz-zlib-rs"))] +const GZIP_BACKEND: &str = "zlib-rs"; +#[cfg(all( + not(feature = "gz-miniz"), + not(feature = "gz-zlib-rs"), + feature = "gz-zlib-ng" +))] +const GZIP_BACKEND: &str = "zlib-ng"; +#[cfg(all( + not(feature = "gz-miniz"), + not(feature = "gz-zlib-rs"), + not(feature = "gz-zlib-ng"), + feature = "gz-zlib-cloudflare" +))] +const GZIP_BACKEND: &str = "cloudflare-zlib"; + +fn build_gzip_fixture() -> (Vec, usize, String) { + let corpus = common::build_text_corpus(); + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&corpus).unwrap(); + let compressed = encoder.finish().unwrap(); + let fixture = common::write_fixture(&format!("gzip-{GZIP_BACKEND}.txt.gz"), &compressed); -// Benchmark gzip decompression using flate2 with the selected backend. -// To run with default (miniz_oxide) backend: -// cargo bench --bench gzip_decompress --no-default-features --features gz-miniz -// To run with zlib-rs backend: -// cargo bench --bench gzip_decompress --no-default-features --features gz-zlib-rs -// To compare, run both commands and compare Criterion reports. - -fn load_gz_bytes() -> Vec { - let mut f = File::open("tests/test_data.txt.gz").expect("missing tests/test_data.txt.gz"); - let mut buf = Vec::new(); - f.read_to_end(&mut buf).unwrap(); - buf + ( + compressed, + corpus.len(), + fixture.to_string_lossy().into_owned(), + ) } fn bench_gzip_decompress(c: &mut Criterion) { - let input = load_gz_bytes(); + let (input, output_len, fixture_path) = build_gzip_fixture(); let mut group = c.benchmark_group("gzip_decompress"); - group.throughput(Throughput::Bytes(input.len() as u64)); - - group.bench_function("flate2_gz_decode", |b| { - b.iter_batched( - || input.clone(), - |bytes| { - let reader = GzDecoder::new(bytes.as_slice()); - let mut out = Vec::with_capacity(128 * 1024); - let mut r = reader; - r.read_to_end(&mut out).unwrap(); - black_box(out) - }, - BatchSize::SmallInput, - ) + group.throughput(Throughput::Bytes(output_len as u64)); + + group.bench_function(format!("raw_decoder/{GZIP_BACKEND}"), |b| { + b.iter(|| { + let mut reader = GzDecoder::new(input.as_slice()); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.bench_function(format!("oneio_get_reader/{GZIP_BACKEND}"), |b| { + b.iter(|| { + let mut reader = oneio::get_reader(&fixture_path).unwrap(); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) }); group.finish(); diff --git a/examples/s3_operations.rs b/examples/s3_operations.rs index 1552a94..3d162b2 100644 --- a/examples/s3_operations.rs +++ b/examples/s3_operations.rs @@ -32,27 +32,18 @@ fn main() { info!("error if file does not exist"); let res = s3_stats("oneio-test", "test/README___NON_EXISTS.md"); assert!(res.is_err()); - assert_eq!( - false, - s3_exists("oneio-test", "test/README___NON_EXISTS.md").unwrap() - ); - assert_eq!(true, s3_exists("oneio-test", "test/README.md").unwrap()); + assert!(!s3_exists("oneio-test", "test/README___NON_EXISTS.md").unwrap()); + assert!(s3_exists("oneio-test", "test/README.md").unwrap()); info!("copy S3 file to a different location"); let res = s3_copy("oneio-test", "test/README.md", "test/README-temporary.md"); assert!(res.is_ok()); - assert_eq!( - true, - s3_exists("oneio-test", "test/README-temporary.md").unwrap() - ); + assert!(s3_exists("oneio-test", "test/README-temporary.md").unwrap()); info!("delete temporary copied S3 file"); let res = s3_delete("oneio-test", "test/README-temporary.md"); assert!(res.is_ok()); - assert_eq!( - false, - s3_exists("oneio-test", "test/README-temporary.md").unwrap() - ); + assert!(!s3_exists("oneio-test", "test/README-temporary.md").unwrap()); info!("list S3 files"); let res = s3_list("oneio-test", "test/", Some("/".to_string()), false).unwrap(); diff --git a/scripts/bench_decompression_backends.sh b/scripts/bench_decompression_backends.sh new file mode 100755 index 0000000..bdfaaba --- /dev/null +++ b/scripts/bench_decompression_backends.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -euo pipefail + +criterion_args=(-- --warm-up-time 1 --measurement-time 5 --sample-size 20) + +run_case() { + local name="$1" + shift + echo + echo "== ${name} ==" + cargo bench "$@" "${criterion_args[@]}" +} + +run_case "gzip miniz_oxide" --bench gzip_decompress --no-default-features --features gz-miniz +run_case "gzip zlib-rs" --bench gzip_decompress --no-default-features --features gz-zlib-rs +run_case "gzip zlib-ng" --bench gzip_decompress --no-default-features --features gz-zlib-ng +run_case "gzip cloudflare-zlib" --bench gzip_decompress --no-default-features --features gz-zlib-cloudflare +run_case "bzip2" --bench bzip2_decompress --no-default-features --features bz diff --git a/src/error.rs b/src/error.rs index edc4a79..cf9d500 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,10 @@ pub enum OneIoError { #[error("{0}")] Network(Box), + /// Structured status errors from remote services + #[error("{service} status error: {code}")] + Status { service: &'static str, code: u16 }, + /// Feature not supported/compiled #[error("Not supported: {0}")] NotSupported(String), diff --git a/src/lib.rs b/src/lib.rs index a415c50..21733f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,6 +191,8 @@ async fn main() -> Result<(), Box> { "local_data.csv.gz" ).await?; + // download_async preserves the remote bytes. + Ok(()) } ``` @@ -296,6 +298,7 @@ match oneio::get_reader("file.txt") { Ok(reader) => { /* use reader */ }, Err(OneIoError::Io(e)) => { /* filesystem error */ }, Err(OneIoError::Network(e)) => { /* network error */ }, + Err(OneIoError::Status { service, code }) => { /* remote status error */ }, Err(OneIoError::NotSupported(msg)) => { /* feature not compiled */ }, } ``` diff --git a/src/oneio/mod.rs b/src/oneio/mod.rs index 886a080..1c6b6b7 100644 --- a/src/oneio/mod.rs +++ b/src/oneio/mod.rs @@ -21,12 +21,8 @@ use std::path::Path; use futures::StreamExt; /// Extracts the protocol from a given path. -pub(crate) fn get_protocol(path: &str) -> Option { - let parts = path.split("://").collect::>(); - if parts.len() < 2 { - return None; - } - Some(parts[0].to_string()) +pub(crate) fn get_protocol(path: &str) -> Option<&str> { + path.split_once("://").map(|(protocol, _)| protocol) } pub fn get_writer_raw(path: &str) -> Result, OneIoError> { @@ -40,7 +36,7 @@ pub fn get_writer_raw(path: &str) -> Result, OneIoError> { pub fn get_reader_raw(path: &str) -> Result, OneIoError> { let raw_reader: Box = match get_protocol(path) { - Some(protocol) => match protocol.as_str() { + Some(protocol) => match protocol { #[cfg(feature = "http")] "http" | "https" => { let response = remote::get_http_reader_raw(path, None)?; @@ -54,7 +50,7 @@ pub fn get_reader_raw(path: &str) -> Result, OneIoError> { #[cfg(feature = "s3")] "s3" | "r2" => { let (bucket, path) = s3::s3_url_parse(path)?; - Box::new(s3::s3_reader(bucket.as_str(), path.as_str())?) + s3::s3_reader(bucket.as_str(), path.as_str())? } _ => { return Err(OneIoError::NotSupported(path.to_string())); @@ -131,11 +127,9 @@ pub fn get_cache_reader( // read all to cache file, no encode/decode happens let mut reader = get_reader_raw(path)?; - let mut data: Vec = vec![]; - reader.read_to_end(&mut data)?; let mut writer = get_writer_raw(cache_file_path.as_str())?; - writer.write_all(&data)?; - drop(writer); + std::io::copy(&mut reader, &mut writer)?; + writer.flush()?; // return reader from cache file get_reader(cache_file_path.as_str()) @@ -163,7 +157,7 @@ pub fn get_cache_reader( /// }; /// ``` pub fn get_writer(path: &str) -> Result, OneIoError> { - let output_file = BufWriter::new(File::create(path)?); + let output_file = get_writer_raw(path)?; let file_type = path.rsplit('.').next().unwrap_or(""); get_compression_writer(output_file, file_type) @@ -255,24 +249,7 @@ pub fn get_content_length(path: &str) -> Result { match get_protocol(path) { #[cfg(feature = "http")] Some(protocol) if protocol == "http" || protocol == "https" => { - #[cfg(feature = "rustls")] - crypto::ensure_default_provider()?; - - // HEAD request to get Content-Length - let client = reqwest::blocking::Client::new(); - let response = client.head(path).send()?; - - response - .headers() - .get("content-length") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse().ok()) - .ok_or_else(|| { - OneIoError::NotSupported( - "Cannot determine file size - server doesn't provide Content-Length" - .to_string(), - ) - }) + remote::get_http_content_length(path) } #[cfg(feature = "ftp")] Some(protocol) if protocol == "ftp" => { @@ -464,8 +441,8 @@ pub async fn read_to_string_async(path: &str) -> Result { /// Downloads a file asynchronously from a URL to a local path /// -/// This is the async version of `download()`. It supports all protocols and -/// handles decompression if needed. +/// This is the async version of `download()`. It preserves the raw bytes from +/// the source, matching the synchronous `download()` behavior. /// /// # Arguments /// * `url` - Source URL to download from @@ -491,20 +468,18 @@ pub async fn read_to_string_async(path: &str) -> Result { #[cfg(feature = "async")] pub async fn download_async(url: &str, path: &str) -> Result<(), OneIoError> { use tokio::fs::File; - use tokio::io::AsyncWriteExt; + use tokio::io::{copy, AsyncWriteExt}; - let mut reader = get_reader_async(url).await?; - let mut file = File::create(path).await?; - - let mut buffer = vec![0u8; 8192]; - loop { - let bytes_read = reader.read(&mut buffer).await?; - if bytes_read == 0 { - break; + if let Some(parent) = Path::new(path).parent() { + if !parent.as_os_str().is_empty() { + tokio::fs::create_dir_all(parent).await?; } - file.write_all(&buffer[..bytes_read]).await?; } + let mut reader = get_async_reader_raw(url).await?; + let mut file = File::create(path).await?; + copy(&mut reader, &mut file).await?; + file.flush().await?; Ok(()) } @@ -612,8 +587,10 @@ fn get_async_compression_reader( #[cfg(test)] mod tests { use super::*; + #[cfg(any(feature = "any_gz", feature = "http"))] use std::io::Read; + #[cfg(any(feature = "any_gz", feature = "http", feature = "async"))] const TEST_TEXT: &str = "OneIO test file.\nThis is a test."; #[cfg(feature = "any_gz")] diff --git a/src/oneio/remote.rs b/src/oneio/remote.rs index 45ad996..a6403ca 100644 --- a/src/oneio/remote.rs +++ b/src/oneio/remote.rs @@ -5,6 +5,11 @@ use crate::OneIoError; #[cfg(feature = "http")] use reqwest::blocking::Client; use std::io::Read; +#[cfg(feature = "http")] +use std::sync::OnceLock; + +#[cfg(feature = "http")] +static DEFAULT_HTTP_CLIENT: OnceLock> = OnceLock::new(); #[cfg(feature = "ftp")] pub(crate) fn get_ftp_reader_raw(path: &str) -> Result, OneIoError> { @@ -15,69 +20,87 @@ pub(crate) fn get_ftp_reader_raw(path: &str) -> Result, One #[cfg(feature = "rustls")] super::crypto::ensure_default_provider()?; - let parts = path.split('/').collect::>(); - let socket = match parts[2].contains(':') { - true => parts[2].to_string(), - false => format!("{}:21", parts[2]), + let path_without_scheme = path + .strip_prefix("ftp://") + .ok_or_else(|| OneIoError::NotSupported(path.to_string()))?; + let (host, remote_path) = path_without_scheme + .split_once('/') + .ok_or_else(|| OneIoError::NotSupported(path.to_string()))?; + let socket = match host.contains(':') { + true => host.to_string(), + false => format!("{host}:21"), }; - let path = parts[3..].join("/"); let mut ftp_stream = suppaftp::FtpStream::connect(socket)?; // use anonymous login ftp_stream.login("anonymous", "oneio")?; ftp_stream.transfer_type(suppaftp::types::FileType::Binary)?; - let reader = Box::new(ftp_stream.retr_as_stream(path.as_str())?); + let reader = Box::new(ftp_stream.retr_as_stream(remote_path)?); Ok(reader) } #[cfg(feature = "http")] -pub(crate) fn get_http_reader_raw( - path: &str, - opt_client: Option, -) -> Result { +fn build_default_http_client() -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::USER_AGENT, + reqwest::header::HeaderValue::from_static("oneio"), + ); + headers.insert( + reqwest::header::CONTENT_LENGTH, + reqwest::header::HeaderValue::from_static("0"), + ); + #[cfg(feature = "cli")] + headers.insert( + reqwest::header::CACHE_CONTROL, + reqwest::header::HeaderValue::from_static("no-cache"), + ); + + #[cfg(any(feature = "rustls", feature = "native-tls"))] + { + let accept_invalid_certs = matches!( + std::env::var("ONEIO_ACCEPT_INVALID_CERTS") + .unwrap_or_default() + .to_lowercase() + .as_str(), + "true" | "yes" | "y" | "1" + ); + Client::builder() + .default_headers(headers) + .danger_accept_invalid_certs(accept_invalid_certs) + .build() + } + + #[cfg(not(any(feature = "rustls", feature = "native-tls")))] + { + Client::builder().default_headers(headers).build() + } +} + +#[cfg(feature = "http")] +fn default_http_client() -> Result { dotenvy::dotenv().ok(); #[cfg(feature = "rustls")] super::crypto::ensure_default_provider()?; + match DEFAULT_HTTP_CLIENT.get_or_init(|| build_default_http_client().map_err(|e| e.to_string())) + { + Ok(client) => Ok(client.clone()), + Err(message) => Err(OneIoError::Network(Box::new(std::io::Error::other( + message.clone(), + )))), + } +} + +#[cfg(feature = "http")] +pub(crate) fn get_http_reader_raw( + path: &str, + opt_client: Option, +) -> Result { let client = match opt_client { Some(c) => c, - None => { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::USER_AGENT, - reqwest::header::HeaderValue::from_static("oneio"), - ); - headers.insert( - reqwest::header::CONTENT_LENGTH, - reqwest::header::HeaderValue::from_static("0"), - ); - #[cfg(feature = "cli")] - headers.insert( - reqwest::header::CACHE_CONTROL, - reqwest::header::HeaderValue::from_static("no-cache"), - ); - - #[cfg(any(feature = "rustls", feature = "native-tls"))] - { - let accept_invalid_certs = matches!( - std::env::var("ONEIO_ACCEPT_INVALID_CERTS") - .unwrap_or_default() - .to_lowercase() - .as_str(), - "true" | "yes" | "y" | "1" - ); - Client::builder() - .default_headers(headers) - .danger_accept_invalid_certs(accept_invalid_certs) - .build()? - } - - #[cfg(not(any(feature = "rustls", feature = "native-tls")))] - { - Client::builder().default_headers(headers).build()? - } - } + None => default_http_client()?, }; let res = client .execute(client.get(path).build()?)? @@ -162,6 +185,23 @@ pub fn get_http_reader( get_compression_reader(raw_reader, file_type) } +#[cfg(feature = "http")] +pub(crate) fn get_http_content_length(path: &str) -> Result { + let client = default_http_client()?; + let response = client.head(path).send()?.error_for_status()?; + + response + .headers() + .get("content-length") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()) + .ok_or_else(|| { + OneIoError::NotSupported( + "Cannot determine file size - server doesn't provide Content-Length".to_string(), + ) + }) +} + /// Downloads a file from a remote location to a local path. /// /// # Arguments @@ -198,33 +238,28 @@ pub fn download( opt_client: Option, ) -> Result<(), OneIoError> { match get_protocol(remote_path) { - None => { - return Err(OneIoError::NotSupported(remote_path.to_string())); + #[cfg(feature = "http")] + Some("http" | "https") => { + let mut writer = get_writer_raw(local_path)?; + let mut response = get_http_reader_raw(remote_path, opt_client)?; + response.copy_to(&mut writer)?; + Ok(()) } - Some(protocol) => match protocol.as_str() { - #[cfg(feature = "http")] - "http" | "https" => { - let mut writer = get_writer_raw(local_path)?; - let mut response = get_http_reader_raw(remote_path, opt_client)?; - response.copy_to(&mut writer)?; - } - #[cfg(feature = "ftp")] - "ftp" => { - let mut writer = get_writer_raw(local_path)?; - let mut reader = get_ftp_reader_raw(remote_path)?; - std::io::copy(&mut reader, &mut writer)?; - } - #[cfg(feature = "s3")] - "s3" => { - let (bucket, path) = crate::oneio::s3::s3_url_parse(remote_path)?; - crate::oneio::s3::s3_download(bucket.as_str(), path.as_str(), local_path)?; - } - _ => { - return Err(OneIoError::NotSupported(remote_path.to_string())); - } - }, - }; - Ok(()) + #[cfg(feature = "ftp")] + Some("ftp") => { + let mut writer = get_writer_raw(local_path)?; + let mut reader = get_ftp_reader_raw(remote_path)?; + std::io::copy(&mut reader, &mut writer)?; + Ok(()) + } + #[cfg(feature = "s3")] + Some("s3" | "r2") => { + let (bucket, path) = crate::oneio::s3::s3_url_parse(remote_path)?; + crate::oneio::s3::s3_download(bucket.as_str(), path.as_str(), local_path)?; + Ok(()) + } + Some(_) | None => Err(OneIoError::NotSupported(remote_path.to_string())), + } } /// Downloads a file from a remote path and saves it locally with retry mechanism. @@ -291,19 +326,17 @@ pub fn download_with_retry( /// an `Err` variant with a `OneIoError` is returned. pub(crate) fn remote_file_exists(path: &str) -> Result { match get_protocol(path) { - Some(protocol) => match protocol.as_str() { + Some(protocol) => match protocol { "http" | "https" => { - #[cfg(feature = "rustls")] - super::crypto::ensure_default_provider()?; - - let client = Client::builder() + let client = default_http_client()?; + let res = client + .head(path) .timeout(std::time::Duration::from_secs(2)) - .build()?; - let res = client.head(path).send()?; + .send()?; Ok(res.status().is_success()) } #[cfg(feature = "s3")] - "s3" => { + "s3" | "r2" => { let (bucket, path) = crate::oneio::s3::s3_url_parse(path)?; let res = crate::oneio::s3::s3_exists(bucket.as_str(), path.as_str())?; Ok(res) diff --git a/src/oneio/s3.rs b/src/oneio/s3.rs index 45b1788..3030df6 100644 --- a/src/oneio/s3.rs +++ b/src/oneio/s3.rs @@ -10,7 +10,8 @@ use crate::OneIoError; use s3::creds::Credentials; use s3::serde_types::{HeadObjectResult, ListBucketResult}; use s3::{Bucket, Region}; -use std::io::{Cursor, Read}; +use std::io::{Cursor, Read, Write}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; /// Checks if the necessary environment variables for AWS S3 are set. /// @@ -83,13 +84,125 @@ pub fn s3_env_check() -> Result<(), OneIoError> { /// /// Returns a `Result` containing the bucket and key as a tuple, or a `OneIoError` if parsing fails. pub fn s3_url_parse(path: &str) -> Result<(String, String), OneIoError> { - let parts = path.split('/').collect::>(); - if parts.len() < 3 { + let (_, remaining) = path + .split_once("://") + .ok_or_else(|| OneIoError::NotSupported(format!("Invalid S3 URL: {path}")))?; + let (bucket, key) = remaining + .split_once('/') + .ok_or_else(|| OneIoError::NotSupported(format!("Invalid S3 URL: {path}")))?; + if bucket.is_empty() || key.is_empty() { return Err(OneIoError::NotSupported(format!("Invalid S3 URL: {path}"))); } - let bucket = parts[2]; - let key = parts[3..].join("/"); - Ok((bucket.to_string(), key)) + Ok((bucket.to_string(), key.to_string())) +} + +enum StreamMessage { + Chunk(Vec), + Error(String), + Eof, +} + +struct StreamWriter { + sender: SyncSender, + closed: bool, +} + +impl StreamWriter { + fn new(sender: SyncSender) -> Self { + Self { + sender, + closed: false, + } + } + + fn send_error(&mut self, err: std::io::Error) -> std::io::Result<()> { + self.closed = true; + self.sender + .send(StreamMessage::Error(err.to_string())) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed")) + } + + fn close(&mut self) -> std::io::Result<()> { + if self.closed { + return Ok(()); + } + self.closed = true; + self.sender + .send(StreamMessage::Eof) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed")) + } +} + +impl Write for StreamWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sender + .send(StreamMessage::Chunk(buf.to_vec())) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed"))?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +impl Drop for StreamWriter { + fn drop(&mut self) { + let _ = self.close(); + } +} + +struct StreamReader { + receiver: Receiver, + current_chunk: Cursor>, + done: bool, +} + +impl StreamReader { + fn new(receiver: Receiver) -> Self { + Self { + receiver, + current_chunk: Cursor::new(Vec::new()), + done: false, + } + } +} + +impl Read for StreamReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if buf.is_empty() { + return Ok(0); + } + + loop { + let bytes_read = self.current_chunk.read(buf)?; + if bytes_read > 0 { + return Ok(bytes_read); + } + + if self.done { + return Ok(0); + } + + match self.receiver.recv() { + Ok(StreamMessage::Chunk(chunk)) => { + self.current_chunk = Cursor::new(chunk); + } + Ok(StreamMessage::Error(message)) => { + self.done = true; + return Err(std::io::Error::other(message)); + } + Ok(StreamMessage::Eof) => { + self.done = true; + return Ok(0); + } + Err(_) => { + self.done = true; + return Err(std::io::Error::other("S3 stream closed unexpectedly")); + } + } + } + } } /// Creates an S3 bucket object with the specified bucket name. @@ -172,9 +285,26 @@ pub fn s3_bucket(bucket: &str) -> Result { /// ``` pub fn s3_reader(bucket: &str, path: &str) -> Result, OneIoError> { let bucket = s3_bucket(bucket)?; - let object = bucket.get_object(path)?; - let buf: Vec = object.to_vec(); - Ok(Box::new(Cursor::new(buf))) + let path = path.to_string(); + let (sender, receiver) = sync_channel(8); + + std::thread::spawn(move || { + let mut writer = StreamWriter::new(sender); + match bucket.get_object_to_writer(path, &mut writer) { + Ok(200..=299) => { + let _ = writer.close(); + } + Ok(code) => { + let _ = + writer.send_error(std::io::Error::other(format!("S3 status error: {code}"))); + } + Err(err) => { + let _ = writer.send_error(std::io::Error::other(err.to_string())); + } + } + }); + + Ok(Box::new(StreamReader::new(receiver))) } /// Uploads a file to an S3 bucket at the specified path. @@ -316,9 +446,10 @@ pub fn s3_download(bucket: &str, s3_path: &str, file_path: &str) -> Result<(), O let res: u16 = bucket.get_object_to_writer(s3_path, &mut output_file)?; match res { 200..=299 => Ok(()), - _ => Err(OneIoError::Network(Box::new(std::io::Error::other( - format!("S3 HTTP error: {res}"), - )))), + _ => Err(OneIoError::Status { + service: "S3", + code: res, + }), } } @@ -360,9 +491,10 @@ pub fn s3_stats(bucket: &str, path: &str) -> Result Ok(head_object), - _ => Err(OneIoError::Network(Box::new(std::io::Error::other( - format!("S3 HTTP error: {code}"), - )))), + _ => Err(OneIoError::Status { + service: "S3", + code, + }), } } @@ -394,25 +526,11 @@ pub fn s3_stats(bucket: &str, path: &str) -> Result Result { match s3_stats(bucket, path) { Ok(_) => Ok(true), - Err(err) => { - // Check if this is a 404 network error by parsing the status code - if let OneIoError::Network(boxed_err) = &err { - let error_msg = boxed_err.to_string(); - if error_msg.starts_with("S3 HTTP error: ") { - // Parse the status code from the structured error message - if let Some(code_str) = error_msg.strip_prefix("S3 HTTP error: ") { - if let Ok(status_code) = code_str.parse::() { - return match status_code { - 404 => Ok(false), // Not Found - // 403 Forbidden means permission denied; propagate as error - _ => Err(err), // Other errors should propagate - }; - } - } - } - } - Err(err) - } + Err(OneIoError::Status { + service: "S3", + code: 404, + }) => Ok(false), + Err(err) => Err(err), } } @@ -480,6 +598,7 @@ pub fn s3_list( #[cfg(test)] mod tests { use super::*; + use std::io::{Read, Write}; #[test] fn test_s3_url_parse() { @@ -536,4 +655,38 @@ mod tests { } } } + + #[test] + fn test_stream_reader_reads_in_order() { + let (sender, receiver) = sync_channel(2); + let writer_thread = std::thread::spawn(move || { + let mut writer = StreamWriter::new(sender); + writer.write_all(b"hello ").unwrap(); + writer.write_all(b"world").unwrap(); + writer.close().unwrap(); + }); + + let mut reader = StreamReader::new(receiver); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + writer_thread.join().unwrap(); + + assert_eq!(output, "hello world"); + } + + #[test] + fn test_stream_reader_propagates_error() { + let (sender, receiver) = sync_channel(2); + let mut writer = StreamWriter::new(sender); + writer.write_all(b"hello").unwrap(); + writer + .send_error(std::io::Error::other("stream failed")) + .unwrap(); + + let mut reader = StreamReader::new(receiver); + let mut buf = [0; 5]; + reader.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"hello"); + assert!(reader.read(&mut [0; 1]).is_err()); + } } diff --git a/src/oneio/utils.rs b/src/oneio/utils.rs index 0a56dba..a9fd856 100644 --- a/src/oneio/utils.rs +++ b/src/oneio/utils.rs @@ -116,6 +116,6 @@ mod tests { assert_eq!(data.purpose, "test".to_string()); assert_eq!(data.version, 1); assert_eq!(data.meta.float, 1.1); - assert_eq!(data.meta.success, true); + assert!(data.meta.success); } } diff --git a/tests/async_integration.rs b/tests/async_integration.rs index db36ad0..3ad16ef 100644 --- a/tests/async_integration.rs +++ b/tests/async_integration.rs @@ -3,7 +3,6 @@ #![cfg(feature = "async")] -use oneio; use tokio::io::AsyncReadExt; const TEST_TEXT: &str = "OneIO test file.\nThis is a test."; @@ -81,3 +80,20 @@ async fn async_download_http_to_file() { let _ = std::fs::remove_file(tmp_path); } + +#[cfg(feature = "any_gz")] +#[tokio::test] +async fn async_download_preserves_compressed_bytes() { + let tmp_path = "tests/_tmp_async_download.txt.gz"; + let _ = std::fs::remove_file(tmp_path); + + oneio::download_async("tests/test_data.txt.gz", tmp_path) + .await + .unwrap(); + + let original = std::fs::read("tests/test_data.txt.gz").unwrap(); + let downloaded = std::fs::read(tmp_path).unwrap(); + assert_eq!(downloaded, original); + + let _ = std::fs::remove_file(tmp_path); +} diff --git a/tests/basic_integration.rs b/tests/basic_integration.rs index d4416ba..fb9793d 100644 --- a/tests/basic_integration.rs +++ b/tests/basic_integration.rs @@ -1,7 +1,6 @@ //! Basic integration tests using only default features (gz, bz, http) //! These tests should always pass with `cargo test` -use oneio; use std::io::{Read, Write}; const TEST_TEXT: &str = "OneIO test file.\nThis is a test.";