diff --git a/package.yaml b/package.yaml index 0f8e8e8..33cbb4d 100644 --- a/package.yaml +++ b/package.yaml @@ -31,6 +31,7 @@ dependencies: - http-types - bytestring - bytestring-builder +- unordered-containers - text - directory - case-insensitive diff --git a/src/Network/VCR.hs b/src/Network/VCR.hs index ff631bb..f804f3d 100644 --- a/src/Network/VCR.hs +++ b/src/Network/VCR.hs @@ -28,12 +28,12 @@ import qualified Network.Wai.Handler.Warp as Warp import Control.Applicative ((<**>)) import Network.VCR.Middleware (die, middleware) import Network.VCR.Types (Cassette, Mode (..), Options (..), - emptyCassette, parseOptions) + emptyCassette, parseOptions, readCassette) import Options.Applicative (execParser, fullDesc, header, helper, info, progDesc) import System.Environment (getArgs) -import Data.Yaml (decodeFileEither, encodeFile) +import Data.Yaml (encodeFile) import System.Directory (doesFileExist) import System.IO (BufferMode (..), hSetBuffering, stdout) @@ -55,7 +55,7 @@ run options@Options { mode, cassettePath, port } = do exists <- doesFileExist cassettePath when (not exists) $ encodeFile cassettePath (emptyCassette $ T.pack endpoint) _ -> pure () - cas <- decodeFileEither cassettePath + cas <- readCassette cassettePath case cas of Left err -> die $ "Cassette: " <> cassettePath <> " couldn't be decoded or found! " <> (show err) Right cassette -> do diff --git a/src/Network/VCR/Middleware.hs b/src/Network/VCR/Middleware.hs index c79e500..74714b4 100644 --- a/src/Network/VCR/Middleware.hs +++ b/src/Network/VCR/Middleware.hs @@ -5,6 +5,7 @@ module Network.VCR.Middleware where import Control.Monad (when) +import Data.Aeson (decode, Value (..), Object) import Data.ByteString.Builder (toLazyByteString) import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy.Char8 as LBS @@ -22,7 +23,7 @@ import Data.Yaml (decodeFileEither, encode, import qualified Network.HTTP.Types as HT import Network.VCR.Types (ApiCall (..), Cassette (..), Mode (..), SavedRequest (..), - SavedResponse (..), emptyCassette) + SavedResponse (..), emptyCassette, modifyBody') import qualified Network.Wai as Wai import Data.CaseInsensitive (mk) @@ -46,7 +47,7 @@ middleware Record { endpoint } = recordingMiddleware endpoint -- `filePath` cassette file recordingMiddleware :: String -> IORef Cassette -> FilePath -> Wai.Middleware recordingMiddleware endpoint cassetteIORef filePath app req respond = do - cassette@Cassette { apiCalls, ignoredHeaders }<- readIORef cassetteIORef + cassette@Cassette { apiCalls, ignoredHeaders, ignoredBodyFields }<- readIORef cassetteIORef (req', body) <- getRequestBody req -- Construct a request that can be sent to the actual remote API, by replacing the host in the request with the endpoint -- passed as an argument to the middleware @@ -54,7 +55,7 @@ recordingMiddleware endpoint cassetteIORef filePath app req respond = do -- delegate to http-proxy app app newRequest $ \response -> do -- Save the request that we have received from the remote API - let savedRequest = buildRequest ignoredHeaders req' (LBS.fromChunks body) + let savedRequest = buildRequest ignoredHeaders ignoredBodyFields req' (LBS.fromChunks body) -- Since reading the response body consumes it, we can't just reuse the response (status, headers, reBody) <- getResponseBody response savedResponse <- buildResponse reBody response @@ -75,10 +76,11 @@ findAnyResponse cassetteIORef savedRequest = do note ("The request: " <> tshow savedRequest <> " is not recorded! Ignored headers: " <> tshow ignoredHeaders) $ find (\c -> request c == savedRequest) apiCalls --- | A policy for obtaining response which expects the request to be issued in the order they were recorded. + +-- | A policy for obtaining r esponse which expects the request to be issued in the order they were recorded. consumeRequestsInOrder :: FindResponse consumeRequestsInOrder cassetteIORef savedRequest = do - cassette@Cassette { apiCalls, ignoredHeaders } <- readIORef cassetteIORef + cassette@Cassette { apiCalls, ignoredHeaders, ignoredBodyFields } <- readIORef cassetteIORef case apiCalls of c : rest -> do if request c == savedRequest then do @@ -86,6 +88,7 @@ consumeRequestsInOrder cassetteIORef savedRequest = do pure $ Right $ response c else pure $ Left $ "Expected a different request: " <> tshow savedRequest + <> ", needed to match: " <> (tshow $ request c) [] -> pure $ Left "No more requests recorded!" @@ -93,9 +96,9 @@ consumeRequestsInOrder cassetteIORef savedRequest = do -- a 500 error will be thrown replayingMiddleware :: FindResponse -> IORef Cassette -> FilePath -> Wai.Middleware replayingMiddleware findResponse cassetteIORef filePath app req respond = do - cassette@Cassette { apiCalls, ignoredHeaders } <- readIORef cassetteIORef + cassette@Cassette { apiCalls, ignoredHeaders, ignoredBodyFields } <- readIORef cassetteIORef b <- Wai.strictRequestBody req - let savedRequest = buildRequest ignoredHeaders req b + let savedRequest = buildRequest ignoredHeaders ignoredBodyFields req b -- Find an existing ApiCall according to the FindResponse policy findResponse cassetteIORef savedRequest >>= \case -- if a request is found, respond with the saved response @@ -128,9 +131,9 @@ modifyEndpoint endpoint req = req endpointError e = error $ "Error parsing endpoint as URI, " <> show e noHostError = error "No host could be extracted from the endpoint" -buildRequest :: [Text] -> Wai.Request -> LBS.ByteString -> SavedRequest -buildRequest ignoredHeaders r body = - SavedRequest +buildRequest :: [Text] -> [Text] -> Wai.Request -> LBS.ByteString -> SavedRequest +buildRequest ignoredHeaders ignoredBodyFields r body = + modifyBody' ignoredBodyFields $ SavedRequest { methodName = TE.decodeUtf8 $ Wai.requestMethod r , headers = reqHeaders , url = TE.decodeUtf8 $ Wai.rawPathInfo r @@ -138,8 +141,8 @@ buildRequest ignoredHeaders r body = , body = LBS.toStrict body } where - reqHeaders = filter (\(key, value) -> elem key ignoredHeaders') (Wai.requestHeaders r) - ignoredHeaders' = mk . BE.encodeUtf8 <$> ignoredHeaders + reqHeaders = filter (\(key, value) -> elem key ignoredHeaders') (Wai.requestHeaders r) + ignoredHeaders' = mk . BE.encodeUtf8 <$> ignoredHeaders buildResponse :: LBS.ByteString -> Wai.Response -> IO SavedResponse buildResponse body response = do diff --git a/src/Network/VCR/Types.hs b/src/Network/VCR/Types.hs index bf96a57..a06111e 100644 --- a/src/Network/VCR/Types.hs +++ b/src/Network/VCR/Types.hs @@ -1,18 +1,25 @@ -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} module Network.VCR.Types where +import Data.Bifunctor (first) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L +import Data.HashMap.Strict (HashMap) +import qualified Data.HashMap.Strict as HashMap +import Data.Maybe (fromJust) import Data.Text (Text) import qualified Data.Text.Encoding as BE (decodeUtf8, encodeUtf8) import qualified Data.Text.Lazy.Encoding as BEL (decodeUtf8, encodeUtf8) +import Data.Yaml (ParseException (AesonException), decodeFileEither) import Control.Monad (mzero) import Data.Aeson (FromJSON, ToJSON, Value (..), object, + decode, encode, eitherDecode, parseJSON, toJSON, (.:), (.=)) import Data.CaseInsensitive (foldedCase, mk) import GHC.Generics (Generic) @@ -72,9 +79,10 @@ data ApiCall = ApiCall } deriving (Show, Eq, Generic, ToJSON, FromJSON) data Cassette = Cassette - { endpoint :: Text - , apiCalls :: [ApiCall] - , ignoredHeaders :: [Text] + { endpoint :: Text + , apiCalls :: [ApiCall] + , ignoredHeaders :: [Text] + , ignoredBodyFields :: [Text] } deriving (Show, Eq, Generic, ToJSON, FromJSON) @@ -140,5 +148,41 @@ toHeader (name, value) = (mk $ BE.encodeUtf8 name, BE.encodeUtf8 value) emptyCassette :: Text -> Cassette -emptyCassette endpoint = Cassette { endpoint = endpoint, apiCalls = [], ignoredHeaders = [] } - +emptyCassette endpoint = Cassette { endpoint = endpoint, apiCalls = [], ignoredHeaders = [], ignoredBodyFields = [] } + +-- utility methods for matching with ignored fields in the body + +type ModResult a = Either ParseException a + +decodeExc :: FromJSON a => L.ByteString -> ModResult a +decodeExc = first AesonException . eitherDecode + +-- | Remove the ignored fields from the request's body +modifyBody :: [Text] -> SavedRequest -> ModResult SavedRequest +modifyBody ignoredBodyFields (SavedRequest mn hs url ps body) = SavedRequest mn hs url ps <$> body' + where ignoredBodyFieldsMap = HashMap.fromList $ zip ignoredBodyFields (repeat Null) + bodyDecoded :: ModResult (HashMap Text Value) = decodeExc $ L.fromStrict body + bodyDiff = HashMap.difference <$> bodyDecoded <*> (Right ignoredBodyFieldsMap) + body' = (L.toStrict . encode) <$> bodyDiff + +-- | An error-throwing versiion of modifyBody +modifyBody' :: [Text] -> SavedRequest -> SavedRequest +modifyBody' ignoredBodyFields request@(SavedRequest _ _ _ _ body) = + case modifyBody ignoredBodyFields request of + Right r -> r + Left exc -> error $ "Error when parsing the body of the request (" <> (show body) <> "): " <> (show exc) + +-- | Remove the ignored fields from all of the saved requests +-- this allows us to later match the requests without the ignored fields taken into account. +modifyCassette :: Cassette -> ModResult Cassette +modifyCassette (Cassette endpoint apiCalls ignoredHeaders ignoredBodyFields) = + Cassette endpoint <$> apiCalls' <*> (Right ignoredHeaders) <*> (Right ignoredBodyFields) + where + modifyApiCall ignoredBodyFields (ApiCall savedRequest savedResponse) = ApiCall <$> (modifyBody ignoredBodyFields savedRequest) <*> (Right savedResponse) + apiCalls' = sequence $ (modifyApiCall ignoredBodyFields) <$> apiCalls + +-- | Read the cassette, removing the ignored fields from all saved requests' bodies +readCassette :: String -> IO (ModResult Cassette) +readCassette filePath = do + cassette <- decodeFileEither filePath + pure $ cassette >>= modifyCassette