@@ -59,6 +59,7 @@ pub struct PendingWebsocketMessage {
5959
6060pub struct SharedStateInner {
6161 ups : PubSub ,
62+ gateway_id : Uuid ,
6263 receiver_subject : String ,
6364 in_flight_requests : HashMap < RequestId , InFlightRequest > ,
6465}
@@ -74,6 +75,7 @@ impl SharedState {
7475
7576 Self ( Arc :: new ( SharedStateInner {
7677 ups,
78+ gateway_id,
7779 receiver_subject,
7880 in_flight_requests : HashMap :: new ( ) ,
7981 } ) )
@@ -160,6 +162,7 @@ impl SharedState {
160162 } ;
161163
162164 let payload = protocol:: ToClientTunnelMessage {
165+ gateway_id : * self . gateway_id . as_bytes ( ) ,
163166 request_id : request_id. clone ( ) ,
164167 message_id,
165168 // Only send reply to subject on the first message for this request. This reduces
@@ -179,8 +182,8 @@ impl SharedState {
179182 } ) ;
180183
181184 // Send message
182- let message = protocol:: ToClient :: ToClientTunnelMessage ( payload) ;
183- let message_serialized = versioned:: ToClient :: wrap_latest ( message)
185+ let message = protocol:: ToRunner :: ToClientTunnelMessage ( payload) ;
186+ let message_serialized = versioned:: ToRunner :: wrap_latest ( message)
184187 . serialize_with_embedded_version ( PROTOCOL_VERSION ) ?;
185188
186189 if let ( Some ( hs) , Some ( ws_msg_index) ) = ( & mut req. hibernation_state , ws_msg_index) {
@@ -221,105 +224,48 @@ impl SharedState {
221224 ) ;
222225
223226 match versioned:: ToGateway :: deserialize_with_embedded_version ( & msg. payload ) {
224- Ok ( protocol:: ToGateway { message : msg } ) => {
225- tracing:: debug!(
226- request_id=?Uuid :: from_bytes( msg. request_id) ,
227- message_id=?Uuid :: from_bytes( msg. message_id) ,
228- "successfully deserialized message"
229- ) ;
230-
231- let Some ( mut in_flight) =
232- self . in_flight_requests . get_async ( & msg. request_id ) . await
227+ Ok ( protocol:: ToGateway :: ToGatewayKeepAlive ) => {
228+ // TODO:
229+ // let prev_len = in_flight.pending_msgs.len();
230+ //
231+ // tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack");
232+ //
233+ // in_flight
234+ // .pending_msgs
235+ // .retain(|m| m.message_id != msg.message_id);
236+ //
237+ // if prev_len == in_flight.pending_msgs.len() {
238+ // tracing::warn!(
239+ // request_id=?Uuid::from_bytes(msg.request_id),
240+ // message_id=?Uuid::from_bytes(msg.message_id),
241+ // "pending message does not exist or ack received after message body"
242+ // )
243+ // } else {
244+ // tracing::debug!(
245+ // request_id=?Uuid::from_bytes(msg.request_id),
246+ // message_id=?Uuid::from_bytes(msg.message_id),
247+ // "received TunnelAck, removed from pending"
248+ // );
249+ // }
250+ }
251+ Ok ( protocol:: ToGateway :: ToServerTunnelMessage ( msg) ) => {
252+ let Some ( in_flight) = self . in_flight_requests . get_async ( & msg. request_id ) . await
233253 else {
234254 tracing:: warn!(
235255 request_id=?Uuid :: from_bytes( msg. request_id) ,
236256 message_id=?Uuid :: from_bytes( msg. message_id) ,
237- "in flight has already been disconnected, cannot ack message"
257+ "in flight has already been disconnected, dropping message"
238258 ) ;
239259 continue ;
240260 } ;
241261
242- if let protocol:: ToServerTunnelMessageKind :: TunnelAck = & msg. message_kind {
243- let prev_len = in_flight. pending_msgs . len ( ) ;
244-
245- tracing:: debug!( message_id=?Uuid :: from_bytes( msg. message_id) , "received tunnel ack" ) ;
246-
247- in_flight
248- . pending_msgs
249- . retain ( |m| m. message_id != msg. message_id ) ;
250-
251- if prev_len == in_flight. pending_msgs . len ( ) {
252- tracing:: warn!(
253- request_id=?Uuid :: from_bytes( msg. request_id) ,
254- message_id=?Uuid :: from_bytes( msg. message_id) ,
255- "pending message does not exist or ack received after message body"
256- )
257- } else {
258- tracing:: debug!(
259- request_id=?Uuid :: from_bytes( msg. request_id) ,
260- message_id=?Uuid :: from_bytes( msg. message_id) ,
261- "received TunnelAck, removed from pending"
262- ) ;
263- }
264- } else {
265- // Send message to the request handler to emulate the real network action
266- tracing:: debug!(
267- request_id=?Uuid :: from_bytes( msg. request_id) ,
268- message_id=?Uuid :: from_bytes( msg. message_id) ,
269- "forwarding message to request handler"
270- ) ;
271- let _ = in_flight. msg_tx . send ( msg. message_kind . clone ( ) ) . await ;
272-
273- // Send ack back to runner
274- let ups_clone = self . ups . clone ( ) ;
275- let receiver_subject = in_flight. receiver_subject . clone ( ) ;
276- let request_id = msg. request_id ;
277- let message_id = msg. message_id ;
278- let ack_message = protocol:: ToClient :: ToClientTunnelMessage (
279- protocol:: ToClientTunnelMessage {
280- request_id,
281- message_id,
282- gateway_reply_to : None ,
283- message_kind : protocol:: ToClientTunnelMessageKind :: TunnelAck ,
284- } ,
285- ) ;
286- let ack_message_serialized =
287- match versioned:: ToClient :: wrap_latest ( ack_message)
288- . serialize_with_embedded_version ( PROTOCOL_VERSION )
289- {
290- Ok ( x) => x,
291- Err ( err) => {
292- tracing:: error!( ?err, "failed to serialize ack" ) ;
293- continue ;
294- }
295- } ;
296- tokio:: spawn ( async move {
297- match ups_clone
298- . publish (
299- & receiver_subject,
300- & ack_message_serialized,
301- PublishOpts :: one ( ) ,
302- )
303- . await
304- {
305- Ok ( _) => {
306- tracing:: debug!(
307- request_id=?Uuid :: from_bytes( request_id) ,
308- message_id=?Uuid :: from_bytes( message_id) ,
309- "sent TunnelAck to runner"
310- ) ;
311- }
312- Err ( err) => {
313- tracing:: warn!(
314- ?err,
315- request_id=?Uuid :: from_bytes( request_id) ,
316- message_id=?Uuid :: from_bytes( message_id) ,
317- "failed to send TunnelAck to runner"
318- ) ;
319- }
320- }
321- } ) ;
322- }
262+ // Send message to the request handler to emulate the real network action
263+ tracing:: debug!(
264+ request_id=?Uuid :: from_bytes( msg. request_id) ,
265+ message_id=?Uuid :: from_bytes( msg. message_id) ,
266+ "forwarding message to request handler"
267+ ) ;
268+ let _ = in_flight. msg_tx . send ( msg. message_kind . clone ( ) ) . await ;
323269 }
324270 Err ( err) => {
325271 tracing:: error!( ?err, "failed to parse message" ) ;
0 commit comments