Skip to content

Commit 89a0735

Browse files
Refactoring how some things work and add doc tests for middleware updates
1 parent 3eb3ad4 commit 89a0735

File tree

4 files changed

+168
-28
lines changed

4 files changed

+168
-28
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ dirs = "5"
2626
chrono = "0.4"
2727
tokio = { version = "1", features = ["fs"] }
2828
tower = { version = "0.4", optional = true }
29-
axum = { version = ">= 0.7.2", optional = true }
29+
axum = { version = ">= 0.8", optional = true }
3030
futures-core = { version = "0.3", optional = true }
3131
http = "1"
3232
bytes = { version = "1", optional = true }
3333
thiserror = "1"
3434
mauth-core = "0.6"
35+
tracing = { version = "0.1", optional = true }
3536

3637
[dev-dependencies]
3738
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
3839

3940
[features]
40-
axum-service = ["tower", "futures-core", "axum", "bytes"]
41+
axum-service = ["tower", "futures-core", "axum", "bytes", "tracing"]
4142
tracing-otel-26 = ["reqwest-tracing/opentelemetry_0_26"]
4243
tracing-otel-27 = ["reqwest-tracing/opentelemetry_0_27"]

README.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ the MAuth protocol, and verify the responses. Usage example:
77
release any code to Production or deploy in a Client-accessible environment without getting
88
approval for the full stack used through the Architecture and Security groups.
99

10+
## Outgoing Requests
11+
1012
```no_run
1113
use mauth_client::MAuthInfo;
1214
use reqwest::Client;
@@ -49,6 +51,8 @@ match client.get("https://www.example.com/").send().await {
4951
# }
5052
```
5153

54+
## Incoming Requests
55+
5256
The optional `axum-service` feature provides for a Tower Layer and Service that will
5357
authenticate incoming requests via MAuth V2 or V1 and provide to the lower layers a
5458
validated app_uuid from the request via the `ValidatedRequestDetails` struct. Note that
@@ -66,6 +70,104 @@ a different response, or respond to the lack of the extension in another way, yo
6670
use a more manual mechanism to check for the extension and decide how to proceed if it
6771
is not present.
6872

73+
### Examples for `RequiredMAuthValidationLayer`
74+
75+
```no_run
76+
# async fn run_server() {
77+
use mauth_client::{
78+
axum_service::RequiredMAuthValidationLayer,
79+
validate_incoming::ValidatedRequestDetails,
80+
};
81+
use axum::{http::StatusCode, Router, routing::get, serve};
82+
use tokio::net::TcpListener;
83+
84+
// If there is not a valid mauth signature, this function will never run at all, and
85+
// the request will return an empty 401 Unauthorized
86+
async fn foo() -> StatusCode {
87+
StatusCode::OK
88+
}
89+
90+
// In addition to returning a 401 Unauthorized without running if there is not a valid
91+
// MAuth signature, this also makes the validated requesting app UUID available to
92+
// the function
93+
async fn bar(details: ValidatedRequestDetails) -> StatusCode {
94+
println!("Got a request from app with UUID: {}", details.app_uuid);
95+
StatusCode::OK
96+
}
97+
98+
// This function will run regardless of whether or not there is a mauth signature
99+
async fn baz() -> StatusCode {
100+
StatusCode::OK
101+
}
102+
103+
// Attaching the baz route handler after the layer means the layer is not run for
104+
// requests to that path, so no mauth checking will be performed for that route and
105+
// any other routes attached after the layer
106+
let router = Router::new()
107+
.route("/foo", get(foo))
108+
.route("/bar", get(bar))
109+
.layer(RequiredMAuthValidationLayer::from_default_file().unwrap())
110+
.route("/baz", get(baz));
111+
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();
112+
serve(listener, router).await.unwrap();
113+
# }
114+
```
115+
116+
### Examples for `OptionalMAuthValidationLayer`
117+
118+
```no_run
119+
# async fn run_server() {
120+
use mauth_client::{
121+
axum_service::OptionalMAuthValidationLayer,
122+
validate_incoming::ValidatedRequestDetails,
123+
};
124+
use axum::{http::StatusCode, Router, routing::get, serve};
125+
use tokio::net::TcpListener;
126+
127+
// This request will run no matter what the authorization status is
128+
async fn foo() -> StatusCode {
129+
StatusCode::OK
130+
}
131+
132+
// If there is not a valid mauth signature, this function will never run at all, and
133+
// the request will return an empty 401 Unauthorized
134+
async fn bar(_: ValidatedRequestDetails) -> StatusCode {
135+
StatusCode::OK
136+
}
137+
138+
// In addition to returning a 401 Unauthorized without running if there is not a valid
139+
// MAuth signature, this also makes the validated requesting app UUID available to
140+
// the function
141+
async fn baz(details: ValidatedRequestDetails) -> StatusCode {
142+
println!("Got a request from app with UUID: {}", details.app_uuid);
143+
StatusCode::OK
144+
}
145+
146+
// This request will run whether or not there is a valid mauth signature, but the Option
147+
// provided can be used to tell you whether there was a valid signature, so you can
148+
// implement things like multiple possible types of authentication or behavior other than
149+
// a 401 return if there is no authentication
150+
async fn bam(optional_details: Option<ValidatedRequestDetails>) -> StatusCode {
151+
match optional_details {
152+
Some(details) => println!("Got a request from app with UUID: {}", details.app_uuid),
153+
None => println!("Got a request without a valid mauth signature"),
154+
}
155+
StatusCode::OK
156+
}
157+
158+
let router = Router::new()
159+
.route("/foo", get(foo))
160+
.route("/bar", get(bar))
161+
.route("/baz", get(baz))
162+
.route("/bam", get(bam))
163+
.layer(OptionalMAuthValidationLayer::from_default_file().unwrap());
164+
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();
165+
serve(listener, router).await.unwrap();
166+
# }
167+
```
168+
169+
### OpenTelemetry Integration
170+
69171
There are also optional features `tracing-otel-26` and `tracing-otel-27` that pair with
70172
the `axum-service` feature to ensure that any outgoing requests for credentials that take
71173
place in the context of an incoming web request also include the proper OpenTelemetry span

