Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 202 additions & 4 deletions crates/fetchkit/src/fetchers/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl Fetcher for DefaultFetcher {
};

// THREAT[TM-SSRF-010]: Follow redirects manually so every hop is re-validated.
let response =
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;

let status_code = response.status().as_u16();
Expand Down Expand Up @@ -250,6 +250,7 @@ impl Fetcher for DefaultFetcher {
etag: meta.etag,
filename: meta.filename,
method: Some("HEAD".to_string()),
redirect_chain,
..Default::default()
});
}
Expand All @@ -265,6 +266,7 @@ impl Fetcher for DefaultFetcher {
last_modified: meta.last_modified,
etag: meta.etag,
filename: meta.filename,
redirect_chain,
error: Some(
"Binary content is not supported. Only textual content (HTML, text, JSON, etc.) can be fetched."
.to_string(),
Expand All @@ -282,6 +284,9 @@ impl Fetcher for DefaultFetcher {
// Convert to string
let content = String::from_utf8_lossy(&body).to_string();

// Detect paywall before content is moved by conversion
let is_paywall = detect_paywall(&content);

// Determine format and convert if needed
// THREAT[TM-DOS-006]: Conversion input is bounded by max_body_size
let is_html_content = is_html(&meta.content_type, &content);
Expand Down Expand Up @@ -335,6 +340,9 @@ impl Fetcher for DefaultFetcher {
final_content.push_str(TRUNCATION_MESSAGE);
}

// Compute quality signals
let word_count = count_words(&final_content);

Ok(FetchResponse {
url: final_url,
status_code,
Expand All @@ -347,6 +355,9 @@ impl Fetcher for DefaultFetcher {
content: Some(final_content),
truncated: if truncated { Some(true) } else { None },
metadata: page_metadata,
word_count: Some(word_count),
redirect_chain,
is_paywall: if is_paywall { Some(true) } else { None },
..Default::default()
})
}
Expand Down Expand Up @@ -383,7 +394,7 @@ impl Fetcher for DefaultFetcher {
};

// THREAT[TM-SSRF-010]: Follow redirects manually with IP validation at each hop
let response =
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;

let status_code = response.status().as_u16();
Expand All @@ -401,6 +412,7 @@ impl Fetcher for DefaultFetcher {
etag: meta.etag,
filename: meta.filename,
method: Some("HEAD".to_string()),
redirect_chain,
..Default::default()
});
}
Expand All @@ -426,19 +438,22 @@ impl Fetcher for DefaultFetcher {
truncated: if truncated { Some(true) } else { None },
saved_path: Some(save_result.path),
bytes_written: Some(save_result.bytes_written),
redirect_chain,
// No inline content when saving to file
..Default::default()
})
}
}

