From 02ddcfb690033d6d1ffbc6fb6f9e255d1dea3fd0 Mon Sep 17 00:00:00 2001 From: Mingwei Zhang Date: Fri, 6 Mar 2026 06:50:53 -0800 Subject: [PATCH 1/2] fix: stream remote io paths and add decompression benchmarks Update async download to preserve raw bytes and stream cache and S3 reads instead of buffering whole payloads. Reuse default HTTP clients for blocking requests, add structured S3 status errors, expand decompression benchmarks for gzip backends and bz2, and refresh tests and changelog for the new behavior. --- CHANGELOG.md | 11 ++ Cargo.toml | 7 +- benches/bzip2_decompress.rs | 53 ++++++ benches/common/mod.rs | 38 ++++ benches/gzip_decompress.rs | 85 +++++---- examples/s3_operations.rs | 17 +- scripts/bench_decompression_backends.sh | 18 ++ src/error.rs | 4 + src/lib.rs | 3 + src/oneio/mod.rs | 63 +++---- src/oneio/remote.rs | 193 ++++++++++++--------- src/oneio/s3.rs | 221 ++++++++++++++++++++---- src/oneio/utils.rs | 2 +- tests/async_integration.rs | 18 +- tests/basic_integration.rs | 1 - 15 files changed, 530 insertions(+), 204 deletions(-) create mode 100644 benches/bzip2_decompress.rs create mode 100644 benches/common/mod.rs create mode 100755 scripts/bench_decompression_backends.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 052ae6d..c52d370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed +- Stream cache writes to disk via `std::io::copy` instead of buffering the full payload in memory +- `download_async()` now preserves raw bytes, matching `download()` +- Default blocking HTTP clients are reused across reads and content-length probes +- S3 status failures now use structured errors instead of string parsing +- S3 readers now stream data through a bounded channel instead of materializing the full object in memory + +### Added +- Added `bzip2_decompress` benchmark coverage +- Added a benchmark helper script for comparing gzip backend feature flags and bz2 decompression + ## v0.20.1 -- 2025-12-18 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 7c31989..7ee60e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ s3 = ["rust-s3"] gz = ["gz-zlib-rs"] # internal feature to enable gzip support any_gz = [] -gz-miniz = ["any_gz", "flate2/miniz_oxide"] +gz-miniz = ["any_gz", "flate2/rust_backend"] gz-zlib-rs = ["any_gz", "flate2/zlib-rs"] gz-zlib-ng = ["any_gz", "flate2/zlib-ng"] gz-zlib-cloudflare = ["any_gz", "flate2/cloudflare_zlib"] @@ -128,6 +128,11 @@ name = "gzip_decompress" harness = false required-features = ["any_gz"] +[[bench]] +name = "bzip2_decompress" +harness = false +required-features = ["bz"] + # This list only includes examples which require additional features to run. These are more in the examples' directory. [[example]] name = "s3_operations" diff --git a/benches/bzip2_decompress.rs b/benches/bzip2_decompress.rs new file mode 100644 index 0000000..54be42f --- /dev/null +++ b/benches/bzip2_decompress.rs @@ -0,0 +1,53 @@ +mod common; + +use std::hint::black_box; +use std::io::{Read, Write}; + +use bzip2::read::BzDecoder; +use bzip2::write::BzEncoder; +use bzip2::Compression; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; + +fn build_bzip2_fixture() -> (Vec, usize, String) { + let corpus = common::build_text_corpus(); + let mut encoder = BzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&corpus).unwrap(); + let compressed = encoder.finish().unwrap(); + let fixture = common::write_fixture("bzip2.txt.bz2", &compressed); + + ( + compressed, + corpus.len(), + fixture.to_string_lossy().into_owned(), + ) +} + +fn bench_bzip2_decompress(c: &mut Criterion) { + let (input, output_len, fixture_path) = build_bzip2_fixture(); + + let mut group = c.benchmark_group("bzip2_decompress"); + group.throughput(Throughput::Bytes(output_len as u64)); + + group.bench_function("raw_decoder", |b| { + b.iter(|| { + let mut reader = BzDecoder::new(input.as_slice()); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.bench_function("oneio_get_reader", |b| { + b.iter(|| { + let mut reader = oneio::get_reader(&fixture_path).unwrap(); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.finish(); +} + +criterion_group!(benches, bench_bzip2_decompress); +criterion_main!(benches); diff --git a/benches/common/mod.rs b/benches/common/mod.rs new file mode 100644 index 0000000..7013938 --- /dev/null +++ b/benches/common/mod.rs @@ -0,0 +1,38 @@ +use std::fs; +use std::io::Write; +use std::path::PathBuf; + +const TARGET_CORPUS_SIZE: usize = 16 * 1024 * 1024; + +pub fn build_text_corpus() -> Vec { + let mut data = Vec::with_capacity(TARGET_CORPUS_SIZE); + let mut seq = 0_u64; + + while data.len() < TARGET_CORPUS_SIZE { + writeln!( + &mut data, + "{seq},AS{:05},AS{:05},peer=route-views.eqix,next-hop=192.0.2.{},med={},local-pref={},community={}:{}", + (seq % 64512) + 100, + ((seq * 7) % 64512) + 100, + (seq % 254) + 1, + seq % 1000, + 100 + (seq % 200), + 64512 + (seq % 64), + 100 + (seq % 4096) + ) + .unwrap(); + seq += 1; + } + + data.truncate(TARGET_CORPUS_SIZE); + data +} + +pub fn write_fixture(name: &str, bytes: &[u8]) -> PathBuf { + let fixture_dir = PathBuf::from("target/bench-fixtures"); + fs::create_dir_all(&fixture_dir).unwrap(); + + let path = fixture_dir.join(name); + fs::write(&path, bytes).unwrap(); + path +} diff --git a/benches/gzip_decompress.rs b/benches/gzip_decompress.rs index a22e413..df45b1e 100644 --- a/benches/gzip_decompress.rs +++ b/benches/gzip_decompress.rs @@ -1,42 +1,67 @@ -use std::fs::File; +mod common; + use std::hint::black_box; -use std::io::Read; +use std::io::{Read, Write}; -use criterion::{criterion_group, criterion_main, BatchSize, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use flate2::read::GzDecoder; +use flate2::write::GzEncoder; +use flate2::Compression; + +#[cfg(feature = "gz-miniz")] +const GZIP_BACKEND: &str = "miniz_oxide"; +#[cfg(all(not(feature = "gz-miniz"), feature = "gz-zlib-rs"))] +const GZIP_BACKEND: &str = "zlib-rs"; +#[cfg(all( + not(feature = "gz-miniz"), + not(feature = "gz-zlib-rs"), + feature = "gz-zlib-ng" +))] +const GZIP_BACKEND: &str = "zlib-ng"; +#[cfg(all( + not(feature = "gz-miniz"), + not(feature = "gz-zlib-rs"), + not(feature = "gz-zlib-ng"), + feature = "gz-zlib-cloudflare" +))] +const GZIP_BACKEND: &str = "cloudflare-zlib"; + +fn build_gzip_fixture() -> (Vec, usize, String) { + let corpus = common::build_text_corpus(); + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&corpus).unwrap(); + let compressed = encoder.finish().unwrap(); + let fixture = common::write_fixture(&format!("gzip-{GZIP_BACKEND}.txt.gz"), &compressed); -// Benchmark gzip decompression using flate2 with the selected backend. -// To run with default (miniz_oxide) backend: -// cargo bench --bench gzip_decompress --no-default-features --features gz-miniz -// To run with zlib-rs backend: -// cargo bench --bench gzip_decompress --no-default-features --features gz-zlib-rs -// To compare, run both commands and compare Criterion reports. - -fn load_gz_bytes() -> Vec { - let mut f = File::open("tests/test_data.txt.gz").expect("missing tests/test_data.txt.gz"); - let mut buf = Vec::new(); - f.read_to_end(&mut buf).unwrap(); - buf + ( + compressed, + corpus.len(), + fixture.to_string_lossy().into_owned(), + ) } fn bench_gzip_decompress(c: &mut Criterion) { - let input = load_gz_bytes(); + let (input, output_len, fixture_path) = build_gzip_fixture(); let mut group = c.benchmark_group("gzip_decompress"); - group.throughput(Throughput::Bytes(input.len() as u64)); - - group.bench_function("flate2_gz_decode", |b| { - b.iter_batched( - || input.clone(), - |bytes| { - let reader = GzDecoder::new(bytes.as_slice()); - let mut out = Vec::with_capacity(128 * 1024); - let mut r = reader; - r.read_to_end(&mut out).unwrap(); - black_box(out) - }, - BatchSize::SmallInput, - ) + group.throughput(Throughput::Bytes(output_len as u64)); + + group.bench_function(format!("raw_decoder/{GZIP_BACKEND}"), |b| { + b.iter(|| { + let mut reader = GzDecoder::new(input.as_slice()); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) + }); + + group.bench_function(format!("oneio_get_reader/{GZIP_BACKEND}"), |b| { + b.iter(|| { + let mut reader = oneio::get_reader(&fixture_path).unwrap(); + let mut out = Vec::with_capacity(output_len); + reader.read_to_end(&mut out).unwrap(); + black_box(out.len()) + }) }); group.finish(); diff --git a/examples/s3_operations.rs b/examples/s3_operations.rs index 1552a94..3d162b2 100644 --- a/examples/s3_operations.rs +++ b/examples/s3_operations.rs @@ -32,27 +32,18 @@ fn main() { info!("error if file does not exist"); let res = s3_stats("oneio-test", "test/README___NON_EXISTS.md"); assert!(res.is_err()); - assert_eq!( - false, - s3_exists("oneio-test", "test/README___NON_EXISTS.md").unwrap() - ); - assert_eq!(true, s3_exists("oneio-test", "test/README.md").unwrap()); + assert!(!s3_exists("oneio-test", "test/README___NON_EXISTS.md").unwrap()); + assert!(s3_exists("oneio-test", "test/README.md").unwrap()); info!("copy S3 file to a different location"); let res = s3_copy("oneio-test", "test/README.md", "test/README-temporary.md"); assert!(res.is_ok()); - assert_eq!( - true, - s3_exists("oneio-test", "test/README-temporary.md").unwrap() - ); + assert!(s3_exists("oneio-test", "test/README-temporary.md").unwrap()); info!("delete temporary copied S3 file"); let res = s3_delete("oneio-test", "test/README-temporary.md"); assert!(res.is_ok()); - assert_eq!( - false, - s3_exists("oneio-test", "test/README-temporary.md").unwrap() - ); + assert!(!s3_exists("oneio-test", "test/README-temporary.md").unwrap()); info!("list S3 files"); let res = s3_list("oneio-test", "test/", Some("/".to_string()), false).unwrap(); diff --git a/scripts/bench_decompression_backends.sh b/scripts/bench_decompression_backends.sh new file mode 100755 index 0000000..bdfaaba --- /dev/null +++ b/scripts/bench_decompression_backends.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -euo pipefail + +criterion_args=(-- --warm-up-time 1 --measurement-time 5 --sample-size 20) + +run_case() { + local name="$1" + shift + echo + echo "== ${name} ==" + cargo bench "$@" "${criterion_args[@]}" +} + +run_case "gzip miniz_oxide" --bench gzip_decompress --no-default-features --features gz-miniz +run_case "gzip zlib-rs" --bench gzip_decompress --no-default-features --features gz-zlib-rs +run_case "gzip zlib-ng" --bench gzip_decompress --no-default-features --features gz-zlib-ng +run_case "gzip cloudflare-zlib" --bench gzip_decompress --no-default-features --features gz-zlib-cloudflare +run_case "bzip2" --bench bzip2_decompress --no-default-features --features bz diff --git a/src/error.rs b/src/error.rs index edc4a79..cf9d500 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,10 @@ pub enum OneIoError { #[error("{0}")] Network(Box), + /// Structured status errors from remote services + #[error("{service} status error: {code}")] + Status { service: &'static str, code: u16 }, + /// Feature not supported/compiled #[error("Not supported: {0}")] NotSupported(String), diff --git a/src/lib.rs b/src/lib.rs index a415c50..21733f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,6 +191,8 @@ async fn main() -> Result<(), Box> { "local_data.csv.gz" ).await?; + // download_async preserves the remote bytes. + Ok(()) } ``` @@ -296,6 +298,7 @@ match oneio::get_reader("file.txt") { Ok(reader) => { /* use reader */ }, Err(OneIoError::Io(e)) => { /* filesystem error */ }, Err(OneIoError::Network(e)) => { /* network error */ }, + Err(OneIoError::Status { service, code }) => { /* remote status error */ }, Err(OneIoError::NotSupported(msg)) => { /* feature not compiled */ }, } ``` diff --git a/src/oneio/mod.rs b/src/oneio/mod.rs index 886a080..1c6b6b7 100644 --- a/src/oneio/mod.rs +++ b/src/oneio/mod.rs @@ -21,12 +21,8 @@ use std::path::Path; use futures::StreamExt; /// Extracts the protocol from a given path. -pub(crate) fn get_protocol(path: &str) -> Option { - let parts = path.split("://").collect::>(); - if parts.len() < 2 { - return None; - } - Some(parts[0].to_string()) +pub(crate) fn get_protocol(path: &str) -> Option<&str> { + path.split_once("://").map(|(protocol, _)| protocol) } pub fn get_writer_raw(path: &str) -> Result, OneIoError> { @@ -40,7 +36,7 @@ pub fn get_writer_raw(path: &str) -> Result, OneIoError> { pub fn get_reader_raw(path: &str) -> Result, OneIoError> { let raw_reader: Box = match get_protocol(path) { - Some(protocol) => match protocol.as_str() { + Some(protocol) => match protocol { #[cfg(feature = "http")] "http" | "https" => { let response = remote::get_http_reader_raw(path, None)?; @@ -54,7 +50,7 @@ pub fn get_reader_raw(path: &str) -> Result, OneIoError> { #[cfg(feature = "s3")] "s3" | "r2" => { let (bucket, path) = s3::s3_url_parse(path)?; - Box::new(s3::s3_reader(bucket.as_str(), path.as_str())?) + s3::s3_reader(bucket.as_str(), path.as_str())? } _ => { return Err(OneIoError::NotSupported(path.to_string())); @@ -131,11 +127,9 @@ pub fn get_cache_reader( // read all to cache file, no encode/decode happens let mut reader = get_reader_raw(path)?; - let mut data: Vec = vec![]; - reader.read_to_end(&mut data)?; let mut writer = get_writer_raw(cache_file_path.as_str())?; - writer.write_all(&data)?; - drop(writer); + std::io::copy(&mut reader, &mut writer)?; + writer.flush()?; // return reader from cache file get_reader(cache_file_path.as_str()) @@ -163,7 +157,7 @@ pub fn get_cache_reader( /// }; /// ``` pub fn get_writer(path: &str) -> Result, OneIoError> { - let output_file = BufWriter::new(File::create(path)?); + let output_file = get_writer_raw(path)?; let file_type = path.rsplit('.').next().unwrap_or(""); get_compression_writer(output_file, file_type) @@ -255,24 +249,7 @@ pub fn get_content_length(path: &str) -> Result { match get_protocol(path) { #[cfg(feature = "http")] Some(protocol) if protocol == "http" || protocol == "https" => { - #[cfg(feature = "rustls")] - crypto::ensure_default_provider()?; - - // HEAD request to get Content-Length - let client = reqwest::blocking::Client::new(); - let response = client.head(path).send()?; - - response - .headers() - .get("content-length") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse().ok()) - .ok_or_else(|| { - OneIoError::NotSupported( - "Cannot determine file size - server doesn't provide Content-Length" - .to_string(), - ) - }) + remote::get_http_content_length(path) } #[cfg(feature = "ftp")] Some(protocol) if protocol == "ftp" => { @@ -464,8 +441,8 @@ pub async fn read_to_string_async(path: &str) -> Result { /// Downloads a file asynchronously from a URL to a local path /// -/// This is the async version of `download()`. It supports all protocols and -/// handles decompression if needed. +/// This is the async version of `download()`. It preserves the raw bytes from +/// the source, matching the synchronous `download()` behavior. /// /// # Arguments /// * `url` - Source URL to download from @@ -491,20 +468,18 @@ pub async fn read_to_string_async(path: &str) -> Result { #[cfg(feature = "async")] pub async fn download_async(url: &str, path: &str) -> Result<(), OneIoError> { use tokio::fs::File; - use tokio::io::AsyncWriteExt; + use tokio::io::{copy, AsyncWriteExt}; - let mut reader = get_reader_async(url).await?; - let mut file = File::create(path).await?; - - let mut buffer = vec![0u8; 8192]; - loop { - let bytes_read = reader.read(&mut buffer).await?; - if bytes_read == 0 { - break; + if let Some(parent) = Path::new(path).parent() { + if !parent.as_os_str().is_empty() { + tokio::fs::create_dir_all(parent).await?; } - file.write_all(&buffer[..bytes_read]).await?; } + let mut reader = get_async_reader_raw(url).await?; + let mut file = File::create(path).await?; + copy(&mut reader, &mut file).await?; + file.flush().await?; Ok(()) } @@ -612,8 +587,10 @@ fn get_async_compression_reader( #[cfg(test)] mod tests { use super::*; + #[cfg(any(feature = "any_gz", feature = "http"))] use std::io::Read; + #[cfg(any(feature = "any_gz", feature = "http", feature = "async"))] const TEST_TEXT: &str = "OneIO test file.\nThis is a test."; #[cfg(feature = "any_gz")] diff --git a/src/oneio/remote.rs b/src/oneio/remote.rs index 45ad996..a6403ca 100644 --- a/src/oneio/remote.rs +++ b/src/oneio/remote.rs @@ -5,6 +5,11 @@ use crate::OneIoError; #[cfg(feature = "http")] use reqwest::blocking::Client; use std::io::Read; +#[cfg(feature = "http")] +use std::sync::OnceLock; + +#[cfg(feature = "http")] +static DEFAULT_HTTP_CLIENT: OnceLock> = OnceLock::new(); #[cfg(feature = "ftp")] pub(crate) fn get_ftp_reader_raw(path: &str) -> Result, OneIoError> { @@ -15,69 +20,87 @@ pub(crate) fn get_ftp_reader_raw(path: &str) -> Result, One #[cfg(feature = "rustls")] super::crypto::ensure_default_provider()?; - let parts = path.split('/').collect::>(); - let socket = match parts[2].contains(':') { - true => parts[2].to_string(), - false => format!("{}:21", parts[2]), + let path_without_scheme = path + .strip_prefix("ftp://") + .ok_or_else(|| OneIoError::NotSupported(path.to_string()))?; + let (host, remote_path) = path_without_scheme + .split_once('/') + .ok_or_else(|| OneIoError::NotSupported(path.to_string()))?; + let socket = match host.contains(':') { + true => host.to_string(), + false => format!("{host}:21"), }; - let path = parts[3..].join("/"); let mut ftp_stream = suppaftp::FtpStream::connect(socket)?; // use anonymous login ftp_stream.login("anonymous", "oneio")?; ftp_stream.transfer_type(suppaftp::types::FileType::Binary)?; - let reader = Box::new(ftp_stream.retr_as_stream(path.as_str())?); + let reader = Box::new(ftp_stream.retr_as_stream(remote_path)?); Ok(reader) } #[cfg(feature = "http")] -pub(crate) fn get_http_reader_raw( - path: &str, - opt_client: Option, -) -> Result { +fn build_default_http_client() -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::USER_AGENT, + reqwest::header::HeaderValue::from_static("oneio"), + ); + headers.insert( + reqwest::header::CONTENT_LENGTH, + reqwest::header::HeaderValue::from_static("0"), + ); + #[cfg(feature = "cli")] + headers.insert( + reqwest::header::CACHE_CONTROL, + reqwest::header::HeaderValue::from_static("no-cache"), + ); + + #[cfg(any(feature = "rustls", feature = "native-tls"))] + { + let accept_invalid_certs = matches!( + std::env::var("ONEIO_ACCEPT_INVALID_CERTS") + .unwrap_or_default() + .to_lowercase() + .as_str(), + "true" | "yes" | "y" | "1" + ); + Client::builder() + .default_headers(headers) + .danger_accept_invalid_certs(accept_invalid_certs) + .build() + } + + #[cfg(not(any(feature = "rustls", feature = "native-tls")))] + { + Client::builder().default_headers(headers).build() + } +} + +#[cfg(feature = "http")] +fn default_http_client() -> Result { dotenvy::dotenv().ok(); #[cfg(feature = "rustls")] super::crypto::ensure_default_provider()?; + match DEFAULT_HTTP_CLIENT.get_or_init(|| build_default_http_client().map_err(|e| e.to_string())) + { + Ok(client) => Ok(client.clone()), + Err(message) => Err(OneIoError::Network(Box::new(std::io::Error::other( + message.clone(), + )))), + } +} + +#[cfg(feature = "http")] +pub(crate) fn get_http_reader_raw( + path: &str, + opt_client: Option, +) -> Result { let client = match opt_client { Some(c) => c, - None => { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::USER_AGENT, - reqwest::header::HeaderValue::from_static("oneio"), - ); - headers.insert( - reqwest::header::CONTENT_LENGTH, - reqwest::header::HeaderValue::from_static("0"), - ); - #[cfg(feature = "cli")] - headers.insert( - reqwest::header::CACHE_CONTROL, - reqwest::header::HeaderValue::from_static("no-cache"), - ); - - #[cfg(any(feature = "rustls", feature = "native-tls"))] - { - let accept_invalid_certs = matches!( - std::env::var("ONEIO_ACCEPT_INVALID_CERTS") - .unwrap_or_default() - .to_lowercase() - .as_str(), - "true" | "yes" | "y" | "1" - ); - Client::builder() - .default_headers(headers) - .danger_accept_invalid_certs(accept_invalid_certs) - .build()? - } - - #[cfg(not(any(feature = "rustls", feature = "native-tls")))] - { - Client::builder().default_headers(headers).build()? - } - } + None => default_http_client()?, }; let res = client .execute(client.get(path).build()?)? @@ -162,6 +185,23 @@ pub fn get_http_reader( get_compression_reader(raw_reader, file_type) } +#[cfg(feature = "http")] +pub(crate) fn get_http_content_length(path: &str) -> Result { + let client = default_http_client()?; + let response = client.head(path).send()?.error_for_status()?; + + response + .headers() + .get("content-length") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse().ok()) + .ok_or_else(|| { + OneIoError::NotSupported( + "Cannot determine file size - server doesn't provide Content-Length".to_string(), + ) + }) +} + /// Downloads a file from a remote location to a local path. /// /// # Arguments @@ -198,33 +238,28 @@ pub fn download( opt_client: Option, ) -> Result<(), OneIoError> { match get_protocol(remote_path) { - None => { - return Err(OneIoError::NotSupported(remote_path.to_string())); + #[cfg(feature = "http")] + Some("http" | "https") => { + let mut writer = get_writer_raw(local_path)?; + let mut response = get_http_reader_raw(remote_path, opt_client)?; + response.copy_to(&mut writer)?; + Ok(()) } - Some(protocol) => match protocol.as_str() { - #[cfg(feature = "http")] - "http" | "https" => { - let mut writer = get_writer_raw(local_path)?; - let mut response = get_http_reader_raw(remote_path, opt_client)?; - response.copy_to(&mut writer)?; - } - #[cfg(feature = "ftp")] - "ftp" => { - let mut writer = get_writer_raw(local_path)?; - let mut reader = get_ftp_reader_raw(remote_path)?; - std::io::copy(&mut reader, &mut writer)?; - } - #[cfg(feature = "s3")] - "s3" => { - let (bucket, path) = crate::oneio::s3::s3_url_parse(remote_path)?; - crate::oneio::s3::s3_download(bucket.as_str(), path.as_str(), local_path)?; - } - _ => { - return Err(OneIoError::NotSupported(remote_path.to_string())); - } - }, - }; - Ok(()) + #[cfg(feature = "ftp")] + Some("ftp") => { + let mut writer = get_writer_raw(local_path)?; + let mut reader = get_ftp_reader_raw(remote_path)?; + std::io::copy(&mut reader, &mut writer)?; + Ok(()) + } + #[cfg(feature = "s3")] + Some("s3" | "r2") => { + let (bucket, path) = crate::oneio::s3::s3_url_parse(remote_path)?; + crate::oneio::s3::s3_download(bucket.as_str(), path.as_str(), local_path)?; + Ok(()) + } + Some(_) | None => Err(OneIoError::NotSupported(remote_path.to_string())), + } } /// Downloads a file from a remote path and saves it locally with retry mechanism. @@ -291,19 +326,17 @@ pub fn download_with_retry( /// an `Err` variant with a `OneIoError` is returned. pub(crate) fn remote_file_exists(path: &str) -> Result { match get_protocol(path) { - Some(protocol) => match protocol.as_str() { + Some(protocol) => match protocol { "http" | "https" => { - #[cfg(feature = "rustls")] - super::crypto::ensure_default_provider()?; - - let client = Client::builder() + let client = default_http_client()?; + let res = client + .head(path) .timeout(std::time::Duration::from_secs(2)) - .build()?; - let res = client.head(path).send()?; + .send()?; Ok(res.status().is_success()) } #[cfg(feature = "s3")] - "s3" => { + "s3" | "r2" => { let (bucket, path) = crate::oneio::s3::s3_url_parse(path)?; let res = crate::oneio::s3::s3_exists(bucket.as_str(), path.as_str())?; Ok(res) diff --git a/src/oneio/s3.rs b/src/oneio/s3.rs index 45b1788..3030df6 100644 --- a/src/oneio/s3.rs +++ b/src/oneio/s3.rs @@ -10,7 +10,8 @@ use crate::OneIoError; use s3::creds::Credentials; use s3::serde_types::{HeadObjectResult, ListBucketResult}; use s3::{Bucket, Region}; -use std::io::{Cursor, Read}; +use std::io::{Cursor, Read, Write}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; /// Checks if the necessary environment variables for AWS S3 are set. /// @@ -83,13 +84,125 @@ pub fn s3_env_check() -> Result<(), OneIoError> { /// /// Returns a `Result` containing the bucket and key as a tuple, or a `OneIoError` if parsing fails. pub fn s3_url_parse(path: &str) -> Result<(String, String), OneIoError> { - let parts = path.split('/').collect::>(); - if parts.len() < 3 { + let (_, remaining) = path + .split_once("://") + .ok_or_else(|| OneIoError::NotSupported(format!("Invalid S3 URL: {path}")))?; + let (bucket, key) = remaining + .split_once('/') + .ok_or_else(|| OneIoError::NotSupported(format!("Invalid S3 URL: {path}")))?; + if bucket.is_empty() || key.is_empty() { return Err(OneIoError::NotSupported(format!("Invalid S3 URL: {path}"))); } - let bucket = parts[2]; - let key = parts[3..].join("/"); - Ok((bucket.to_string(), key)) + Ok((bucket.to_string(), key.to_string())) +} + +enum StreamMessage { + Chunk(Vec), + Error(String), + Eof, +} + +struct StreamWriter { + sender: SyncSender, + closed: bool, +} + +impl StreamWriter { + fn new(sender: SyncSender) -> Self { + Self { + sender, + closed: false, + } + } + + fn send_error(&mut self, err: std::io::Error) -> std::io::Result<()> { + self.closed = true; + self.sender + .send(StreamMessage::Error(err.to_string())) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed")) + } + + fn close(&mut self) -> std::io::Result<()> { + if self.closed { + return Ok(()); + } + self.closed = true; + self.sender + .send(StreamMessage::Eof) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed")) + } +} + +impl Write for StreamWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.sender + .send(StreamMessage::Chunk(buf.to_vec())) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "stream closed"))?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +impl Drop for StreamWriter { + fn drop(&mut self) { + let _ = self.close(); + } +} + +struct StreamReader { + receiver: Receiver, + current_chunk: Cursor>, + done: bool, +} + +impl StreamReader { + fn new(receiver: Receiver) -> Self { + Self { + receiver, + current_chunk: Cursor::new(Vec::new()), + done: false, + } + } +} + +impl Read for StreamReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if buf.is_empty() { + return Ok(0); + } + + loop { + let bytes_read = self.current_chunk.read(buf)?; + if bytes_read > 0 { + return Ok(bytes_read); + } + + if self.done { + return Ok(0); + } + + match self.receiver.recv() { + Ok(StreamMessage::Chunk(chunk)) => { + self.current_chunk = Cursor::new(chunk); + } + Ok(StreamMessage::Error(message)) => { + self.done = true; + return Err(std::io::Error::other(message)); + } + Ok(StreamMessage::Eof) => { + self.done = true; + return Ok(0); + } + Err(_) => { + self.done = true; + return Err(std::io::Error::other("S3 stream closed unexpectedly")); + } + } + } + } } /// Creates an S3 bucket object with the specified bucket name. @@ -172,9 +285,26 @@ pub fn s3_bucket(bucket: &str) -> Result { /// ``` pub fn s3_reader(bucket: &str, path: &str) -> Result, OneIoError> { let bucket = s3_bucket(bucket)?; - let object = bucket.get_object(path)?; - let buf: Vec = object.to_vec(); - Ok(Box::new(Cursor::new(buf))) + let path = path.to_string(); + let (sender, receiver) = sync_channel(8); + + std::thread::spawn(move || { + let mut writer = StreamWriter::new(sender); + match bucket.get_object_to_writer(path, &mut writer) { + Ok(200..=299) => { + let _ = writer.close(); + } + Ok(code) => { + let _ = + writer.send_error(std::io::Error::other(format!("S3 status error: {code}"))); + } + Err(err) => { + let _ = writer.send_error(std::io::Error::other(err.to_string())); + } + } + }); + + Ok(Box::new(StreamReader::new(receiver))) } /// Uploads a file to an S3 bucket at the specified path. @@ -316,9 +446,10 @@ pub fn s3_download(bucket: &str, s3_path: &str, file_path: &str) -> Result<(), O let res: u16 = bucket.get_object_to_writer(s3_path, &mut output_file)?; match res { 200..=299 => Ok(()), - _ => Err(OneIoError::Network(Box::new(std::io::Error::other( - format!("S3 HTTP error: {res}"), - )))), + _ => Err(OneIoError::Status { + service: "S3", + code: res, + }), } } @@ -360,9 +491,10 @@ pub fn s3_stats(bucket: &str, path: &str) -> Result Ok(head_object), - _ => Err(OneIoError::Network(Box::new(std::io::Error::other( - format!("S3 HTTP error: {code}"), - )))), + _ => Err(OneIoError::Status { + service: "S3", + code, + }), } } @@ -394,25 +526,11 @@ pub fn s3_stats(bucket: &str, path: &str) -> Result Result { match s3_stats(bucket, path) { Ok(_) => Ok(true), - Err(err) => { - // Check if this is a 404 network error by parsing the status code - if let OneIoError::Network(boxed_err) = &err { - let error_msg = boxed_err.to_string(); - if error_msg.starts_with("S3 HTTP error: ") { - // Parse the status code from the structured error message - if let Some(code_str) = error_msg.strip_prefix("S3 HTTP error: ") { - if let Ok(status_code) = code_str.parse::() { - return match status_code { - 404 => Ok(false), // Not Found - // 403 Forbidden means permission denied; propagate as error - _ => Err(err), // Other errors should propagate - }; - } - } - } - } - Err(err) - } + Err(OneIoError::Status { + service: "S3", + code: 404, + }) => Ok(false), + Err(err) => Err(err), } } @@ -480,6 +598,7 @@ pub fn s3_list( #[cfg(test)] mod tests { use super::*; + use std::io::{Read, Write}; #[test] fn test_s3_url_parse() { @@ -536,4 +655,38 @@ mod tests { } } } + + #[test] + fn test_stream_reader_reads_in_order() { + let (sender, receiver) = sync_channel(2); + let writer_thread = std::thread::spawn(move || { + let mut writer = StreamWriter::new(sender); + writer.write_all(b"hello ").unwrap(); + writer.write_all(b"world").unwrap(); + writer.close().unwrap(); + }); + + let mut reader = StreamReader::new(receiver); + let mut output = String::new(); + reader.read_to_string(&mut output).unwrap(); + writer_thread.join().unwrap(); + + assert_eq!(output, "hello world"); + } + + #[test] + fn test_stream_reader_propagates_error() { + let (sender, receiver) = sync_channel(2); + let mut writer = StreamWriter::new(sender); + writer.write_all(b"hello").unwrap(); + writer + .send_error(std::io::Error::other("stream failed")) + .unwrap(); + + let mut reader = StreamReader::new(receiver); + let mut buf = [0; 5]; + reader.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"hello"); + assert!(reader.read(&mut [0; 1]).is_err()); + } } diff --git a/src/oneio/utils.rs b/src/oneio/utils.rs index 0a56dba..a9fd856 100644 --- a/src/oneio/utils.rs +++ b/src/oneio/utils.rs @@ -116,6 +116,6 @@ mod tests { assert_eq!(data.purpose, "test".to_string()); assert_eq!(data.version, 1); assert_eq!(data.meta.float, 1.1); - assert_eq!(data.meta.success, true); + assert!(data.meta.success); } } diff --git a/tests/async_integration.rs b/tests/async_integration.rs index db36ad0..3ad16ef 100644 --- a/tests/async_integration.rs +++ b/tests/async_integration.rs @@ -3,7 +3,6 @@ #![cfg(feature = "async")] -use oneio; use tokio::io::AsyncReadExt; const TEST_TEXT: &str = "OneIO test file.\nThis is a test."; @@ -81,3 +80,20 @@ async fn async_download_http_to_file() { let _ = std::fs::remove_file(tmp_path); } + +#[cfg(feature = "any_gz")] +#[tokio::test] +async fn async_download_preserves_compressed_bytes() { + let tmp_path = "tests/_tmp_async_download.txt.gz"; + let _ = std::fs::remove_file(tmp_path); + + oneio::download_async("tests/test_data.txt.gz", tmp_path) + .await + .unwrap(); + + let original = std::fs::read("tests/test_data.txt.gz").unwrap(); + let downloaded = std::fs::read(tmp_path).unwrap(); + assert_eq!(downloaded, original); + + let _ = std::fs::remove_file(tmp_path); +} diff --git a/tests/basic_integration.rs b/tests/basic_integration.rs index d4416ba..fb9793d 100644 --- a/tests/basic_integration.rs +++ b/tests/basic_integration.rs @@ -1,7 +1,6 @@ //! Basic integration tests using only default features (gz, bz, http) //! These tests should always pass with `cargo test` -use oneio; use std::io::{Read, Write}; const TEST_TEXT: &str = "OneIO test file.\nThis is a test."; From 2edadf6cb59d8939d83efdeed556367bdc1e4c14 Mon Sep 17 00:00:00 2001 From: Mingwei Zhang Date: Fri, 6 Mar 2026 06:51:11 -0800 Subject: [PATCH 2/2] docs: update README.md from lib.rs documentation --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 84f2cb3..10d94da 100644 --- a/README.md +++ b/README.md @@ -191,6 +191,8 @@ async fn main() -> Result<(), Box> { "local_data.csv.gz" ).await?; + // download_async preserves the remote bytes. + Ok(()) } ``` @@ -295,6 +297,7 @@ match oneio::get_reader("file.txt") { Ok(reader) => { /* use reader */ }, Err(OneIoError::Io(e)) => { /* filesystem error */ }, Err(OneIoError::Network(e)) => { /* network error */ }, + Err(OneIoError::Status { service, code }) => { /* remote status error */ }, Err(OneIoError::NotSupported(msg)) => { /* feature not compiled */ }, } ```