Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 72 additions & 27 deletions connectrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ pub type ClientBody = BoxBody<Bytes, ConnectError>;

/// Wrap a fully-known buffer in a [`ClientBody`].
///
/// Used by unary, server-streaming, and client-streaming calls where the
/// complete request body is available before sending.
/// Used by unary and server-streaming calls where the complete request body
/// is available before sending.
#[inline]
pub fn full_body(b: Bytes) -> ClientBody {
Full::new(b).map_err(|never| match never {}).boxed()
Expand Down Expand Up @@ -2160,9 +2160,10 @@ where

/// A request body that pulls envelope-encoded frames from an mpsc channel.
///
/// Used as the request body for bidirectional streaming calls. [`BidiStream::send`]
/// pushes encoded envelopes to the channel's sender half; dropping the sender
/// (via [`BidiStream::close_send`]) closes the body, signalling EOF to the server.
/// Used as the request body for bidirectional and client-streaming calls.
/// [`BidiStream::send`] pushes encoded envelopes to the channel's sender half;
/// dropping the sender (via [`BidiStream::close_send`]) closes the body,
/// signalling EOF to the server.
struct ChannelBody {
rx: tokio::sync::mpsc::Receiver<Result<Bytes, ConnectError>>,
}
Expand Down Expand Up @@ -2529,6 +2530,11 @@ where
// ChannelBody gets polled as sends happen, independent of when the
// caller first calls message(). See RecvState doc for the deadlock
// this avoids.
//
// Uses tokio::spawn directly (not spawn_detached) because
// RecvState::Pending needs JoinHandle<Result<...>>. If wasm32+client
// becomes supported, factor this into a spawn_with_result helper that
// bridges via oneshot on wasm.
let response_fut = transport.send(http_request);
let response_task = tokio::spawn(async move {
response_fut
Expand Down Expand Up @@ -2558,6 +2564,13 @@ where
/// Sends multiple request messages as envelope-framed data and receives a single
/// envelope-framed response with END_STREAM. Returns a [`UnaryResponse`] containing
/// the decoded response message along with headers and trailers.
///
/// The request body is streamed: each item from the iterator is encoded into
/// an envelope and pushed to a bounded mpsc channel that backs the HTTP
/// request body. The transport begins sending as soon as the first envelope
/// is ready instead of waiting for the iterator to be fully drained, so peak
/// memory stays around `channel_depth * envelope_size` rather than the full
/// concatenated body.
pub async fn call_client_stream<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
Expand All @@ -2583,7 +2596,11 @@ where
.parse()
.map_err(|e| ConnectError::internal(format!("invalid URI: {e}")))?;

// Encode each request message as an envelope and concatenate
// Channel-backed request body. Depth 32 matches `call_bidi_stream` and
// gives natural backpressure on HTTP/2 flow control.
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, ConnectError>>(32);
let body: ClientBody = ChannelBody { rx }.boxed();

let compression_for_encoder = config.request_compression.as_ref().map(|enc| {
(
std::sync::Arc::new(config.compression.clone()),
Expand All @@ -2594,22 +2611,6 @@ where
compression_for_encoder,
config.compression_policy.with_override(options.compress),
);
let mut body_buf = BytesMut::new();
for request in requests {
let msg_bytes = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};

tokio_util::codec::Encoder::encode(&mut encoder, msg_bytes, &mut body_buf)?;
}

let request_body = body_buf.freeze();

// Compute deadline BEFORE sending, matching Go's ctx.Deadline() semantics
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
Expand All @@ -2625,15 +2626,59 @@ where
}

let http_request = builder
.body(full_body(request_body))
.body(body)
.map_err(|e| ConnectError::internal(format!("failed to build request: {e}")))?;

// Drive the transport send concurrently with the iterator drain below.
// Without this, a transport whose send() future contains the actual I/O
// would not read from the channel until awaited, deadlocking once the
// channel filled. The response is bridged back via a oneshot so the
// awaitee is uniform across architectures.
let response_fut = transport.send(http_request);
let (resp_tx, resp_rx) =
tokio::sync::oneshot::channel::<Result<Response<T::ResponseBody>, ConnectError>>();
let _ = crate::spawn_detached(async move {
let result = response_fut
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")));
let _ = resp_tx.send(result);
});

// Enforce client-side deadline on send + parse.
with_deadline(deadline, async {
let response = transport
.send(http_request)
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")))?;
// Drain the iterator, encoding each request and pushing its envelope
// into the channel. The iterator is synchronous, so the only awaits
// here are tx.send(...), which provides backpressure via the channel
// depth.
for request in requests {
let msg_bytes = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};

let mut envelope_buf = BytesMut::new();
tokio_util::codec::Encoder::encode(&mut encoder, msg_bytes, &mut envelope_buf)?;

if tx.send(Ok(envelope_buf.freeze())).await.is_err() {
// Receiver dropped: the spawned send task has finished, either
// because the transport failed or the server responded before
// we finished sending. Stop draining and let the response
// task surface the actual error/result.
break;
}
}

drop(tx);

// Await the response now that the request body has been fully sent.
let response = resp_rx.await.map_err(|_| {
ConnectError::internal("transport send task dropped without producing a response")
})??;

// For gRPC, the response is envelope-framed like a unary gRPC response
// (single data envelope + trailers). Reuse parse_grpc_unary_response.
Expand Down
27 changes: 27 additions & 0 deletions connectrpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,33 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]

/// Spawn a detached background future on the ambient executor.
///
/// On native targets this dispatches via [`tokio::spawn`] and returns the join
/// handle. On `wasm32` there is no tokio runtime, so the future is dispatched
/// via [`wasm_bindgen_futures::spawn_local`] and `None` is returned (no
/// joinable handle available).
///
/// The `Send` bound is required on native (`tokio::spawn`) but relaxed on
/// wasm32 (`spawn_local` is single-threaded).
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn spawn_detached<F>(future: F) -> Option<tokio::task::JoinHandle<()>>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
Some(tokio::spawn(future))
}

/// wasm32 variant — see non-wasm docs above.
#[cfg(target_arch = "wasm32")]
pub(crate) fn spawn_detached<F>(future: F) -> Option<tokio::task::JoinHandle<()>>
where
F: std::future::Future<Output = ()> + 'static,
{
wasm_bindgen_futures::spawn_local(future);
None
}

// Core modules (always available)
pub mod codec;
pub mod compression;
Expand Down
34 changes: 1 addition & 33 deletions connectrpc/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2285,38 +2285,6 @@ where
/// data after a size-limit error or other decoder failure.
const MAX_DRAIN_BYTES: usize = 1024 * 1024; // 1 MiB

/// Spawn a detached background future on the ambient executor.
///
/// On native (non-wasm) targets this dispatches via `tokio::spawn` and returns
/// the join handle. On `wasm32-unknown-unknown` (Cloudflare Workers, browsers,
/// etc.) there is no tokio runtime, so the future is dispatched via
/// `wasm_bindgen_futures::spawn_local`, which does not produce a joinable
/// handle — the function returns `None` in that case.
///
/// On both targets the executor takes ownership of the future. Dropping the
/// returned handle does not cancel the task; callers must not rely on
/// `.abort()` either (see `_reader_task` for the HTTP/1.1 drain rationale).
///
/// The bound on `F` is `Send + 'static` on native (required by `tokio::spawn`)
/// and `'static` on wasm32 (`spawn_local` runs on a single thread). Avoid
/// non-`Send` state in futures that must compile on both targets.
#[cfg(not(target_arch = "wasm32"))]
fn spawn_detached<F>(future: F) -> Option<tokio::task::JoinHandle<()>>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
Some(tokio::spawn(future))
}

#[cfg(target_arch = "wasm32")]
fn spawn_detached<F>(future: F) -> Option<tokio::task::JoinHandle<()>>
where
F: std::future::Future<Output = ()> + 'static,
{
wasm_bindgen_futures::spawn_local(future);
None
}

/// Spawn a background task that reads envelope-framed messages from an HTTP
/// body and forwards them to a channel.
///
Expand Down Expand Up @@ -2430,7 +2398,7 @@ where

// The reader runs detached — it has to outlive the response stream so it
// can finish draining the request body.
let reader_task = spawn_detached(reader_future);
let reader_task = crate::spawn_detached(reader_future);

let request_stream: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::unfold(rx, |mut rx| async {
Expand Down
81 changes: 81 additions & 0 deletions tests/streaming/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,87 @@ mod tests {
assert_eq!(msg.data, "");
}

/// Regression: `call_client_stream` must stream the request body
/// frame-by-frame instead of buffering the whole concatenated payload
/// into a single Frame. Each iterator item should produce its own body
/// frame (one envelope per channel push).
#[tokio::test]
async fn client_stream_request_body_is_streamed() {
use bytes::Bytes;
use connectrpc::client::{BoxFuture, ClientBody, ClientTransport};
use http::{Request, Response};
use http_body::Body;
use std::pin::Pin;
use std::sync::Mutex;

#[derive(Clone)]
struct FrameCountingTransport {
frame_sizes: Arc<Mutex<Vec<usize>>>,
}

impl ClientTransport for FrameCountingTransport {
type ResponseBody = http_body_util::Empty<Bytes>;
type Error = ConnectError;

fn send(
&self,
request: Request<ClientBody>,
) -> BoxFuture<'static, Result<Response<Self::ResponseBody>, Self::Error>> {
let recorded = self.frame_sizes.clone();
Box::pin(async move {
let mut body = request.into_body();
let mut sizes = Vec::new();
while let Some(frame) =
std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await
{
let frame: http_body::Frame<Bytes> = frame?;
if let Ok(data) = frame.into_data() {
sizes.push(data.len());
}
}
*recorded.lock().unwrap() = sizes;
// Short-circuit: the call will surface this as Unavailable.
// The assertion is on the captured request framing.
Err(ConnectError::unavailable("recorded; not forwarded"))
})
}
}

let frames: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
let transport = FrameCountingTransport {
frame_sizes: frames.clone(),
};
let config = ClientConfig::new("http://localhost/".parse().unwrap());
let client = EchoServiceClient::new(transport, config);

let messages: Vec<EchoRequest> = (0..5)
.map(|i| EchoRequest {
sequence: i,
data: format!("msg-{i}"),
..Default::default()
})
.collect();

// Expected to fail with the forced transport error. call_client_stream
// awaits the oneshot internally, so frames are fully captured by return.
let _ = client.client_stream(messages).await;

let captured = frames.lock().unwrap().clone();
assert_eq!(
captured.len(),
5,
"expected one body frame per request message, got {} (sizes: {captured:?})",
captured.len(),
);
// Every envelope carries at minimum the 5-byte header.
for size in &captured {
assert!(
*size >= 5,
"envelope frame too small ({size} bytes) — header alone is 5 bytes",
);
}
}

/// Tests bidi streaming at the server level by sending envelope-framed
/// messages over raw HTTP. This verifies that the server-side bidi handler
/// correctly receives and echoes all messages.
Expand Down
Loading