diff --git a/pgmock.go b/pgmock.go index 7bd09fc..4b44f8a 100644 --- a/pgmock.go +++ b/pgmock.go @@ -55,7 +55,7 @@ func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { } type expectStartupMessageStep struct { - want *pgproto3.StartupMessage + want pgproto3.FrontendMessage any bool } @@ -85,11 +85,13 @@ func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { } func expectMessage(want pgproto3.FrontendMessage, any bool) Step { - if want, ok := want.(*pgproto3.StartupMessage); ok { - return &expectStartupMessageStep{want: want, any: any} + switch msg := want.(type) { + case *pgproto3.StartupMessage, *pgproto3.CancelRequest, + *pgproto3.SSLRequest, *pgproto3.GSSEncRequest: + return &expectStartupMessageStep{want: msg, any: any} + default: + return &expectMessageStep{want: want, any: any} } - - return &expectMessageStep{want: want, any: any} } type sendMessageStep struct {