/// Returns `(response, redirect_chain)` where redirect_chain lists intermediate URLs.
async fn send_request_following_redirects(
initial_url: Url,
method: reqwest::Method,
headers: HeaderMap,
options: &FetchOptions,
) -> Result<reqwest::Response, FetchError> {
) -> Result<(reqwest::Response, Vec<String>), FetchError> {
let mut current_url = initial_url;
let mut redirect_chain = Vec::new();

for redirect_count in 0..=MAX_REDIRECTS {
let client = build_client_for_url(&current_url, headers.clone(), options)?;
Expand All @@ -449,7 +464,7 @@ async fn send_request_following_redirects(
.map_err(FetchError::from_reqwest)?;

let Some(next_url) = redirect_target(&current_url, &response, options)? else {
return Ok(response);
return Ok((response, redirect_chain));
};

if redirect_count == MAX_REDIRECTS {
Expand All @@ -463,6 +478,7 @@ async fn send_request_following_redirects(
"Following redirect with IP validation"
);

redirect_chain.push(current_url.to_string());
current_url = next_url;
}

Expand Down Expand Up @@ -650,6 +666,36 @@ async fn read_body_with_timeout(
}
}

/// Count words in text content.
fn count_words(text: &str) -> u64 {
text.split_whitespace().count() as u64
}

/// Common paywall indicators in raw HTML content.
const PAYWALL_INDICATORS: &[&str] = &[
"paywall",
"subscribe to read",
"subscribe to continue",
"subscription required",
"premium content",
"members only",
"sign in to read",
"log in to read",
"create a free account",
"already a subscriber",
"unlock this article",
"get unlimited access",
"start your free trial",
];

/// Heuristic paywall detection from raw HTML.
fn detect_paywall(html: &str) -> bool {
let lower = html.to_lowercase();
PAYWALL_INDICATORS
.iter()
.any(|indicator| lower.contains(indicator))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1048,4 +1094,156 @@ mod tests {
assert_eq!(response.status_code, 304);
assert!(response.content.is_none());
}

#[test]
fn test_count_words() {
assert_eq!(count_words("hello world"), 2);
assert_eq!(count_words(""), 0);
assert_eq!(count_words(" one two three "), 3);
assert_eq!(count_words("word"), 1);
}

#[test]
fn test_detect_paywall() {
assert!(detect_paywall("<div class=\"paywall\">Subscribe</div>"));
assert!(detect_paywall("<p>Subscribe to read the full article</p>"));
assert!(detect_paywall("<span>Already a subscriber? Log in</span>"));
assert!(detect_paywall("<div>Unlock this article</div>"));
assert!(!detect_paywall("<p>This is a normal article</p>"));
assert!(!detect_paywall("<h1>Hello World</h1><p>Free content</p>"));
}

#[tokio::test]
async fn test_word_count_in_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/article"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Hello world this is a test")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;

let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/article", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();

assert_eq!(response.word_count, Some(6));
}

#[tokio::test]
async fn test_redirect_chain_tracked() {
let destination = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/final"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("arrived")
.insert_header("content-type", "text/plain"),
)
.mount(&destination)
.await;

let origin = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", format!("{}/final", destination.uri())),
)
.mount(&origin)
.await;

let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/start", origin.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();

assert_eq!(response.status_code, 200);
assert_eq!(response.redirect_chain.len(), 1);
assert!(response.redirect_chain[0].contains("/start"));
}

#[tokio::test]
async fn test_no_redirect_chain_for_direct_response() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/direct"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("direct")
.insert_header("content-type", "text/plain"),
)
.mount(&server)
.await;

let fetcher = DefaultFetcher::new();
let options = FetchOptions {
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/direct", server.uri()));
let response = fetcher.fetch(&request, &options).await.unwrap();

assert!(response.redirect_chain.is_empty());
}

#[tokio::test]
async fn test_paywall_detection() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/paywalled"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><div class='paywall'>Subscribe to read the full article</div><p>Preview...</p></body></html>")
.insert_header("content-type", "text/html"),
)
.mount(&server)
.await;

let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/paywalled", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();

assert_eq!(response.is_paywall, Some(true));
}

#[tokio::test]
async fn test_no_paywall_for_normal_content() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/free"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("<html><body><p>This is free content</p></body></html>")
.insert_header("content-type", "text/html"),
)
.mount(&server)
.await;

let fetcher = DefaultFetcher::new();
let options = FetchOptions {
enable_markdown: true,
dns_policy: DnsPolicy::allow_all(),
..Default::default()
};
let request = FetchRequest::new(format!("{}/free", server.uri())).as_markdown();
let response = fetcher.fetch(&request, &options).await.unwrap();

assert!(response.is_paywall.is_none());
}
}
12 changes: 12 additions & 0 deletions crates/fetchkit/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ pub struct FetchResponse {
/// Structured page metadata extracted from HTML
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<PageMetadata>,

/// Word count of the final content
#[serde(skip_serializing_if = "Option::is_none")]
pub word_count: Option<u64>,

/// Chain of URLs followed during redirects (empty if no redirects)
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub redirect_chain: Vec<String>,

/// Heuristic paywall detection (soft signal, not guaranteed)
#[serde(skip_serializing_if = "Option::is_none")]
pub is_paywall: Option<bool>,
}

#[cfg(test)]
Expand Down
Loading