src/axum_service.rs

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
//! Structs and impls related to providing a Tower Service and Layer to verify incoming requests
22
3-
use axum::extract::{FromRequestParts, Request};
3+
use axum::{
4+
body::Body,
5+
extract::{FromRequestParts, OptionalFromRequestParts, Request},
6+
response::IntoResponse,
7+
};
48
use futures_core::future::BoxFuture;
5-
use http::{request::Parts, StatusCode};
9+
use http::{request::Parts, Response, StatusCode};
10+
use std::convert::Infallible;
611
use std::error::Error;
712
use std::task::{Context, Poll};
813
use tower::{Layer, Service};
14+
use tracing::error;
915

1016
use crate::validate_incoming::ValidatedRequestDetails;
1117
use crate::{
@@ -27,24 +33,31 @@ where
2733
S: Service<Request> + Send + Clone + 'static,
2834
S::Future: Send + 'static,
2935
S::Error: Into<Box<dyn Error + Sync + Send>>,
36+
S::Response: Into<Response<Body>>,
3037
{
31-
type Response = S::Response;
32-
type Error = Box<dyn Error + Sync + Send>;
38+
type Response = Response<Body>;
39+
type Error = S::Error;
3340
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
3441

3542
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36-
self.service.poll_ready(cx).map_err(|e| e.into())
43+
self.service.poll_ready(cx)
3744
}
3845

3946
fn call(&mut self, request: Request) -> Self::Future {
4047
let mut cloned = self.clone();
4148
Box::pin(async move {
4249
match cloned.mauth_info.validate_request(request).await {
4350
Ok(valid_request) => match cloned.service.call(valid_request).await {
44-
Ok(response) => Ok(response),
45-
Err(err) => Err(err.into()),
51+
Ok(response) => Ok(response.into()),
52+
Err(err) => Err(err),
4653
},
47-
Err(err) => Err(Box::new(err) as Box<dyn Error + Send + Sync>),
54+
Err(err) => {
55+
error!(
56+
error = ?err,
57+
"Failed to validate MAuth signature, rejecting request"
58+
);
59+
Ok(StatusCode::UNAUTHORIZED.into_response())
60+
}
4861
}
4962
})
5063
}
@@ -121,23 +134,18 @@ where
121134
S::Error: Into<Box<dyn Error + Sync + Send>>,
122135
{
123136
type Response = S::Response;
124-
type Error = Box<dyn Error + Sync + Send>;
137+
type Error = S::Error;
125138
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
126139

127140
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128-
self.service.poll_ready(cx).map_err(|e| e.into())
141+
self.service.poll_ready(cx)
129142
}
130143

131144
fn call(&mut self, request: Request) -> Self::Future {
132145
let mut cloned = self.clone();
133146
Box::pin(async move {
134-
match cloned.mauth_info.validate_request_optionally(request).await {
135-
Ok(valid_request) => match cloned.service.call(valid_request).await {
136-
Ok(response) => Ok(response),
137-
Err(err) => Err(err.into()),
138-
},
139-
Err(err) => Err(Box::new(err) as Box<dyn Error + Send + Sync>),
140-
}
147+
let processed_request = cloned.mauth_info.validate_request_optionally(request).await;
148+
cloned.service.call(processed_request).await
141149
})
142150
}
143151
}
@@ -192,8 +200,10 @@ impl OptionalMAuthValidationLayer {
192200
}
193201
}
194202

