Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 69 additions & 21 deletions packages/blitz-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
//!
//! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait.

use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
// use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
use blitz_traits::net::{AbortSignal, Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
use data_url::DataUrl;
use std::sync::Arc;
use std::{marker::PhantomData, pin::Pin, sync::Arc, task::Poll};
use tokio::runtime::Handle;

#[cfg(feature = "cache")]
Expand Down Expand Up @@ -102,16 +103,6 @@ impl Provider {
})
}

async fn fetch_with_handler(
client: Client,
request: Request,
handler: Box<dyn NetHandler>,
) -> Result<(), ProviderError> {
let (response_url, bytes) = Self::fetch_inner(client, request).await?;
handler.bytes(response_url, bytes);
Ok(())
}

#[allow(clippy::type_complexity)]
pub fn fetch_with_callback(
&self,
Expand Down Expand Up @@ -155,7 +146,7 @@ impl Provider {
}

impl NetProvider for Provider {
fn fetch(&self, doc_id: usize, request: Request, handler: Box<dyn NetHandler>) {
fn fetch(&self, doc_id: usize, mut request: Request, handler: Box<dyn NetHandler>) {
let client = self.client.clone();

#[cfg(feature = "debug_log")]
Expand All @@ -166,23 +157,80 @@ impl NetProvider for Provider {
#[cfg(feature = "debug_log")]
let url = request.url.to_string();

let _res = Self::fetch_with_handler(client, request, handler).await;

#[cfg(feature = "debug_log")]
if let Err(e) = _res {
eprintln!("Error fetching {url}: {e:?}");
let signal = request.signal.take();
let result = if let Some(signal) = signal {
AbortFetch::new(
signal,
Box::pin(async move { Self::fetch_inner(client, request).await }),
)
.await
} else {
println!("Success {url}");
}
Self::fetch_inner(client, request).await
};

// Call the waker to notify of completed network request
waker.wake(doc_id)
waker.wake(doc_id);

match result {
Ok((response_url, bytes)) => {
handler.bytes(response_url, bytes);
#[cfg(feature = "debug_log")]
println!("Success {url}");
}
Err(e) => {
#[cfg(feature = "debug_log")]
eprintln!("Error fetching {url}: {e:?}");
#[cfg(not(feature = "debug_log"))]
let _ = e;
}
};
});
}
}

/// A future that is cancellable using an AbortSignal
struct AbortFetch<F, T> {
signal: AbortSignal,
future: F,
_rt: PhantomData<T>,
}

impl<F, T> AbortFetch<F, T> {
fn new(signal: AbortSignal, future: F) -> Self {
Self {
signal,
future,
_rt: PhantomData,
}
}
}

impl<F, T> Future for AbortFetch<F, T>
where
F: Future + Unpin + Send + 'static,
F::Output: Send + Into<Result<T, ProviderError>> + 'static,
T: Unpin,
{
type Output = Result<T, ProviderError>;

fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.signal.aborted() {
return Poll::Ready(Err(ProviderError::Abort));
}

match Pin::new(&mut self.future).poll(cx) {
Poll::Ready(output) => Poll::Ready(output.into()),
Poll::Pending => Poll::Pending,
}
}
}

#[derive(Debug)]
pub enum ProviderError {
Abort,
Io(std::io::Error),
DataUrl(data_url::DataUrlError),
DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
Expand Down
1 change: 1 addition & 0 deletions packages/blitz-traits/src/navigation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl NavigationOptions {
content_type: self.content_type,
headers: HeaderMap::new(),
body: self.document_resource,
signal: None,
}
}
}
50 changes: 50 additions & 0 deletions packages/blitz-traits/src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ use serde::{
Serialize,
ser::{SerializeSeq, SerializeTuple},
};
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use std::{ops::Deref, path::PathBuf};
pub use url::Url;

Expand Down Expand Up @@ -43,6 +47,7 @@ pub struct Request {
pub content_type: String,
pub headers: HeaderMap,
pub body: Body,
pub signal: Option<AbortSignal>,
}
impl Request {
/// A get request to the specified Url and an empty body
Expand All @@ -53,8 +58,14 @@ impl Request {
content_type: String::new(),
headers: HeaderMap::new(),
body: Body::Empty,
signal: None,
}
}

pub fn signal(mut self, signal: AbortSignal) -> Self {
self.signal = Some(signal);
self
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -148,3 +159,42 @@ pub struct DummyNetProvider;
impl NetProvider for DummyNetProvider {
fn fetch(&self, _doc_id: usize, _request: Request, _handler: Box<dyn NetHandler>) {}
}

/// The AbortController interface represents a controller object that
/// allows you to abort one or more Web requests as and when desired.
///
/// https://developer.mozilla.org/en-US/docs/Web/API/AbortController
#[derive(Debug, Default)]
pub struct AbortController {
pub signal: AbortSignal,
}

impl AbortController {
/// The abort() method of the AbortController interface aborts
/// an asynchronous operation before it has completed.
/// This is able to abort fetch requests.
///
/// https://developer.mozilla.org/en-US/docs/Web/API/AbortController/abort
pub fn abort(self) {
self.signal.0.store(true, Ordering::SeqCst);
}
}

/// The AbortSignal interface represents a signal object that allows you to
/// communicate with an asynchronous operation (such as a fetch request) and
/// abort it if required via an AbortController object.
///
/// https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal
#[derive(Debug, Default, Clone)]
pub struct AbortSignal(Arc<AtomicBool>);

impl AbortSignal {
/// The aborted read-only property returns a value that indicates whether
/// the asynchronous operations the signal is communicating with are
/// aborted (true) or not (false).
///
/// https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/aborted
pub fn aborted(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
Loading