1
1
use std:: sync:: { Arc , OnceLock } ;
2
2
3
- use axum:: { extract:: State , http:: StatusCode , routing:: post, Json , Router } ;
3
+ use axum:: {
4
+ extract:: State ,
5
+ http:: { HeaderMap , StatusCode } ,
6
+ routing:: post,
7
+ Json , Router ,
8
+ } ;
4
9
use jsonwebtoken:: { get_current_timestamp, DecodingKey , EncodingKey , Validation } ;
5
10
use log:: { info, warn} ;
6
11
use ring:: rand:: { SecureRandom , SystemRandom } ;
7
12
use serde:: { Deserialize , Serialize } ;
13
+ use serde_repr:: { Deserialize_repr , Serialize_repr } ;
8
14
9
15
use crate :: { util, AppState } ;
10
16
11
17
#[ derive( Deserialize , Clone ) ]
12
18
pub struct AuthConfig {
13
19
route : String ,
20
+ refresh_subroute : String ,
14
21
secret_path : String ,
15
- valid_secs : u64 ,
22
+ valid_secs_refresh : u64 ,
23
+ valid_secs_session : u64 ,
16
24
}
17
25
18
26
#[ derive( Deserialize ) ]
@@ -21,11 +29,19 @@ pub struct AuthRequest {
21
29
password : String ,
22
30
}
23
31
32
+ #[ repr( u8 ) ]
33
+ #[ derive( Deserialize_repr , Serialize_repr , PartialEq , Eq ) ]
34
+ pub enum TokenKind {
35
+ Refresh = 0 ,
36
+ Session = 1 ,
37
+ }
38
+
24
39
#[ derive( Deserialize , Serialize ) ]
25
40
pub struct Claims {
26
- sub : String , // account id as a string
27
- crt : u64 , // creation timestamp in UTC
28
- exp : u64 , // expiration timestamp in UTC
41
+ sub : String , // account id as a string
42
+ crt : u64 , // creation timestamp in UTC
43
+ exp : u64 , // expiration timestamp in UTC
44
+ kind : TokenKind , // kind of token
29
45
}
30
46
31
47
static SECRET_KEY : OnceLock < Vec < u8 > > = OnceLock :: new ( ) ;
@@ -55,21 +71,33 @@ pub fn register(
55
71
rng : & SystemRandom ,
56
72
) -> Router < Arc < AppState > > {
57
73
let route = & config. route ;
74
+ let refresh_route = util:: get_subroute ( route, & config. refresh_subroute ) ;
58
75
info ! ( "Registering auth route @ {}" , route) ;
76
+ info ! ( "\t Refresh route @ {}" , refresh_route) ;
59
77
check_secret ( & config. secret_path , rng) ;
60
- routes. route ( route, post ( do_auth) )
78
+ routes
79
+ . route ( route, post ( do_auth) )
80
+ . route ( & refresh_route, post ( do_refresh) )
61
81
}
62
82
63
- fn gen_jwt ( account_id : i64 , valid_secs : u64 ) -> Result < String , String > {
83
+ fn gen_jwt ( auth_config : & AuthConfig , account_id : i64 , kind : TokenKind ) -> Result < String , String > {
64
84
let secret = SECRET_KEY . get ( ) . unwrap ( ) ;
65
85
let key = EncodingKey :: from_secret ( secret) ;
86
+
87
+ let valid_secs = match kind {
88
+ TokenKind :: Refresh => auth_config. valid_secs_refresh ,
89
+ TokenKind :: Session => auth_config. valid_secs_session ,
90
+ } ;
91
+
66
92
let crt = get_current_timestamp ( ) ;
67
93
let exp = crt + valid_secs;
68
94
let claims = Claims {
69
95
sub : account_id. to_string ( ) ,
70
96
crt,
71
97
exp,
98
+ kind,
72
99
} ;
100
+
73
101
jsonwebtoken:: encode ( & jsonwebtoken:: Header :: default ( ) , & claims, & key)
74
102
. map_err ( |e| format ! ( "JWT error: {}" , e) )
75
103
}
@@ -85,7 +113,7 @@ fn get_validator(account_id: Option<i64>) -> Validation {
85
113
validation
86
114
}
87
115
88
- pub fn validate_jwt ( jwt : & str ) -> Result < i64 , String > {
116
+ pub fn validate_jwt ( jwt : & str , kind : TokenKind ) -> Result < i64 , String > {
89
117
let Some ( secret) = SECRET_KEY . get ( ) else {
90
118
return Err ( "Auth module not initialized" . to_string ( ) ) ;
91
119
} ;
@@ -102,6 +130,10 @@ pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
102
130
return Err ( "Expired JWT" . to_string ( ) ) ;
103
131
}
104
132
133
+ if token. claims . kind != kind {
134
+ return Err ( "Bad token kind" . to_string ( ) ) ;
135
+ }
136
+
105
137
match token. claims . sub . parse ( ) {
106
138
Ok ( id) => Ok ( id) ,
107
139
Err ( e) => Err ( format ! ( "Bad account ID: {}" , e) ) ,
@@ -118,8 +150,11 @@ async fn do_auth(
118
150
warn ! ( "Auth error: {}" , e) ;
119
151
( StatusCode :: UNAUTHORIZED , "Invalid credentials" . to_string ( ) )
120
152
} ) ?;
121
- let valid_secs = app. config . auth . as_ref ( ) . unwrap ( ) . valid_secs ;
122
- match gen_jwt ( account_id, valid_secs) {
153
+ match gen_jwt (
154
+ app. config . auth . as_ref ( ) . unwrap ( ) ,
155
+ account_id,
156
+ TokenKind :: Refresh ,
157
+ ) {
123
158
Ok ( jwt) => Ok ( jwt) ,
124
159
Err ( e) => {
125
160
warn ! ( "Auth error: {}" , e) ;
@@ -130,3 +165,30 @@ async fn do_auth(
130
165
}
131
166
}
132
167
}
168
+
169
+ async fn do_refresh (
170
+ State ( app) : State < Arc < AppState > > ,
171
+ headers : HeaderMap ,
172
+ ) -> Result < String , ( StatusCode , String ) > {
173
+ assert ! ( app. is_tls) ;
174
+ let db = app. db . lock ( ) . await ;
175
+ // TODO validate the refresh token against the last password reset timestamp
176
+ let account_id = match util:: validate_authed_request ( & headers, TokenKind :: Refresh ) {
177
+ Ok ( id) => id,
178
+ Err ( e) => return Err ( ( StatusCode :: UNAUTHORIZED , e) ) ,
179
+ } ;
180
+ match gen_jwt (
181
+ app. config . auth . as_ref ( ) . unwrap ( ) ,
182
+ account_id,
183
+ TokenKind :: Session ,
184
+ ) {
185
+ Ok ( jwt) => Ok ( jwt) ,
186
+ Err ( e) => {
187
+ warn ! ( "Refresh error: {}" , e) ;
188
+ Err ( (
189
+ StatusCode :: INTERNAL_SERVER_ERROR ,
190
+ "Server error" . to_string ( ) ,
191
+ ) )
192
+ }
193
+ }
194
+ }
0 commit comments