Skip to content

Commit 19ed47c

Browse files
wip
1 parent 4013e64 commit 19ed47c

File tree

5 files changed

+187
-112
lines changed

5 files changed

+187
-112
lines changed

dmq-node/src/DMQ/NodeToClient.hs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module DMQ.NodeToClient
1616
import Data.Aeson qualified as Aeson
1717
import Data.ByteString.Lazy (ByteString)
1818
import Data.Functor.Contravariant ((>$<))
19+
import Data.Typeable
1920
import Data.Void
2021
import Data.Word
2122

@@ -132,20 +133,22 @@ data Apps ntcAddr m a =
132133
-- | Construct applications for the node-to-client protocols
133134
--
134135
ntcApps
135-
:: forall crypto idx ntcAddr failure m.
136+
:: forall crypto idx ntcAddr m.
136137
( MonadThrow m
137138
, MonadThread m
138139
, MonadSTM m
139140
, Crypto crypto
140141
, Aeson.ToJSON ntcAddr
141142
, Aeson.ToJSON (MempoolAddFail (Sig crypto))
143+
, Show (MempoolAddFail (Sig crypto))
142144
, ShowProxy (MempoolAddFail (Sig crypto))
143145
, ShowProxy (Sig crypto)
146+
, Typeable crypto
144147
)
145148
=> (forall ev. Aeson.ToJSON ev => Tracer m (WithEventType ev))
146149
-> Configuration
147150
-> TxSubmissionMempoolReader SigId (Sig crypto) idx m
148-
-> MempoolWriter SigId (Sig crypto) failure idx m
151+
-> MempoolWriter SigId (Sig crypto) idx m
149152
-> Word16
150153
-> Codecs crypto m
151154
-> Apps ntcAddr m ()
Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
{-# LANGUAGE FlexibleContexts #-}
2-
{-# LANGUAGE OverloadedStrings #-}
3-
{-# LANGUAGE StandaloneDeriving #-}
1+
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE OverloadedStrings #-}
3+
{-# LANGUAGE StandaloneDeriving #-}
4+
{-# LANGUAGE TypeApplications #-}
5+
{-# LANGUAGE UndecidableInstances #-}
46

57
module DMQ.NodeToClient.LocalMsgSubmission where
68

79
import Control.Concurrent.Class.MonadSTM
10+
import Control.Monad.Class.MonadThrow
811
import Control.Tracer
912
import Data.Aeson (ToJSON (..), object, (.=))
1013
import Data.Aeson qualified as Aeson
14+
import Data.Typeable
1115

1216
import DMQ.Protocol.LocalMsgSubmission.Server
1317
import DMQ.Protocol.LocalMsgSubmission.Type
@@ -16,55 +20,78 @@ import Ouroboros.Network.TxSubmission.Mempool.Simple
1620
-- | Local transaction submission server, for adding txs to the 'Mempool'
1721
--
1822
localMsgSubmissionServer ::
19-
MonadSTM m
20-
=> (sig -> sigid)
23+
forall msgid msg idx m.
24+
( MonadSTM m
25+
, MonadThrow m
26+
, Typeable msgid
27+
, Typeable msg
28+
, Show msgid
29+
, Show (MempoolAddFail msg))
30+
=> (msg -> msgid)
2131
-- ^ get message id
22-
-> Tracer m (TraceLocalMsgSubmission sig sigid)
23-
-> MempoolWriter sigid sig failure idx m
32+
-> Tracer m (TraceLocalMsgSubmission msg msgid)
33+
-> MempoolWriter msgid msg idx m
2434
-- ^ duplicate error tag in case the mempool returns the empty list on failure
25-
-> m (LocalMsgSubmissionServer sig m ())
35+
-> m (LocalMsgSubmissionServer msg m ())
2636
localMsgSubmissionServer getMsgId tracer MempoolWriter { mempoolAddTxs } =
2737
pure server
2838
where
29-
process (sigid, e@(SubmitFail reason)) =
30-
(e, server) <$ traceWith tracer (TraceSubmitFailure sigid reason)
31-
process (sigid, success) =
32-
(success, server) <$ traceWith tracer (TraceSubmitAccept sigid)
39+
process (Left (msgid, reason)) = do
40+
traceWith tracer (TraceSubmitFailure msgid reason)
41+
throwIO $ MsgValidationException msgid reason
42+
process (Right [(msgid, e@(SubmitFail reason))]) =
43+
(e, server) <$ traceWith tracer (TraceSubmitFailure msgid reason)
44+
process (Right [(msgid, SubmitSuccess)]) =
45+
(SubmitSuccess, server) <$ traceWith tracer (TraceSubmitAccept msgid)
46+
process _ = throwIO (TooManyMessages @msgid @msg)
3347

3448
server = LocalTxSubmissionServer {
35-
recvMsgSubmitTx = \sig -> do
36-
traceWith tracer $ TraceReceivedMsg (getMsgId sig)
37-
process . head =<< mempoolAddTxs [sig]
49+
recvMsgSubmitTx = \msg -> do
50+
traceWith tracer $ TraceReceivedMsg (getMsgId msg)
51+
process =<< mempoolAddTxs [msg]
3852

3953
, recvMsgDone = ()
4054
}
4155

4256

43-
data TraceLocalMsgSubmission sig sigid =
44-
TraceReceivedMsg sigid
57+
data TraceLocalMsgSubmission msg msgid =
58+
TraceReceivedMsg msgid
4559
-- ^ A signature was received.
46-
| TraceSubmitFailure sigid (MempoolAddFail sig)
47-
| TraceSubmitAccept sigid
60+
| TraceSubmitFailure msgid (MempoolAddFail msg)
61+
| TraceSubmitAccept msgid
4862

4963
deriving instance
50-
(Show sig, Show sigid, Show (MempoolAddFail sig))
51-
=> Show (TraceLocalMsgSubmission sig sigid)
64+
(Show msg, Show msgid, Show (MempoolAddFail msg))
65+
=> Show (TraceLocalMsgSubmission msg msgid)
5266

53-
instance (ToJSON sigid, ToJSON (MempoolAddFail sig))
54-
=> ToJSON (TraceLocalMsgSubmission sig sigid) where
55-
toJSON (TraceReceivedMsg sigid) =
67+
68+
69+
data MsgSubmissionServerException msgid msg =
70+
MsgValidationException msgid (MempoolAddFail msg)
71+
| TooManyMessages
72+
73+
deriving instance (Show (MempoolAddFail msg), Show msgid)
74+
=> Show (MsgSubmissionServerException msgid msg)
75+
76+
instance (Typeable msgid, Typeable msg, Show (MempoolAddFail msg), Show msgid)
77+
=> Exception (MsgSubmissionServerException msgid msg) where
78+
79+
80+
instance (ToJSON msgid, ToJSON (MempoolAddFail msg))
81+
=> ToJSON (TraceLocalMsgSubmission msg msgid) where
82+
toJSON (TraceReceivedMsg msgid) =
5683
-- TODO: once we have verbosity levels, we could include the full tx, for
5784
-- now one can use `TraceSendRecv` tracer for the mini-protocol to see full
5885
-- msgs.
5986
object [ "kind" .= Aeson.String "TraceReceivedMsg"
60-
, "sigId" .= sigid
87+
, "sigId" .= msgid
6188
]
62-
toJSON (TraceSubmitFailure sigid reject) =
89+
toJSON (TraceSubmitFailure msgid reject) =
6390
object [ "kind" .= Aeson.String "TraceSubmitFailure"
64-
, "sigId" .= sigid
91+
, "sigId" .= msgid
6592
, "reason" .= reject
6693
]
67-
toJSON (TraceSubmitAccept sigid) =
94+
toJSON (TraceSubmitAccept msgid) =
6895
object [ "kind" .= Aeson.String "TraceSubmitAccept"
69-
, "sigId" .= sigid
96+
, "sigId" .= msgid
7097
]

dmq-node/src/DMQ/Protocol/LocalMsgSubmission/Codec.hs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ import Codec.CBOR.Encoding qualified as CBOR
1010
import Codec.CBOR.Read qualified as CBOR
1111
import Control.Monad.Class.MonadST
1212
import Data.ByteString.Lazy (ByteString)
13+
import Data.Text qualified as T
14+
import Data.Tuple (swap)
1315
import Text.Printf
1416

17+
import Cardano.Binary
1518
import Cardano.KESAgent.KES.Crypto (Crypto (..))
1619

1720
import DMQ.Protocol.LocalMsgSubmission.Type
@@ -35,7 +38,25 @@ codecLocalMsgSubmission =
3538

3639
encodeReject :: MempoolAddFail (Sig crypto) -> CBOR.Encoding
3740
encodeReject = \case
38-
SigInvalid reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> CBOR.encodeString reason
41+
SigInvalid reason -> CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> e
42+
where
43+
e = case reason of
44+
InvalidKESSignature ocertKESPeriod sigKESPeriod err -> mconcat [
45+
CBOR.encodeListLen 4, CBOR.encodeWord 0, toCBOR ocertKESPeriod, toCBOR sigKESPeriod, CBOR.encodeString (T.pack err)
46+
]
47+
InvalidSignatureOCERT ocertN sigKESPeriod err -> mconcat [
48+
CBOR.encodeListLen 4, CBOR.encodeWord 1, CBOR.encodeWord64 ocertN, toCBOR sigKESPeriod, CBOR.encodeString (T.pack err)
49+
]
50+
KESBeforeStartOCERT startKESPeriod sigKESPeriod -> mconcat [
51+
CBOR.encodeListLen 3, CBOR.encodeWord 2, toCBOR startKESPeriod, toCBOR sigKESPeriod
52+
]
53+
KESAfterEndOCERT endKESPeriod sigKESPeriod -> mconcat [
54+
CBOR.encodeListLen 3, CBOR.encodeWord 3, toCBOR endKESPeriod, toCBOR sigKESPeriod
55+
]
56+
UnrecognizedPool -> CBOR.encodeListLen 1 <> CBOR.encodeWord 4
57+
ExpiredPool -> CBOR.encodeListLen 1 <> CBOR.encodeWord 5
58+
NotInitialized -> CBOR.encodeListLen 1 <> CBOR.encodeWord 6
59+
ClockSkew -> CBOR.encodeListLen 1 <> CBOR.encodeWord 7
3960
SigDuplicate -> CBOR.encodeListLen 1 <> CBOR.encodeWord 1
4061
SigExpired -> CBOR.encodeListLen 1 <> CBOR.encodeWord 2
4162
SigResultOther reason
@@ -46,7 +67,22 @@ decodeReject = do
4667
len <- CBOR.decodeListLen
4768
tag <- CBOR.decodeWord
4869
case (tag, len) of
49-
(0, 2) -> SigInvalid <$> CBOR.decodeString
70+
(0, 2) -> SigInvalid <$> decSigValidError
71+
where
72+
decSigValidError :: CBOR.Decoder s SigValidationError
73+
decSigValidError = do
74+
lenTag <- (,) <$> CBOR.decodeListLen <*> CBOR.decodeWord
75+
case swap lenTag of
76+
(0, 4) -> InvalidKESSignature <$> fromCBOR <*> fromCBOR <*> (T.unpack <$> CBOR.decodeString)
77+
(1, 4) -> InvalidSignatureOCERT <$> CBOR.decodeWord64 <*> fromCBOR <*> (T.unpack <$> CBOR.decodeString)
78+
(2, 3) -> KESBeforeStartOCERT <$> fromCBOR <*> fromCBOR
79+
(3, 4) -> KESAfterEndOCERT <$> fromCBOR <*> fromCBOR
80+
(4, 1) -> pure UnrecognizedPool
81+
(5, 1) -> pure ExpiredPool
82+
(6, 1) -> pure NotInitialized
83+
(7, 1) -> pure ClockSkew
84+
_otherwise -> fail $ printf "unrecognized (tag,len) = (%d, %d) when decoding SigInvalid tag" tag len
85+
5086
(1, 1) -> pure SigDuplicate
5187
(2, 1) -> pure SigExpired
5288
(3, 2) -> SigResultOther <$> CBOR.decodeString

0 commit comments

Comments
 (0)