195-
#[async_trait::async_trait]
196-
impl<S> FromRequestParts<S> for ValidatedRequestDetails {
203+
impl<S> FromRequestParts<S> for ValidatedRequestDetails
204+
where
205+
S: Send + Sync,
206+
{
197207
type Rejection = StatusCode;
198208

199209
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
@@ -204,3 +214,17 @@ impl<S> FromRequestParts<S> for ValidatedRequestDetails {
204214
.ok_or(StatusCode::UNAUTHORIZED)
205215
}
206216
}
217+
218+
impl<S> OptionalFromRequestParts<S> for ValidatedRequestDetails
219+
where
220+
S: Send + Sync,
221+
{
222+
type Rejection = Infallible;
223+
224+
async fn from_request_parts(
225+
parts: &mut Parts,
226+
_state: &S,
227+
) -> Result<Option<Self>, Self::Rejection> {
228+
Ok(parts.extensions.get::<ValidatedRequestDetails>().cloned())
229+
}
230+
}

src/validate_incoming.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use crate::{MAuthInfo, CLIENT, PUBKEY_CACHE};
22
use axum::extract::Request;
3+
use bytes::Bytes;
34
use chrono::prelude::*;
45
use mauth_core::verifier::Verifier;
56
use thiserror::Error;
7+
use tracing::error;
68
use uuid::Uuid;
79

810
/// This struct holds the app UUID for a validated request. It is meant to be used with the
@@ -59,15 +61,26 @@ impl MAuthInfo {
5961
}
6062
}
6163

62-
pub(crate) async fn validate_request_optionally(
63-
&self,
64-
req: Request,
65-
) -> Result<Request, axum::Error> {
64+
pub(crate) async fn validate_request_optionally(&self, req: Request) -> Request {
6665
let (mut parts, body) = req.into_parts();
6766
if parts.headers.contains_key(MAUTH_V2_SIGNATURE_HEADER)
6867
|| parts.headers.contains_key(MAUTH_V1_SIGNATURE_HEADER)
6968
{
70-
let body_bytes = axum::body::to_bytes(body, usize::MAX).await?;
69+
// By my reading of the code for this it should never fail, since we are passing
70+
// MAX for the limit. But just to be safe, we will log the error and proceed with
71+
// an empty body just in case instead of unwrapping. This would cause the body to
72+
// be unavailable to the lower layers, but they would probably also fail to get it
73+
// anyways since we just did here.
74+
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
75+
Ok(bytes) => bytes,
76+
Err(err) => {
77+
error!(
78+
error = ?err,
79+
"Failed to retrieve request body, continuing with empty body"
80+
);
81+
Bytes::new()
82+
}
83+
};
7184

7285
match self.validate_request_v2(&parts, &body_bytes).await {
7386
Ok(host_app_uuid) => {
@@ -95,9 +108,9 @@ impl MAuthInfo {
95108

96109
let new_body = axum::body::Body::from(body_bytes);
97110
let new_request = Request::from_parts(parts, new_body);
98-
Ok(new_request)
111+
new_request
99112
} else {
100-
Ok(Request::from_parts(parts, body))
113+
Request::from_parts(parts, body)
101114
}
102115
}
103116

0 commit comments

Comments
 (0)