diff --git a/go.mod b/go.mod index 0a4ca6fbf..9c6edea1c 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/klauspost/compress v1.18.0 github.com/maypok86/otter/v2 v2.2.1 github.com/mitchellh/mapstructure v1.5.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/pquerna/xjwt v0.3.0 github.com/pquerna/xjwt/xkeyset v0.0.0-20241217022915-10fc997b2a9f github.com/segmentio/ksuid v1.0.4 @@ -51,7 +52,7 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/crypto v0.34.0 golang.org/x/net v0.35.0 - golang.org/x/oauth2 v0.26.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.11.0 golang.org/x/sys v0.38.0 golang.org/x/term v0.29.0 @@ -86,6 +87,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -108,6 +110,7 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect diff --git a/go.sum b/go.sum index 4a1c84dad..ae97267be 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -117,6 +119,8 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -157,6 +161,8 @@ github.com/maypok86/otter/v2 v2.2.1 h1:hnGssisMFkdisYcvQ8L019zpYQcdtPse+g0ps2i7c github.com/maypok86/otter/v2 v2.2.1/go.mod h1:1NKY9bY+kB5jwCXBJfE59u+zAwOt6C7ni1FTlFFMqVs= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= @@ -219,6 +225,8 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= @@ -294,8 +302,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= -golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -330,8 +338,8 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= -golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index ae3f324e2..40ed588c1 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -26,6 +26,7 @@ import ( v1 "github.com/conductorone/baton-sdk/pb/c1/connector_wrapper/v1" baton_v1 "github.com/conductorone/baton-sdk/pb/c1/connectorapi/baton/v1" "github.com/conductorone/baton-sdk/pkg/connectorrunner" + mcpPkg "github.com/conductorone/baton-sdk/pkg/mcp" "github.com/conductorone/baton-sdk/pkg/crypto" "github.com/conductorone/baton-sdk/pkg/field" "github.com/conductorone/baton-sdk/pkg/logging" @@ -599,6 +600,74 @@ func MakeGRPCServerCommand[T field.Configurable]( } } +func MakeMCPServerCommand[T field.Configurable]( + ctx context.Context, + name string, + v *viper.Viper, + confschema field.Configuration, + getconnector GetConnectorFunc2[T], +) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + err := v.BindPFlags(cmd.Flags()) + if err != nil { + return err + } + + runCtx, err := initLogger( + ctx, + name, + logging.WithLogFormat(v.GetString("log-format")), + logging.WithLogLevel(v.GetString("log-level")), + ) + if err != nil { + return err + } + + runCtx, otelShutdown, err := initOtel(runCtx, name, v, nil) + if err != nil { + return err + } + defer func() { + if otelShutdown == nil { + return + } + shutdownCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(otelShutdownTimeout)) + defer cancel() + err := otelShutdown(shutdownCtx) + if err != nil { + zap.L().Error("error shutting down otel", zap.Error(err)) + } + }() + + l := ctxzap.Extract(runCtx) + l.Debug("starting MCP server") + + readFromPath := true + decodeOpts := field.WithAdditionalDecodeHooks(field.FileUploadDecodeHook(readFromPath)) + t, err := MakeGenericConfiguration[T](v, decodeOpts) + if err != nil { + return fmt.Errorf("failed to make configuration: %w", err) + } + + if err := field.Validate(confschema, t, field.WithAuthMethod(v.GetString("auth-method"))); err != nil { + return err + } + + c, err := getconnector(runCtx, t, RunTimeOpts{}) + if err != nil { + return err + } + + mcpServer, err := mcpPkg.NewMCPServer(runCtx, name, c) + if err != nil { + return fmt.Errorf("failed to create MCP server: %w", err) + } + + l.Info("MCP server starting on stdio") + return mcpServer.Serve(runCtx) + } +} + func MakeCapabilitiesCommand[T field.Configurable]( ctx context.Context, name string, diff --git a/pkg/config/config.go b/pkg/config/config.go index d93045243..c4343f05a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -207,6 +207,17 @@ func DefineConfigurationV2[T field.Configurable]( return nil, nil, err } + _, err = cli.AddCommand(mainCMD, v, &schema, &cobra.Command{ + Use: "mcp", + Short: "Run as MCP server (stdio transport)", + Long: "Run the connector as an MCP (Model Context Protocol) server using stdio transport. This allows AI assistants to interact with the connector.", + RunE: cli.MakeMCPServerCommand(ctx, connectorName, v, confschema, connector), + }) + + if err != nil { + return nil, nil, err + } + _, err = cli.AddCommand(mainCMD, v, &schema, &cobra.Command{ Use: "capabilities", Short: "Get connector capabilities", diff --git a/pkg/mcp/convert.go b/pkg/mcp/convert.go new file mode 100644 index 000000000..b8c496d90 --- /dev/null +++ b/pkg/mcp/convert.go @@ -0,0 +1,74 @@ +package mcp + +import ( + "encoding/json" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +var ( + marshaler = protojson.MarshalOptions{ + EmitUnpopulated: false, + UseProtoNames: true, + } + + unmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, + } +) + +// protoToMap converts a proto message to a map[string]any. +func protoToMap(msg proto.Message) (map[string]any, error) { + if msg == nil { + return nil, nil + } + jsonBytes, err := marshaler.Marshal(msg) + if err != nil { + return nil, err + } + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, err + } + return result, nil +} + +// protoToJSON converts a proto message to a JSON string. +func protoToJSON(msg proto.Message) (string, error) { + if msg == nil { + return "{}", nil + } + jsonBytes, err := marshaler.Marshal(msg) + if err != nil { + return "", err + } + return string(jsonBytes), nil +} + +// protoListToMaps converts a slice of proto messages to a slice of maps. +func protoListToMaps[T proto.Message](list []T) ([]map[string]any, error) { + result := make([]map[string]any, 0, len(list)) + for _, item := range list { + m, err := protoToMap(item) + if err != nil { + return nil, err + } + result = append(result, m) + } + return result, nil +} + +// jsonToProto unmarshals a JSON string into a proto message. +func jsonToProto(jsonStr string, msg proto.Message) error { + return unmarshaler.Unmarshal([]byte(jsonStr), msg) +} + +// mapToProto converts a map to a proto message by going through JSON. +func mapToProto(m map[string]any, msg proto.Message) error { + jsonBytes, err := json.Marshal(m) + if err != nil { + return err + } + return unmarshaler.Unmarshal(jsonBytes, msg) +} diff --git a/pkg/mcp/handlers.go b/pkg/mcp/handlers.go new file mode 100644 index 000000000..1913b6744 --- /dev/null +++ b/pkg/mcp/handlers.go @@ -0,0 +1,470 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" +) + +const defaultPageSize = 50 + +// Input/Output types for handlers. + +type EmptyInput struct{} + +type PaginationInput struct { + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type ResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type (e.g. user or group)"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID"` +} + +type ResourcePaginationInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID"` + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type ListResourcesInput struct { + ResourceTypeID string `json:"resource_type_id" jsonschema:"required,description=The resource type ID to list (e.g. user or group)"` + ParentResourceType string `json:"parent_resource_type,omitempty" jsonschema:"description=Parent resource type (optional)"` + ParentResourceID string `json:"parent_resource_id,omitempty" jsonschema:"description=Parent resource ID (optional)"` + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type GrantInput struct { + EntitlementResourceType string `json:"entitlement_resource_type" jsonschema:"required,description=Resource type of the entitlement"` + EntitlementResourceID string `json:"entitlement_resource_id" jsonschema:"required,description=Resource ID of the entitlement"` + EntitlementID string `json:"entitlement_id" jsonschema:"required,description=The entitlement ID"` + PrincipalResourceType string `json:"principal_resource_type" jsonschema:"required,description=Resource type of the principal (e.g. user or group)"` + PrincipalResourceID string `json:"principal_resource_id" jsonschema:"required,description=Resource ID of the principal"` +} + +type RevokeInput struct { + GrantID string `json:"grant_id" jsonschema:"required,description=The grant ID to revoke"` + EntitlementResourceType string `json:"entitlement_resource_type" jsonschema:"required,description=Resource type of the entitlement"` + EntitlementResourceID string `json:"entitlement_resource_id" jsonschema:"required,description=Resource ID of the entitlement"` + EntitlementID string `json:"entitlement_id" jsonschema:"required,description=The entitlement ID"` + PrincipalResourceType string `json:"principal_resource_type" jsonschema:"required,description=Resource type of the principal"` + PrincipalResourceID string `json:"principal_resource_id" jsonschema:"required,description=Resource ID of the principal"` +} + +type CreateResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type to create"` + DisplayName string `json:"display_name" jsonschema:"required,description=Display name for the new resource"` + ParentResourceType string `json:"parent_resource_type,omitempty" jsonschema:"description=Parent resource type (optional)"` + ParentResourceID string `json:"parent_resource_id,omitempty" jsonschema:"description=Parent resource ID (optional)"` +} + +type DeleteResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID to delete"` +} + +type CreateTicketInput struct { + SchemaID string `json:"schema_id" jsonschema:"required,description=The ticket schema ID"` + DisplayName string `json:"display_name" jsonschema:"required,description=Display name for the ticket"` + Description string `json:"description,omitempty" jsonschema:"description=Description of the ticket"` +} + +type GetTicketInput struct { + TicketID string `json:"ticket_id" jsonschema:"required,description=The ticket ID"` +} + +// Output types. + +type MetadataOutput struct { + Metadata map[string]any `json:"metadata"` +} + +type ValidateOutput struct { + Valid bool `json:"valid"` + Annotations any `json:"annotations,omitempty"` +} + +type ListResourceTypesOutput struct { + ResourceTypes []map[string]any `json:"resource_types"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ListResourcesOutput struct { + Resources []map[string]any `json:"resources"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ResourceOutput struct { + Resource map[string]any `json:"resource"` +} + +type ListEntitlementsOutput struct { + Entitlements []map[string]any `json:"entitlements"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ListGrantsOutput struct { + Grants []map[string]any `json:"grants"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type GrantOutput struct { + Grants []map[string]any `json:"grants"` +} + +type SuccessOutput struct { + Success bool `json:"success"` +} + +type ListTicketSchemasOutput struct { + Schemas []map[string]any `json:"schemas"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type TicketOutput struct { + Ticket map[string]any `json:"ticket"` +} + +// Handler implementations. + +func (m *MCPServer) handleGetMetadata(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, MetadataOutput, error) { + resp, err := m.connector.GetMetadata(ctx, &v2.ConnectorServiceGetMetadataRequest{}) + if err != nil { + return nil, MetadataOutput{}, fmt.Errorf("failed to get metadata: %w", err) + } + + result, err := protoToMap(resp.GetMetadata()) + if err != nil { + return nil, MetadataOutput{}, fmt.Errorf("failed to serialize metadata: %w", err) + } + + return nil, MetadataOutput{Metadata: result}, nil +} + +func (m *MCPServer) handleValidate(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, ValidateOutput, error) { + resp, err := m.connector.Validate(ctx, &v2.ConnectorServiceValidateRequest{}) + if err != nil { + return nil, ValidateOutput{}, fmt.Errorf("validation failed: %w", err) + } + + return nil, ValidateOutput{ + Valid: true, + Annotations: resp.GetAnnotations(), + }, nil +} + +func (m *MCPServer) handleListResourceTypes(ctx context.Context, req *mcp.CallToolRequest, input PaginationInput) (*mcp.CallToolResult, ListResourceTypesOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) + } + + resp, err := m.connector.ListResourceTypes(ctx, &v2.ResourceTypesServiceListResourceTypesRequest{ + PageSize: pageSize, + PageToken: input.PageToken, + }) + if err != nil { + return nil, ListResourceTypesOutput{}, fmt.Errorf("failed to list resource types: %w", err) + } + + resourceTypes, err := protoListToMaps(resp.GetList()) + if err != nil { + return nil, ListResourceTypesOutput{}, fmt.Errorf("failed to serialize resource types: %w", err) + } + + return nil, ListResourceTypesOutput{ + ResourceTypes: resourceTypes, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil +} + +func (m *MCPServer) handleListResources(ctx context.Context, req *mcp.CallToolRequest, input ListResourcesInput) (*mcp.CallToolResult, ListResourcesOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) + } + + var parentResourceID *v2.ResourceId + if input.ParentResourceType != "" && input.ParentResourceID != "" { + parentResourceID = &v2.ResourceId{ + ResourceType: input.ParentResourceType, + Resource: input.ParentResourceID, + } + } + + resp, err := m.connector.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: input.ResourceTypeID, + ParentResourceId: parentResourceID, + PageSize: pageSize, + PageToken: input.PageToken, + }) + if err != nil { + return nil, ListResourcesOutput{}, fmt.Errorf("failed to list resources: %w", err) + } + + resources, err := protoListToMaps(resp.GetList()) + if err != nil { + return nil, ListResourcesOutput{}, fmt.Errorf("failed to serialize resources: %w", err) + } + + return nil, ListResourcesOutput{ + Resources: resources, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil +} + +func (m *MCPServer) handleGetResource(ctx context.Context, req *mcp.CallToolRequest, input ResourceInput) (*mcp.CallToolResult, ResourceOutput, error) { + resp, err := m.connector.GetResource(ctx, &v2.ResourceGetterServiceGetResourceRequest{ + ResourceId: &v2.ResourceId{ + ResourceType: input.ResourceType, + Resource: input.ResourceID, + }, + }) + if err != nil { + return nil, ResourceOutput{}, fmt.Errorf("failed to get resource: %w", err) + } + + resource, err := protoToMap(resp.GetResource()) + if err != nil { + return nil, ResourceOutput{}, fmt.Errorf("failed to serialize resource: %w", err) + } + + return nil, ResourceOutput{Resource: resource}, nil +} + +func (m *MCPServer) handleListEntitlements(ctx context.Context, req *mcp.CallToolRequest, input ResourcePaginationInput) (*mcp.CallToolResult, ListEntitlementsOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) + } + + resp, err := m.connector.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.ResourceType, + Resource: input.ResourceID, + }, + }, + PageSize: pageSize, + PageToken: input.PageToken, + }) + if err != nil { + return nil, ListEntitlementsOutput{}, fmt.Errorf("failed to list entitlements: %w", err) + } + + entitlements, err := protoListToMaps(resp.GetList()) + if err != nil { + return nil, ListEntitlementsOutput{}, fmt.Errorf("failed to serialize entitlements: %w", err) + } + + return nil, ListEntitlementsOutput{ + Entitlements: entitlements, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil +} + +func (m *MCPServer) handleListGrants(ctx context.Context, req *mcp.CallToolRequest, input ResourcePaginationInput) (*mcp.CallToolResult, ListGrantsOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) + } + + resp, err := m.connector.ListGrants(ctx, &v2.GrantsServiceListGrantsRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.ResourceType, + Resource: input.ResourceID, + }, + }, + PageSize: pageSize, + PageToken: input.PageToken, + }) + if err != nil { + return nil, ListGrantsOutput{}, fmt.Errorf("failed to list grants: %w", err) + } + + grants, err := protoListToMaps(resp.GetList()) + if err != nil { + return nil, ListGrantsOutput{}, fmt.Errorf("failed to serialize grants: %w", err) + } + + return nil, ListGrantsOutput{ + Grants: grants, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil +} + +func (m *MCPServer) handleGrant(ctx context.Context, req *mcp.CallToolRequest, input GrantInput) (*mcp.CallToolResult, GrantOutput, error) { + resp, err := m.connector.Grant(ctx, &v2.GrantManagerServiceGrantRequest{ + Entitlement: &v2.Entitlement{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.EntitlementResourceType, + Resource: input.EntitlementResourceID, + }, + }, + Id: input.EntitlementID, + }, + Principal: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.PrincipalResourceType, + Resource: input.PrincipalResourceID, + }, + }, + }) + if err != nil { + return nil, GrantOutput{}, fmt.Errorf("grant failed: %w", err) + } + + grants, err := protoListToMaps(resp.GetGrants()) + if err != nil { + return nil, GrantOutput{}, fmt.Errorf("failed to serialize grants: %w", err) + } + + return nil, GrantOutput{Grants: grants}, nil +} + +func (m *MCPServer) handleRevoke(ctx context.Context, req *mcp.CallToolRequest, input RevokeInput) (*mcp.CallToolResult, SuccessOutput, error) { + _, err := m.connector.Revoke(ctx, &v2.GrantManagerServiceRevokeRequest{ + Grant: &v2.Grant{ + Id: input.GrantID, + Entitlement: &v2.Entitlement{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.EntitlementResourceType, + Resource: input.EntitlementResourceID, + }, + }, + Id: input.EntitlementID, + }, + Principal: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.PrincipalResourceType, + Resource: input.PrincipalResourceID, + }, + }, + }, + }) + if err != nil { + return nil, SuccessOutput{}, fmt.Errorf("revoke failed: %w", err) + } + + return nil, SuccessOutput{Success: true}, nil +} + +func (m *MCPServer) handleCreateResource(ctx context.Context, req *mcp.CallToolRequest, input CreateResourceInput) (*mcp.CallToolResult, ResourceOutput, error) { + var parentResource *v2.Resource + if input.ParentResourceType != "" && input.ParentResourceID != "" { + parentResource = &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.ParentResourceType, + Resource: input.ParentResourceID, + }, + } + } + + resp, err := m.connector.CreateResource(ctx, &v2.CreateResourceRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: input.ResourceType, + }, + DisplayName: input.DisplayName, + ParentResourceId: parentResource.GetId(), + }, + }) + if err != nil { + return nil, ResourceOutput{}, fmt.Errorf("create resource failed: %w", err) + } + + resource, err := protoToMap(resp.GetCreated()) + if err != nil { + return nil, ResourceOutput{}, fmt.Errorf("failed to serialize resource: %w", err) + } + + return nil, ResourceOutput{Resource: resource}, nil +} + +func (m *MCPServer) handleDeleteResource(ctx context.Context, req *mcp.CallToolRequest, input DeleteResourceInput) (*mcp.CallToolResult, SuccessOutput, error) { + _, err := m.connector.DeleteResource(ctx, &v2.DeleteResourceRequest{ + ResourceId: &v2.ResourceId{ + ResourceType: input.ResourceType, + Resource: input.ResourceID, + }, + }) + if err != nil { + return nil, SuccessOutput{}, fmt.Errorf("delete resource failed: %w", err) + } + + return nil, SuccessOutput{Success: true}, nil +} + +func (m *MCPServer) handleListTicketSchemas(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, ListTicketSchemasOutput, error) { + resp, err := m.connector.ListTicketSchemas(ctx, &v2.TicketsServiceListTicketSchemasRequest{}) + if err != nil { + return nil, ListTicketSchemasOutput{}, fmt.Errorf("failed to list ticket schemas: %w", err) + } + + schemas, err := protoListToMaps(resp.GetList()) + if err != nil { + return nil, ListTicketSchemasOutput{}, fmt.Errorf("failed to serialize ticket schemas: %w", err) + } + + return nil, ListTicketSchemasOutput{ + Schemas: schemas, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil +} + +func (m *MCPServer) handleCreateTicket(ctx context.Context, req *mcp.CallToolRequest, input CreateTicketInput) (*mcp.CallToolResult, TicketOutput, error) { + resp, err := m.connector.CreateTicket(ctx, &v2.TicketsServiceCreateTicketRequest{ + Schema: &v2.TicketSchema{ + Id: input.SchemaID, + }, + Request: &v2.TicketRequest{ + DisplayName: input.DisplayName, + Description: input.Description, + }, + }) + if err != nil { + return nil, TicketOutput{}, fmt.Errorf("create ticket failed: %w", err) + } + + ticket, err := protoToMap(resp.GetTicket()) + if err != nil { + return nil, TicketOutput{}, fmt.Errorf("failed to serialize ticket: %w", err) + } + + return nil, TicketOutput{Ticket: ticket}, nil +} + +func (m *MCPServer) handleGetTicket(ctx context.Context, req *mcp.CallToolRequest, input GetTicketInput) (*mcp.CallToolResult, TicketOutput, error) { + resp, err := m.connector.GetTicket(ctx, &v2.TicketsServiceGetTicketRequest{ + Id: input.TicketID, + }) + if err != nil { + return nil, TicketOutput{}, fmt.Errorf("get ticket failed: %w", err) + } + + ticket, err := protoToMap(resp.GetTicket()) + if err != nil { + return nil, TicketOutput{}, fmt.Errorf("failed to serialize ticket: %w", err) + } + + return nil, TicketOutput{Ticket: ticket}, nil +} diff --git a/pkg/mcp/server.go b/pkg/mcp/server.go new file mode 100644 index 000000000..cfc7b4d63 --- /dev/null +++ b/pkg/mcp/server.go @@ -0,0 +1,171 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/types" +) + +// MCPServer wraps a ConnectorServer and exposes its functionality via MCP. +type MCPServer struct { + connector types.ConnectorServer + server *mcp.Server + caps *v2.ConnectorCapabilities +} + +// NewMCPServer creates a new MCP server that wraps the given ConnectorServer. +func NewMCPServer(ctx context.Context, name string, connector types.ConnectorServer) (*MCPServer, error) { + // Get connector metadata to determine capabilities. + metaResp, err := connector.GetMetadata(ctx, &v2.ConnectorServiceGetMetadataRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to get connector metadata: %w", err) + } + + s := mcp.NewServer( + &mcp.Implementation{ + Name: name, + Version: "1.0.0", + }, + nil, + ) + + m := &MCPServer{ + connector: connector, + server: s, + caps: metaResp.GetMetadata().GetCapabilities(), + } + + m.registerTools() + return m, nil +} + +// Serve starts the MCP server on stdio. +func (m *MCPServer) Serve(ctx context.Context) error { + return m.server.Run(ctx, &mcp.StdioTransport{}) +} + +// registerTools registers all MCP tools based on connector capabilities. +func (m *MCPServer) registerTools() { + // Always register read-only tools. + m.registerReadTools() + + // Register provisioning tools if the connector supports provisioning. + if m.hasCapability(v2.Capability_CAPABILITY_PROVISION) { + m.registerProvisioningTools() + } + + // Register ticketing tools if the connector supports ticketing. + if m.hasCapability(v2.Capability_CAPABILITY_TICKETING) { + m.registerTicketingTools() + } +} + +// hasCapability checks if the connector has the given capability. +func (m *MCPServer) hasCapability(cap v2.Capability) bool { + if m.caps == nil { + return false + } + for _, c := range m.caps.GetConnectorCapabilities() { + if c == cap { + return true + } + } + return false +} + +// registerReadTools registers read-only tools that are always available. +func (m *MCPServer) registerReadTools() { + // get_metadata + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_metadata", + Description: "Get connector metadata including display name, description, and capabilities", + }, m.handleGetMetadata) + + // validate + mcp.AddTool(m.server, &mcp.Tool{ + Name: "validate", + Description: "Validate the connector configuration and connectivity", + }, m.handleValidate) + + // list_resource_types + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_resource_types", + Description: "List all resource types supported by this connector", + }, m.handleListResourceTypes) + + // list_resources + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_resources", + Description: "List resources of a specific type", + }, m.handleListResources) + + // get_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_resource", + Description: "Get a specific resource by its type and ID", + }, m.handleGetResource) + + // list_entitlements + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_entitlements", + Description: "List entitlements (permissions, roles, memberships) for a resource", + }, m.handleListEntitlements) + + // list_grants + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_grants", + Description: "List grants (who has what access) for a resource", + }, m.handleListGrants) +} + +// registerProvisioningTools registers tools for provisioning operations. +func (m *MCPServer) registerProvisioningTools() { + // grant + mcp.AddTool(m.server, &mcp.Tool{ + Name: "grant", + Description: "Grant an entitlement to a principal (user or group)", + }, m.handleGrant) + + // revoke + mcp.AddTool(m.server, &mcp.Tool{ + Name: "revoke", + Description: "Revoke a grant from a principal", + }, m.handleRevoke) + + // create_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "create_resource", + Description: "Create a new resource", + }, m.handleCreateResource) + + // delete_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "delete_resource", + Description: "Delete a resource", + }, m.handleDeleteResource) +} + +// registerTicketingTools registers tools for ticketing operations. +func (m *MCPServer) registerTicketingTools() { + // list_ticket_schemas + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_ticket_schemas", + Description: "List available ticket schemas", + }, m.handleListTicketSchemas) + + // create_ticket + mcp.AddTool(m.server, &mcp.Tool{ + Name: "create_ticket", + Description: "Create a new ticket", + }, m.handleCreateTicket) + + // get_ticket + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_ticket", + Description: "Get a ticket by ID", + }, m.handleGetTicket) +} diff --git a/vendor/github.com/google/jsonschema-go/LICENSE b/vendor/github.com/google/jsonschema-go/LICENSE new file mode 100644 index 000000000..1cb53e9df --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 JSON Schema Go Project Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go new file mode 100644 index 000000000..d4dd6436b --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go @@ -0,0 +1,76 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import "maps" + +// An annotations tracks certain properties computed by keywords that are used by validation. +// ("Annotation" is the spec's term.) +// In particular, the unevaluatedItems and unevaluatedProperties keywords need to know which +// items and properties were evaluated (validated successfully). +type annotations struct { + allItems bool // all items were evaluated + endIndex int // 1+largest index evaluated by prefixItems + evaluatedIndexes map[int]bool // set of indexes evaluated by contains + allProperties bool // all properties were evaluated + evaluatedProperties map[string]bool // set of properties evaluated by various keywords +} + +// noteIndex marks i as evaluated. +func (a *annotations) noteIndex(i int) { + if a.evaluatedIndexes == nil { + a.evaluatedIndexes = map[int]bool{} + } + a.evaluatedIndexes[i] = true +} + +// noteEndIndex marks items with index less than end as evaluated. +func (a *annotations) noteEndIndex(end int) { + if end > a.endIndex { + a.endIndex = end + } +} + +// noteProperty marks prop as evaluated. +func (a *annotations) noteProperty(prop string) { + if a.evaluatedProperties == nil { + a.evaluatedProperties = map[string]bool{} + } + a.evaluatedProperties[prop] = true +} + +// noteProperties marks all the properties in props as evaluated. +func (a *annotations) noteProperties(props map[string]bool) { + a.evaluatedProperties = merge(a.evaluatedProperties, props) +} + +// merge adds b's annotations to a. +// a must not be nil. +func (a *annotations) merge(b *annotations) { + if b == nil { + return + } + if b.allItems { + a.allItems = true + } + if b.endIndex > a.endIndex { + a.endIndex = b.endIndex + } + a.evaluatedIndexes = merge(a.evaluatedIndexes, b.evaluatedIndexes) + if b.allProperties { + a.allProperties = true + } + a.evaluatedProperties = merge(a.evaluatedProperties, b.evaluatedProperties) +} + +// merge adds t's keys to s and returns s. +// If s is nil, it returns a copy of t. +func merge[K comparable](s, t map[K]bool) map[K]bool { + if s == nil { + return maps.Clone(t) + } + maps.Copy(s, t) + return s +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/doc.go b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go new file mode 100644 index 000000000..a34bab725 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go @@ -0,0 +1,101 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +/* +Package jsonschema is an implementation of the [JSON Schema specification], +a JSON-based format for describing the structure of JSON data. +The package can be used to read schemas for code generation, and to validate +data using the draft 2020-12 specification. Validation with other drafts +or custom meta-schemas is not supported. + +Construct a [Schema] as you would any Go struct (for example, by writing +a struct literal), or unmarshal a JSON schema into a [Schema] in the usual +way (with [encoding/json], for instance). It can then be used for code +generation or other purposes without further processing. +You can also infer a schema from a Go struct. + +# Resolution + +A Schema can refer to other schemas, both inside and outside itself. These +references must be resolved before a schema can be used for validation. +Call [Schema.Resolve] to obtain a resolved schema (called a [Resolved]). +If the schema has external references, pass a [ResolveOptions] with a [Loader] +to load them. To validate default values in a schema, set +[ResolveOptions.ValidateDefaults] to true. + +# Validation + +Call [Resolved.Validate] to validate a JSON value. The value must be a +Go value that looks like the result of unmarshaling a JSON value into an +[any] or a struct. For example, the JSON value + + {"name": "Al", "scores": [90, 80, 100]} + +could be represented as the Go value + + map[string]any{ + "name": "Al", + "scores": []any{90, 80, 100}, + } + +or as a value of this type: + + type Player struct { + Name string `json:"name"` + Scores []int `json:"scores"` + } + +# Inference + +The [For] function returns a [Schema] describing the given Go type. +Each field in the struct becomes a property of the schema. +The values of "json" tags are respected: the field's property name is taken +from the tag, and fields omitted from the JSON are omitted from the schema as +well. +For example, `jsonschema.For[Player]()` returns this schema: + + { + "properties": { + "name": { + "type": "string" + }, + "scores": { + "type": "array", + "items": {"type": "integer"} + } + "required": ["name", "scores"], + "additionalProperties": {"not": {}} + } + } + +Use the "jsonschema" struct tag to provide a description for the property: + + type Player struct { + Name string `json:"name" jsonschema:"player name"` + Scores []int `json:"scores" jsonschema:"scores of player's games"` + } + +# Deviations from the specification + +Regular expressions are processed with Go's regexp package, which differs +from ECMA 262, most significantly in not supporting back-references. +See [this table of differences] for more. + +The "format" keyword described in [section 7 of the validation spec] is recorded +in the Schema, but is ignored during validation. +It does not even produce [annotations]. +Use the "pattern" keyword instead: it will work more reliably across JSON Schema +implementations. See [learnjsonschema.com] for more recommendations about "format". + +The content keywords described in [section 8 of the validation spec] +are recorded in the schema, but ignored during validation. + +[JSON Schema specification]: https://json-schema.org +[section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 +[section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 +[learnjsonschema.com]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ +[this table of differences]: https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 +[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations +*/ +package jsonschema diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/infer.go b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go new file mode 100644 index 000000000..ae624ad09 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go @@ -0,0 +1,248 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains functions that infer a schema from a Go type. + +package jsonschema + +import ( + "fmt" + "log/slog" + "maps" + "math/big" + "reflect" + "regexp" + "time" +) + +// ForOptions are options for the [For] and [ForType] functions. +type ForOptions struct { + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON + // Schema are ignored instead of causing an error. + // This allows callers to adjust the resulting schema using custom knowledge. + // For example, an interface type where all the possible implementations are + // known can be described with "oneof". + IgnoreInvalidTypes bool + + // TypeSchemas maps types to their schemas. + // If [For] encounters a type that is a key in this map, the + // corresponding value is used as the resulting schema (after cloning to + // ensure uniqueness). + // Types in this map override the default translations, as described + // in [For]'s documentation. + TypeSchemas map[reflect.Type]*Schema +} + +// For constructs a JSON schema object for the given type argument. +// If non-nil, the provided options configure certain aspects of this contruction, +// described below. + +// It translates Go types into compatible JSON schema types, as follows. +// These defaults can be overridden by [ForOptions.TypeSchemas]. +// +// - Strings have schema type "string". +// - Bools have schema type "boolean". +// - Signed and unsigned integer types have schema type "integer". +// - Floating point types have schema type "number". +// - Slices and arrays have schema type "array", and a corresponding schema +// for items. +// - Maps with string key have schema type "object", and corresponding +// schema for additionalProperties. +// - Structs have schema type "object", and disallow additionalProperties. +// Their properties are derived from exported struct fields, using the +// struct field JSON name. Fields that are marked "omitempty" are +// considered optional; all other fields become required properties. +// - Some types in the standard library that implement json.Marshaler +// translate to schemas that match the values to which they marshal. +// For example, [time.Time] translates to the schema for strings. +// +// For will return an error if there is a cycle in the types. +// +// By default, For returns an error if t contains (possibly recursively) any of the +// following Go types, as they are incompatible with the JSON schema spec. +// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. +// - maps with key other than 'string' +// - function types +// - channel types +// - complex numbers +// - unsafe pointers +// +// This function recognizes struct field tags named "jsonschema". +// A jsonschema tag on a field is used as the description for the corresponding property. +// For future compatibility, descriptions must not start with "WORD=", where WORD is a +// sequence of non-whitespace characters. +func For[T any](opts *ForOptions) (*Schema, error) { + if opts == nil { + opts = &ForOptions{} + } + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + var z T + return nil, fmt.Errorf("For[%T](): %w", z, err) + } + return s, nil +} + +// ForType is like [For], but takes a [reflect.Type] +func ForType(t reflect.Type, opts *ForOptions) (*Schema, error) { + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(t, map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + return nil, fmt.Errorf("ForType(%s): %w", t, err) + } + return s, nil +} + +func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { + // Follow pointers: the schema for *T is almost the same as for T, except that + // an explicit JSON "null" is allowed for the pointer. + allowNull := false + for t.Kind() == reflect.Pointer { + allowNull = true + t = t.Elem() + } + + // Check for cycles + // User defined types have a name, so we can skip those that are natively defined + if t.Name() != "" { + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + } + + if s := schemas[t]; s != nil { + return s.CloneSchemas(), nil + } + + var ( + s = new(Schema) + err error + ) + + switch t.Kind() { + case reflect.Bool: + s.Type = "boolean" + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Uintptr: + s.Type = "integer" + + case reflect.Float32, reflect.Float64: + s.Type = "number" + + case reflect.Interface: + // Unrestricted + + case reflect.Map: + if t.Key().Kind() != reflect.String { + if ignore { + return nil, nil // ignore + } + return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) + } + if t.Key().Kind() != reflect.String { + } + s.Type = "object" + s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing map value schema: %v", err) + } + if ignore && s.AdditionalProperties == nil { + // Ignore if the element type is invalid. + return nil, nil + } + + case reflect.Slice, reflect.Array: + s.Type = "array" + s.Items, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing element schema: %v", err) + } + if ignore && s.Items == nil { + // Ignore if the element type is invalid. + return nil, nil + } + if t.Kind() == reflect.Array { + s.MinItems = Ptr(t.Len()) + s.MaxItems = Ptr(t.Len()) + } + + case reflect.String: + s.Type = "string" + + case reflect.Struct: + s.Type = "object" + // no additional properties are allowed + s.AdditionalProperties = falseSchema() + for _, field := range reflect.VisibleFields(t) { + if field.Anonymous { + continue + } + + info := fieldJSONInfo(field) + if info.omit { + continue + } + if s.Properties == nil { + s.Properties = make(map[string]*Schema) + } + fs, err := forType(field.Type, seen, ignore, schemas) + if err != nil { + return nil, err + } + if ignore && fs == nil { + // Skip fields of invalid type. + continue + } + if tag, ok := field.Tag.Lookup("jsonschema"); ok { + if tag == "" { + return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) + } + if disallowedPrefixRegexp.MatchString(tag) { + return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) + } + fs.Description = tag + } + s.Properties[info.name] = fs + if !info.settings["omitempty"] && !info.settings["omitzero"] { + s.Required = append(s.Required, info.name) + } + } + + default: + if ignore { + // Ignore. + return nil, nil + } + return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) + } + if allowNull && s.Type != "" { + s.Types = []string{"null", s.Type} + s.Type = "" + } + return s, nil +} + +// initialSchemaMap holds types from the standard library that have MarshalJSON methods. +var initialSchemaMap = make(map[reflect.Type]*Schema) + +func init() { + ss := &Schema{Type: "string"} + initialSchemaMap[reflect.TypeFor[time.Time]()] = ss + initialSchemaMap[reflect.TypeFor[slog.Level]()] = ss + initialSchemaMap[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + initialSchemaMap[reflect.TypeFor[big.Rat]()] = ss + initialSchemaMap[reflect.TypeFor[big.Float]()] = ss +} + +// Disallow jsonschema tag values beginning "WORD=", for future expansion. +var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go new file mode 100644 index 000000000..ed1b16991 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go @@ -0,0 +1,146 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JSON Pointers. +// A JSON Pointer is a path that refers to one JSON value within another. +// If the path is empty, it refers to the root value. +// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", +// selecting successive properties (for JSON objects) or items (for JSON arrays). +// For example, when applied to this JSON value: +// { +// "points": [ +// {"x": 1, "y": 2}, +// {"x": 3, "y": 4} +// ] +// } +// +// the JSON Pointer "/points/1/x" refers to the number 3. +// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. + +package jsonschema + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +var ( + jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") + jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") +) + +func escapeJSONPointerSegment(s string) string { + return jsonPointerEscaper.Replace(s) +} + +func unescapeJSONPointerSegment(s string) string { + return jsonPointerUnescaper.Replace(s) +} + +// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't +// convert strings to numbers, because that depends on the traversal: a segment +// is treated as a number when applied to an array, but a string when applied to +// an object. See section 4 of the spec. +func parseJSONPointer(ptr string) (segments []string, err error) { + if ptr == "" { + return nil, nil + } + if ptr[0] != '/' { + return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) + } + // Unlike file paths, consecutive slashes are not coalesced. + // Split is nicer than Cut here, because it gets a final "/" right. + segments = strings.Split(ptr[1:], "/") + if strings.Contains(ptr, "~") { + // Undo the simple escaping rules that allow one to include a slash in a segment. + for i := range segments { + segments[i] = unescapeJSONPointerSegment(segments[i]) + } + } + return segments, nil +} + +// dereferenceJSONPointer returns the Schema that sptr points to within s, +// or an error if none. +// This implementation suffices for JSON Schema: pointers are applied only to Schemas, +// and refer only to Schemas. +func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { + defer wrapf(&err, "JSON Pointer %q", sptr) + + segments, err := parseJSONPointer(sptr) + if err != nil { + return nil, err + } + v := reflect.ValueOf(s) + for _, seg := range segments { + switch v.Kind() { + case reflect.Pointer: + v = v.Elem() + if !v.IsValid() { + return nil, errors.New("navigated to nil reference") + } + fallthrough // if valid, can only be a pointer to a Schema + + case reflect.Struct: + // The segment must refer to a field in a Schema. + if v.Type() != reflect.TypeFor[Schema]() { + return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) + } + v = lookupSchemaField(v, seg) + if !v.IsValid() { + return nil, fmt.Errorf("no schema field %q", seg) + } + case reflect.Slice, reflect.Array: + // The segment must be an integer without leading zeroes that refers to an item in the + // slice or array. + if seg == "-" { + return nil, errors.New("the JSON Pointer array segment '-' is not supported") + } + if len(seg) > 1 && seg[0] == '0' { + return nil, fmt.Errorf("segment %q has leading zeroes", seg) + } + n, err := strconv.Atoi(seg) + if err != nil { + return nil, fmt.Errorf("invalid int: %q", seg) + } + if n < 0 || n >= v.Len() { + return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) + } + v = v.Index(n) + // Cannot be invalid. + case reflect.Map: + // The segment must be a key in the map. + v = v.MapIndex(reflect.ValueOf(seg)) + if !v.IsValid() { + return nil, fmt.Errorf("no key %q in map", seg) + } + default: + return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) + } + } + if s, ok := v.Interface().(*Schema); ok { + return s, nil + } + return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) +} + +// lookupSchemaField returns the value of the field with the given name in v, +// or the zero value if there is no such field or it is not of type Schema or *Schema. +func lookupSchemaField(v reflect.Value, name string) reflect.Value { + if name == "type" { + // The "type" keyword may refer to Type or Types. + // At most one will be non-zero. + if t := v.FieldByName("Type"); !t.IsZero() { + return t + } + return v.FieldByName("Types") + } + if sf, ok := schemaFieldMap[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go new file mode 100644 index 000000000..ece9be880 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go @@ -0,0 +1,548 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file deals with preparing a schema for validation, including various checks, +// optimizations, and the resolution of cross-schema references. + +package jsonschema + +import ( + "errors" + "fmt" + "net/url" + "reflect" + "regexp" + "strings" +) + +// A Resolved consists of a [Schema] along with associated information needed to +// validate documents against it. +// A Resolved has been validated against its meta-schema, and all its references +// (the $ref and $dynamicRef keywords) have been resolved to their referenced Schemas. +// Call [Schema.Resolve] to obtain a Resolved from a Schema. +type Resolved struct { + root *Schema + // map from $ids to their schemas + resolvedURIs map[string]*Schema + // map from schemas to additional info computed during resolution + resolvedInfos map[*Schema]*resolvedInfo +} + +func newResolved(s *Schema) *Resolved { + return &Resolved{ + root: s, + resolvedURIs: map[string]*Schema{}, + resolvedInfos: map[*Schema]*resolvedInfo{}, + } +} + +// resolvedInfo holds information specific to a schema that is computed by [Schema.Resolve]. +type resolvedInfo struct { + s *Schema + // The JSON Pointer path from the root schema to here. + // Used in errors. + path string + // The schema's base schema. + // If the schema is the root or has an ID, its base is itself. + // Otherwise, its base is the innermost enclosing schema whose base + // is itself. + // Intuitively, a base schema is one that can be referred to with a + // fragmentless URI. + base *Schema + // The URI for the schema, if it is the root or has an ID. + // Otherwise nil. + // Invariants: + // s.base.uri != nil. + // s.base == s <=> s.uri != nil + uri *url.URL + // The schema to which Ref refers. + resolvedRef *Schema + + // If the schema has a dynamic ref, exactly one of the next two fields + // will be non-zero after successful resolution. + // The schema to which the dynamic ref refers when it acts lexically. + resolvedDynamicRef *Schema + // The anchor to look up on the stack when the dynamic ref acts dynamically. + dynamicRefAnchor string + + // The following fields are independent of arguments to Schema.Resolved, + // so they could live on the Schema. We put them here for simplicity. + + // The set of required properties. + isRequired map[string]bool + + // Compiled regexps. + pattern *regexp.Regexp + patternProperties map[*regexp.Regexp]*Schema + + // Map from anchors to subschemas. + anchors map[string]anchorInfo +} + +// Schema returns the schema that was resolved. +// It must not be modified. +func (r *Resolved) Schema() *Schema { return r.root } + +// schemaString returns a short string describing the schema. +func (r *Resolved) schemaString(s *Schema) string { + if s.ID != "" { + return s.ID + } + info := r.resolvedInfos[s] + if info.path != "" { + return info.path + } + return "" +} + +// A Loader reads and unmarshals the schema at uri, if any. +type Loader func(uri *url.URL) (*Schema, error) + +// ResolveOptions are options for [Schema.Resolve]. +type ResolveOptions struct { + // BaseURI is the URI relative to which the root schema should be resolved. + // If non-empty, must be an absolute URI (one that starts with a scheme). + // It is resolved (in the URI sense; see [url.ResolveReference]) with root's + // $id property. + // If the resulting URI is not absolute, then the schema cannot contain + // relative URI references. + BaseURI string + // Loader loads schemas that are referred to by a $ref but are not under the + // root schema (remote references). + // If nil, resolving a remote reference will return an error. + Loader Loader + // ValidateDefaults determines whether to validate values of "default" keywords + // against their schemas. + // The [JSON Schema specification] does not require this, but it is recommended + // if defaults will be used. + // + // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations + ValidateDefaults bool +} + +// Resolve resolves all references within the schema and performs other tasks that +// prepare the schema for validation. +// If opts is nil, the default values are used. +// The schema must not be changed after Resolve is called. +// The same schema may be resolved multiple times. +func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { + // There are up to five steps required to prepare a schema to validate. + // 1. Load: read the schema from somewhere and unmarshal it. + // This schema (root) may have been loaded or created in memory, but other schemas that + // come into the picture in step 4 will be loaded by the given loader. + // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. + // Precompute some values along the way. + // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and + // resolve (in the URI sense) all identifiers and anchors with their bases. This step results + // in a map from URIs to schemas within root. + // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. + // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. + r := &resolver{loaded: map[string]*Resolved{}} + if opts != nil { + r.opts = *opts + } + var base *url.URL + if r.opts.BaseURI == "" { + base = &url.URL{} // so we can call ResolveReference on it + } else { + var err error + base, err = url.Parse(r.opts.BaseURI) + if err != nil { + return nil, fmt.Errorf("parsing base URI: %w", err) + } + } + + if r.opts.Loader == nil { + r.opts.Loader = func(uri *url.URL) (*Schema, error) { + return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") + } + } + + resolved, err := r.resolve(root, base) + if err != nil { + return nil, err + } + if r.opts.ValidateDefaults { + if err := resolved.validateDefaults(); err != nil { + return nil, err + } + } + // TODO: before we return, throw away anything we don't need for validation. + return resolved, nil +} + +// A resolver holds the state for resolution. +type resolver struct { + opts ResolveOptions + // A cache of loaded and partly resolved schemas. (They may not have had their + // refs resolved.) The cache ensures that the loader will never be called more + // than once with the same URI, and that reference cycles are handled properly. + loaded map[string]*Resolved +} + +func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { + if baseURI.Fragment != "" { + return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) + } + rs := newResolved(s) + + if err := s.check(rs.resolvedInfos); err != nil { + return nil, err + } + + if err := resolveURIs(rs, baseURI); err != nil { + return nil, err + } + + // Remember the schema by both the URI we loaded it from and its canonical name, + // which may differ if the schema has an $id. + // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. + r.loaded[baseURI.String()] = rs + r.loaded[rs.resolvedInfos[s].uri.String()] = rs + + if err := r.resolveRefs(rs); err != nil { + return nil, err + } + return rs, nil +} + +func (root *Schema) check(infos map[*Schema]*resolvedInfo) error { + // Check for structural validity. Do this first and fail fast: + // bad structure will cause other code to panic. + if err := root.checkStructure(infos); err != nil { + return err + } + + var errs []error + report := func(err error) { errs = append(errs, err) } + + for ss := range root.all() { + ss.checkLocal(report, infos) + } + return errors.Join(errs...) +} + +// checkStructure verifies that root and its subschemas form a tree. +// It also assigns each schema a unique path, to improve error messages. +func (root *Schema) checkStructure(infos map[*Schema]*resolvedInfo) error { + assert(len(infos) == 0, "non-empty infos") + + var check func(reflect.Value, []byte) error + check = func(v reflect.Value, path []byte) error { + // For the purpose of error messages, the root schema has path "root" + // and other schemas' paths are their JSON Pointer from the root. + p := "root" + if len(path) > 0 { + p = string(path) + } + s := v.Interface().(*Schema) + if s == nil { + return fmt.Errorf("jsonschema: schema at %s is nil", p) + } + if info, ok := infos[s]; ok { + // We've seen s before. + // The schema graph at root is not a tree, but it needs to + // be because a schema's base must be unique. + // A cycle would also put Schema.all into an infinite recursion. + return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", + root, info.path, p) + } + infos[s] = &resolvedInfo{s: s, path: p} + + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. + // A nil is valid: it just means the field isn't present. + if !fv.IsNil() { + if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { + return err + } + } + + case schemaSliceType: + for i := range fv.Len() { + if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { + return err + } + } + + case schemaMapType: + iter := fv.MapRange() + for iter.Next() { + key := escapeJSONPointerSegment(iter.Key().String()) + if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { + return err + } + } + } + + } + return nil + } + + return check(reflect.ValueOf(root), make([]byte, 0, 256)) +} + +// checkLocal checks s for validity, independently of other schemas it may refer to. +// Since checking a regexp involves compiling it, checkLocal saves those compiled regexps +// in the schema for later use. +// It appends the errors it finds to errs. +func (s *Schema) checkLocal(report func(error), infos map[*Schema]*resolvedInfo) { + addf := func(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) + } + + if s == nil { + addf("nil subschema") + return + } + if err := s.basicChecks(); err != nil { + report(err) + return + } + + // TODO: validate the schema's properties, + // ideally by jsonschema-validating it against the meta-schema. + + // Some properties are present so that Schemas can round-trip, but we do not + // validate them. + // Currently, it's just the $vocabulary property. + // As a special case, we can validate the 2020-12 meta-schema. + if s.Vocabulary != nil && s.Schema != draft202012 { + addf("cannot validate a schema with $vocabulary") + } + + info := infos[s] + + // Check and compile regexps. + if s.Pattern != "" { + re, err := regexp.Compile(s.Pattern) + if err != nil { + addf("pattern: %v", err) + } else { + info.pattern = re + } + } + if len(s.PatternProperties) > 0 { + info.patternProperties = map[*regexp.Regexp]*Schema{} + for reString, subschema := range s.PatternProperties { + re, err := regexp.Compile(reString) + if err != nil { + addf("patternProperties[%q]: %v", reString, err) + continue + } + info.patternProperties[re] = subschema + } + } + + // Build a set of required properties, to avoid quadratic behavior when validating + // a struct. + if len(s.Required) > 0 { + info.isRequired = map[string]bool{} + for _, r := range s.Required { + info.isRequired[r] = true + } + } +} + +// resolveURIs resolves the ids and anchors in all the schemas of root, relative +// to baseURI. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section +// 8.2.1. +// +// Every schema has a base URI and a parent base URI. +// +// The parent base URI is the base URI of the lexically enclosing schema, or for +// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. +// +// If the schema has no $id property, the base URI of a schema is that of its parent. +// If the schema does have an $id, it must be a URI, possibly relative. The schema's +// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against +// the parent base. +// +// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): +// +// { +// allOf: [ +// {$id: "sub1.json", minLength: 5}, +// {$id: "http://b.com", minimum: 10}, +// {not: {maximum: 20}} +// ] +// } +// +// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. +// +// schema base URI +// root http://a.com/root.json +// allOf/0 http://a.com/sub1.json +// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) +// allOf/2 http://a.com/root.json (inherited from parent) +// allOf/2/not http://a.com/root.json (inherited from parent) +func resolveURIs(rs *Resolved, baseURI *url.URL) error { + var resolve func(s, base *Schema) error + resolve = func(s, base *Schema) error { + info := rs.resolvedInfos[s] + baseInfo := rs.resolvedInfos[base] + + // ids are scoped to the root. + if s.ID != "" { + // A non-empty ID establishes a new base. + idURI, err := url.Parse(s.ID) + if err != nil { + return err + } + if idURI.Fragment != "" { + return fmt.Errorf("$id %s must not have a fragment", s.ID) + } + // The base URI for this schema is its $id resolved against the parent base. + info.uri = baseInfo.uri.ResolveReference(idURI) + if !info.uri.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %q)", s.ID, baseInfo.uri) + } + rs.resolvedURIs[info.uri.String()] = s + base = s // needed for anchors + baseInfo = rs.resolvedInfos[base] + } + info.base = base + + // Anchors and dynamic anchors are URI fragments that are scoped to their base. + // We treat them as keys in a map stored within the schema. + setAnchor := func(anchor string, dynamic bool) error { + if anchor != "" { + if _, ok := baseInfo.anchors[anchor]; ok { + return fmt.Errorf("duplicate anchor %q in %s", anchor, baseInfo.uri) + } + if baseInfo.anchors == nil { + baseInfo.anchors = map[string]anchorInfo{} + } + baseInfo.anchors[anchor] = anchorInfo{s, dynamic} + } + return nil + } + + setAnchor(s.Anchor, false) + setAnchor(s.DynamicAnchor, true) + + for c := range s.children() { + if err := resolve(c, base); err != nil { + return err + } + } + return nil + } + + // Set the root URI to the base for now. If the root has an $id, this will change. + rs.resolvedInfos[rs.root].uri = baseURI + // The original base, even if changed, is still a valid way to refer to the root. + rs.resolvedURIs[baseURI.String()] = rs.root + + return resolve(rs.root, rs.root) +} + +// resolveRefs replaces every ref in the schemas with the schema it refers to. +// A reference that doesn't resolve within the schema may refer to some other schema +// that needs to be loaded. +func (r *resolver) resolveRefs(rs *Resolved) error { + for s := range rs.root.all() { + info := rs.resolvedInfos[s] + if s.Ref != "" { + refSchema, _, err := r.resolveRef(rs, s, s.Ref) + if err != nil { + return err + } + // Whether or not the anchor referred to by $ref fragment is dynamic, + // the ref still treats it lexically. + info.resolvedRef = refSchema + } + if s.DynamicRef != "" { + refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) + if err != nil { + return err + } + if frag != "" { + // The dynamic ref's fragment points to a dynamic anchor. + // We must resolve the fragment at validation time. + info.dynamicRefAnchor = frag + } else { + // There is no dynamic anchor in the lexically referenced schema, + // so the dynamic ref behaves like a lexical ref. + info.resolvedDynamicRef = refSchema + } + } + } + return nil +} + +// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. +func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { + refURI, err := url.Parse(ref) + if err != nil { + return nil, "", err + } + // URI-resolve the ref against the current base URI to get a complete URI. + base := rs.resolvedInfos[s].base + refURI = rs.resolvedInfos[base].uri.ResolveReference(refURI) + // The non-fragment part of a ref URI refers to the base URI of some schema. + // This part is the same for dynamic refs too: their non-fragment part resolves + // lexically. + u := *refURI + u.Fragment = "" + fraglessRefURI := &u + // Look it up locally. + referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] + if referencedSchema == nil { + // The schema is remote. Maybe we've already loaded it. + // We assume that the non-fragment part of refURI refers to a top-level schema + // document. That is, we don't support the case exemplified by + // http://foo.com/bar.json/baz, where the document is in bar.json and + // the reference points to a subschema within it. + // TODO: support that case. + if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { + referencedSchema = lrs.root + } else { + // Try to load the schema. + ls, err := r.opts.Loader(fraglessRefURI) + if err != nil { + return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) + } + lrs, err := r.resolve(ls, fraglessRefURI) + if err != nil { + return nil, "", err + } + referencedSchema = lrs.root + assert(referencedSchema != nil, "nil referenced schema") + // Copy the resolvedInfos from lrs into rs, without overwriting + // (hence we can't use maps.Insert). + for s, i := range lrs.resolvedInfos { + if rs.resolvedInfos[s] == nil { + rs.resolvedInfos[s] = i + } + } + } + } + + frag := refURI.Fragment + // Look up frag in refSchema. + // frag is either a JSON Pointer or the name of an anchor. + // A JSON Pointer is either the empty string or begins with a '/', + // whereas anchors are always non-empty strings that don't contain slashes. + if frag != "" && !strings.HasPrefix(frag, "/") { + resInfo := rs.resolvedInfos[referencedSchema] + info, found := resInfo.anchors[frag] + + if !found { + return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) + } + if info.dynamic { + dynamicFragment = frag + } + return info.schema, dynamicFragment, nil + } + // frag is a JSON Pointer. + s, err = dereferenceJSONPointer(referencedSchema, frag) + return s, "", err +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/schema.go b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go new file mode 100644 index 000000000..3b4db9a6e --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go @@ -0,0 +1,436 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/json" + "errors" + "fmt" + "iter" + "maps" + "math" + "reflect" + "slices" +) + +// A Schema is a JSON schema object. +// It corresponds to the 2020-12 draft, as described in https://json-schema.org/draft/2020-12, +// specifically: +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-01 +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01 +// +// A Schema value may have non-zero values for more than one field: +// all relevant non-zero fields are used for validation. +// There is one exception to provide more Go type-safety: the Type and Types fields +// are mutually exclusive. +// +// Since this struct is a Go representation of a JSON value, it inherits JSON's +// distinction between nil and empty. Nil slices and maps are considered absent, +// but empty ones are present and affect validation. For example, +// +// Schema{Enum: nil} +// +// is equivalent to an empty schema, so it validates every instance. But +// +// Schema{Enum: []any{}} +// +// requires equality to some slice element, so it vacuously rejects every instance. +type Schema struct { + // core + ID string `json:"$id,omitempty"` + Schema string `json:"$schema,omitempty"` + Ref string `json:"$ref,omitempty"` + Comment string `json:"$comment,omitempty"` + Defs map[string]*Schema `json:"$defs,omitempty"` + // definitions is deprecated but still allowed. It is a synonym for $defs. + Definitions map[string]*Schema `json:"definitions,omitempty"` + + Anchor string `json:"$anchor,omitempty"` + DynamicAnchor string `json:"$dynamicAnchor,omitempty"` + DynamicRef string `json:"$dynamicRef,omitempty"` + Vocabulary map[string]bool `json:"$vocabulary,omitempty"` + + // metadata + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Default json.RawMessage `json:"default,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + WriteOnly bool `json:"writeOnly,omitempty"` + Examples []any `json:"examples,omitempty"` + + // validation + // Use Type for a single type, or Types for multiple types; never both. + Type string `json:"-"` + Types []string `json:"-"` + Enum []any `json:"enum,omitempty"` + // Const is *any because a JSON null (Go nil) is a valid value. + Const *any `json:"const,omitempty"` + MultipleOf *float64 `json:"multipleOf,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` + ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` + Pattern string `json:"pattern,omitempty"` + + // arrays + PrefixItems []*Schema `json:"prefixItems,omitempty"` + Items *Schema `json:"items,omitempty"` + MinItems *int `json:"minItems,omitempty"` + MaxItems *int `json:"maxItems,omitempty"` + AdditionalItems *Schema `json:"additionalItems,omitempty"` + UniqueItems bool `json:"uniqueItems,omitempty"` + Contains *Schema `json:"contains,omitempty"` + MinContains *int `json:"minContains,omitempty"` // *int, not int: default is 1, not 0 + MaxContains *int `json:"maxContains,omitempty"` + UnevaluatedItems *Schema `json:"unevaluatedItems,omitempty"` + + // objects + MinProperties *int `json:"minProperties,omitempty"` + MaxProperties *int `json:"maxProperties,omitempty"` + Required []string `json:"required,omitempty"` + DependentRequired map[string][]string `json:"dependentRequired,omitempty"` + Properties map[string]*Schema `json:"properties,omitempty"` + PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` + AdditionalProperties *Schema `json:"additionalProperties,omitempty"` + PropertyNames *Schema `json:"propertyNames,omitempty"` + UnevaluatedProperties *Schema `json:"unevaluatedProperties,omitempty"` + + // logic + AllOf []*Schema `json:"allOf,omitempty"` + AnyOf []*Schema `json:"anyOf,omitempty"` + OneOf []*Schema `json:"oneOf,omitempty"` + Not *Schema `json:"not,omitempty"` + + // conditional + If *Schema `json:"if,omitempty"` + Then *Schema `json:"then,omitempty"` + Else *Schema `json:"else,omitempty"` + DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` + + // other + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 + ContentEncoding string `json:"contentEncoding,omitempty"` + ContentMediaType string `json:"contentMediaType,omitempty"` + ContentSchema *Schema `json:"contentSchema,omitempty"` + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 + Format string `json:"format,omitempty"` + + // Extra allows for additional keywords beyond those specified. + Extra map[string]any `json:"-"` +} + +// falseSchema returns a new Schema tree that fails to validate any value. +func falseSchema() *Schema { + return &Schema{Not: &Schema{}} +} + +// anchorInfo records the subschema to which an anchor refers, and whether +// the anchor keyword is $anchor or $dynamicAnchor. +type anchorInfo struct { + schema *Schema + dynamic bool +} + +// String returns a short description of the schema. +func (s *Schema) String() string { + if s.ID != "" { + return s.ID + } + if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { + return fmt.Sprintf("anchor %s", a) + } + return "" +} + +// CloneSchemas returns a copy of s. +// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. +// This allows both s and s.CloneSchemas() to appear as sub-schemas of the same parent. +func (s *Schema) CloneSchemas() *Schema { + if s == nil { + return nil + } + s2 := *s + v := reflect.ValueOf(&s2) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + sscss := fv.Interface().(*Schema) + fv.Set(reflect.ValueOf(sscss.CloneSchemas())) + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + slice = slices.Clone(slice) + for i, ss := range slice { + slice[i] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(slice)) + + case schemaMapType: + m := fv.Interface().(map[string]*Schema) + m = maps.Clone(m) + for k, ss := range m { + m[k] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(m)) + } + } + return &s2 +} + +func (s *Schema) basicChecks() error { + if s.Type != "" && s.Types != nil { + return errors.New("both Type and Types are set; at most one should be") + } + if s.Defs != nil && s.Definitions != nil { + return errors.New("both Defs and Definitions are set; at most one should be") + } + return nil +} + +type schemaWithoutMethods Schema // doesn't implement json.{Unm,M}arshaler + +func (s *Schema) MarshalJSON() ([]byte, error) { + if err := s.basicChecks(); err != nil { + return nil, err + } + + // Marshal either Type or Types as "type". + var typ any + switch { + case s.Type != "": + typ = s.Type + case s.Types != nil: + typ = s.Types + } + ms := struct { + Type any `json:"type,omitempty"` + *schemaWithoutMethods + }{ + Type: typ, + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + bs, err := marshalStructWithMap(&ms, "Extra") + if err != nil { + return nil, err + } + // Marshal {} as true and {"not": {}} as false. + // It is wasteful to do this here instead of earlier, but much easier. + switch { + case bytes.Equal(bs, []byte(`{}`)): + bs = []byte("true") + case bytes.Equal(bs, []byte(`{"not":true}`)): + bs = []byte("false") + } + return bs, nil +} + +func (s *Schema) UnmarshalJSON(data []byte) error { + // A JSON boolean is a valid schema. + var b bool + if err := json.Unmarshal(data, &b); err == nil { + if b { + // true is the empty schema, which validates everything. + *s = Schema{} + } else { + // false is the schema that validates nothing. + *s = *falseSchema() + } + return nil + } + + ms := struct { + Type json.RawMessage `json:"type,omitempty"` + Const json.RawMessage `json:"const,omitempty"` + MinLength *integer `json:"minLength,omitempty"` + MaxLength *integer `json:"maxLength,omitempty"` + MinItems *integer `json:"minItems,omitempty"` + MaxItems *integer `json:"maxItems,omitempty"` + MinProperties *integer `json:"minProperties,omitempty"` + MaxProperties *integer `json:"maxProperties,omitempty"` + MinContains *integer `json:"minContains,omitempty"` + MaxContains *integer `json:"maxContains,omitempty"` + + *schemaWithoutMethods + }{ + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil { + return err + } + // Unmarshal "type" as either Type or Types. + var err error + if len(ms.Type) > 0 { + switch ms.Type[0] { + case '"': + err = json.Unmarshal(ms.Type, &s.Type) + case '[': + err = json.Unmarshal(ms.Type, &s.Types) + default: + err = fmt.Errorf(`invalid value for "type": %q`, ms.Type) + } + } + if err != nil { + return err + } + + unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { + if len(raw) == 0 { + return nil + } + if bytes.Equal(raw, []byte("null")) { + *p = new(any) + return nil + } + return json.Unmarshal(raw, p) + } + + // Setting Const to a pointer to null will marshal properly, but won't + // unmarshal: the *any is set to nil, not a pointer to nil. + if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { + return err + } + + set := func(dst **int, src *integer) { + if src != nil { + *dst = Ptr(int(*src)) + } + } + + set(&s.MinLength, ms.MinLength) + set(&s.MaxLength, ms.MaxLength) + set(&s.MinItems, ms.MinItems) + set(&s.MaxItems, ms.MaxItems) + set(&s.MinProperties, ms.MinProperties) + set(&s.MaxProperties, ms.MaxProperties) + set(&s.MinContains, ms.MinContains) + set(&s.MaxContains, ms.MaxContains) + + return nil +} + +type integer int32 // for the integer-valued fields of Schema + +func (ip *integer) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + // nothing to do + return nil + } + // If there is a decimal point, src is a floating-point number. + var i int64 + if bytes.ContainsRune(data, '.') { + var f float64 + if err := json.Unmarshal(data, &f); err != nil { + return errors.New("not a number") + } + i = int64(f) + if float64(i) != f { + return errors.New("not an integer value") + } + } else { + if err := json.Unmarshal(data, &i); err != nil { + return errors.New("cannot be unmarshaled into an int") + } + } + // Ensure behavior is the same on both 32-bit and 64-bit systems. + if i < math.MinInt32 || i > math.MaxInt32 { + return errors.New("integer is out of range") + } + *ip = integer(i) + return nil +} + +// Ptr returns a pointer to a new variable whose value is x. +func Ptr[T any](x T) *T { return &x } + +// every applies f preorder to every schema under s including s. +// The second argument to f is the path to the schema appended to the argument path. +// It stops when f returns false. +func (s *Schema) every(f func(*Schema) bool) bool { + return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) +} + +// everyChild reports whether f is true for every immediate child schema of s. +func (s *Schema) everyChild(f func(*Schema) bool) bool { + v := reflect.ValueOf(s) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. A nil is valid: it just means the field isn't present. + c := fv.Interface().(*Schema) + if c != nil && !f(c) { + return false + } + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + for _, c := range slice { + if !f(c) { + return false + } + } + + case schemaMapType: + // Sort keys for determinism. + m := fv.Interface().(map[string]*Schema) + for _, k := range slices.Sorted(maps.Keys(m)) { + if !f(m[k]) { + return false + } + } + } + } + return true +} + +// all wraps every in an iterator. +func (s *Schema) all() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.every(yield) } +} + +// children wraps everyChild in an iterator. +func (s *Schema) children() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.everyChild(yield) } +} + +var ( + schemaType = reflect.TypeFor[*Schema]() + schemaSliceType = reflect.TypeFor[[]*Schema]() + schemaMapType = reflect.TypeFor[map[string]*Schema]() +) + +type structFieldInfo struct { + sf reflect.StructField + jsonName string +} + +var ( + // the visible fields of Schema that have a JSON name, sorted by that name + schemaFieldInfos []structFieldInfo + // map from JSON name to field + schemaFieldMap = map[string]reflect.StructField{} +) + +func init() { + for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { + info := fieldJSONInfo(sf) + if !info.omit { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.name}) + } + } + slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { + return cmp.Compare(i1.jsonName, i2.jsonName) + }) + for _, info := range schemaFieldInfos { + schemaFieldMap[info.jsonName] = info.sf + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/util.go b/vendor/github.com/google/jsonschema-go/jsonschema/util.go new file mode 100644 index 000000000..5cfa27dc6 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/util.go @@ -0,0 +1,463 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/binary" + "encoding/json" + "fmt" + "hash/maphash" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" +) + +// Equal reports whether two Go values representing JSON values are equal according +// to the JSON Schema spec. +// The values must not contain cycles. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-4.2.2. +// It behaves like reflect.DeepEqual, except that numbers are compared according +// to mathematical equality. +func Equal(x, y any) bool { + return equalValue(reflect.ValueOf(x), reflect.ValueOf(y)) +} + +func equalValue(x, y reflect.Value) bool { + // Copied from src/reflect/deepequal.go, omitting the visited check (because JSON + // values are trees). + if !x.IsValid() || !y.IsValid() { + return x.IsValid() == y.IsValid() + } + + // Treat numbers specially. + rx, ok1 := jsonNumber(x) + ry, ok2 := jsonNumber(y) + if ok1 && ok2 { + return rx.Cmp(ry) == 0 + } + if x.Kind() != y.Kind() { + return false + } + switch x.Kind() { + case reflect.Array: + if x.Len() != y.Len() { + return false + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Slice: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + // Special case for []byte, which is common. + if x.Type().Elem().Kind() == reflect.Uint8 && x.Type() == y.Type() { + return bytes.Equal(x.Bytes(), y.Bytes()) + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Interface: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Pointer: + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Struct: + t := x.Type() + if t != y.Type() { + return false + } + for i := range t.NumField() { + sf := t.Field(i) + if !sf.IsExported() { + continue + } + if !equalValue(x.FieldByIndex(sf.Index), y.FieldByIndex(sf.Index)) { + return false + } + } + return true + case reflect.Map: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + iter := x.MapRange() + for iter.Next() { + vx := iter.Value() + vy := y.MapIndex(iter.Key()) + if !vy.IsValid() || !equalValue(vx, vy) { + return false + } + } + return true + case reflect.Func: + if x.Type() != y.Type() { + return false + } + if x.IsNil() && y.IsNil() { + return true + } + panic("cannot compare functions") + case reflect.String: + return x.String() == y.String() + case reflect.Bool: + return x.Bool() == y.Bool() + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", x.Kind())) + } +} + +// hashValue adds v to the data hashed by h. v must not have cycles. +// hashValue panics if the value contains functions or channels, or maps whose +// key type is not string. +// It ignores unexported fields of structs. +// Calls to hashValue with the equal values (in the sense +// of [Equal]) result in the same sequence of values written to the hash. +func hashValue(h *maphash.Hash, v reflect.Value) { + // TODO: replace writes of basic types with WriteComparable in 1.24. + + writeUint := func(u uint64) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], u) + h.Write(buf[:]) + } + + var write func(reflect.Value) + write = func(v reflect.Value) { + if r, ok := jsonNumber(v); ok { + // We want 1.0 and 1 to hash the same. + // big.Rats are always normalized, so they will be. + // We could do this more efficiently by handling the int and float cases + // separately, but that's premature. + writeUint(uint64(r.Sign() + 1)) + h.Write(r.Num().Bytes()) + h.Write(r.Denom().Bytes()) + return + } + switch v.Kind() { + case reflect.Invalid: + h.WriteByte(0) + case reflect.String: + h.WriteString(v.String()) + case reflect.Bool: + if v.Bool() { + h.WriteByte(1) + } else { + h.WriteByte(0) + } + case reflect.Complex64, reflect.Complex128: + c := v.Complex() + writeUint(math.Float64bits(real(c))) + writeUint(math.Float64bits(imag(c))) + case reflect.Array, reflect.Slice: + // Although we could treat []byte more efficiently, + // JSON values are unlikely to contain them. + writeUint(uint64(v.Len())) + for i := range v.Len() { + write(v.Index(i)) + } + case reflect.Interface, reflect.Pointer: + write(v.Elem()) + case reflect.Struct: + t := v.Type() + for i := range t.NumField() { + if sf := t.Field(i); sf.IsExported() { + write(v.FieldByIndex(sf.Index)) + } + } + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + panic("map with non-string key") + } + // Sort the keys so the hash is deterministic. + keys := v.MapKeys() + // Write the length. That distinguishes between, say, two consecutive + // maps with disjoint keys from one map that has the items of both. + writeUint(uint64(len(keys))) + slices.SortFunc(keys, func(x, y reflect.Value) int { return cmp.Compare(x.String(), y.String()) }) + for _, k := range keys { + write(k) + write(v.MapIndex(k)) + } + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", v.Kind())) + } + } + + write(v) +} + +// jsonNumber converts a numeric value or a json.Number to a [big.Rat]. +// If v is not a number, it returns nil, false. +func jsonNumber(v reflect.Value) (*big.Rat, bool) { + r := new(big.Rat) + switch { + case !v.IsValid(): + return nil, false + case v.CanInt(): + r.SetInt64(v.Int()) + case v.CanUint(): + r.SetUint64(v.Uint()) + case v.CanFloat(): + r.SetFloat64(v.Float()) + default: + jn, ok := v.Interface().(json.Number) + if !ok { + return nil, false + } + if _, ok := r.SetString(jn.String()); !ok { + // This can fail in rare cases; for example, "1e9999999". + // That is a valid JSON number, since the spec puts no limit on the size + // of the exponent. + return nil, false + } + } + return r, true +} + +// jsonType returns a string describing the type of the JSON value, +// as described in the JSON Schema specification: +// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1. +// It returns "", false if the value is not valid JSON. +func jsonType(v reflect.Value) (string, bool) { + if !v.IsValid() { + // Not v.IsNil(): a nil []any is still a JSON array. + return "null", true + } + if v.CanInt() || v.CanUint() { + return "integer", true + } + if v.CanFloat() { + if _, f := math.Modf(v.Float()); f == 0 { + return "integer", true + } + return "number", true + } + switch v.Kind() { + case reflect.Bool: + return "boolean", true + case reflect.String: + return "string", true + case reflect.Slice, reflect.Array: + return "array", true + case reflect.Map, reflect.Struct: + return "object", true + default: + return "", false + } +} + +func assert(cond bool, msg string) { + if !cond { + panic("assertion failed: " + msg) + } +} + +// marshalStructWithMap marshals its first argument to JSON, treating the field named +// mapField as an embedded map. The first argument must be a pointer to +// a struct. The underlying type of mapField must be a map[string]any, and it must have +// a "-" json tag, meaning it will not be marshaled. +// +// For example, given this struct: +// +// type S struct { +// A int +// Extra map[string] any `json:"-"` +// } +// +// and this value: +// +// s := S{A: 1, Extra: map[string]any{"B": 2}} +// +// the call marshalJSONWithMap(s, "Extra") would return +// +// {"A": 1, "B": 2} +// +// It is an error if the map contains the same key as another struct field's +// JSON name. +// +// marshalStructWithMap calls json.Marshal on a value of type T, so T must not +// have a MarshalJSON method that calls this function, on pain of infinite regress. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here the function requires `-` json tag, does not clear the mapField map, +// and handles embedded struct due to the implementation of jsonNames in this package. +// +// TODO: avoid this restriction on T by forcing it to marshal in a default way. +// See https://go.dev/play/p/EgXKJHxEx_R. +func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { + // Marshal the struct and the map separately, and concatenate the bytes. + // This strategy is dramatically less complicated than + // constructing a synthetic struct or map with the combined keys. + if s == nil { + return []byte("null"), nil + } + s2 := *s + vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) + mapVal := vMapField.Interface().(map[string]any) + + // Check for duplicates. + names := jsonNames(reflect.TypeFor[T]()) + for key := range mapVal { + if names[key] { + return nil, fmt.Errorf("map key %q duplicates struct field", key) + } + } + + structBytes, err := json.Marshal(s2) + if err != nil { + return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) + } + if len(mapVal) == 0 { + return structBytes, nil + } + mapBytes, err := json.Marshal(mapVal) + if err != nil { + return nil, err + } + if len(structBytes) == 2 { // must be "{}" + return mapBytes, nil + } + // "{X}" + "{Y}" => "{X,Y}" + res := append(structBytes[:len(structBytes)-1], ',') + res = append(res, mapBytes[1:]...) + return res, nil +} + +// unmarshalStructWithMap is the inverse of marshalStructWithMap. +// T has the same restrictions as in that function. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here jsonNames also returns fields from embedded structs, hence this function +// handles embedded structs as well. +func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { + // Unmarshal into the struct, ignoring unknown fields. + if err := json.Unmarshal(data, v); err != nil { + return err + } + // Unmarshal into the map. + m := map[string]any{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + // Delete from the map the fields of the struct. + for n := range jsonNames(reflect.TypeFor[T]()) { + delete(m, n) + } + if len(m) != 0 { + reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) + } + return nil +} + +var jsonNamesMap sync.Map // from reflect.Type to map[string]bool + +// jsonNames returns the set of JSON object keys that t will marshal into, +// including fields from embedded structs in t. +// t must be a struct type. +// +// Note that there is a similar function in mcp/util.go, but they are not the same +// Here the function recurses over embedded structs and includes fields from them. +func jsonNames(t reflect.Type) map[string]bool { + // Lock not necessary: at worst we'll duplicate work. + if val, ok := jsonNamesMap.Load(t); ok { + return val.(map[string]bool) + } + m := map[string]bool{} + for i := range t.NumField() { + field := t.Field(i) + // handle embedded structs + if field.Anonymous { + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + for n := range jsonNames(fieldType) { + m[n] = true + } + continue + } + info := fieldJSONInfo(field) + if !info.omit { + m[info.name] = true + } + } + jsonNamesMap.Store(t, m) + return m +} + +type jsonInfo struct { + omit bool // unexported or first tag element is "-" + name string // Go field name or first tag element. Empty if omit is true. + settings map[string]bool // "omitempty", "omitzero", etc. +} + +// fieldJSONInfo reports information about how encoding/json +// handles the given struct field. +// If the field is unexported, jsonInfo.omit is true and no other jsonInfo field +// is populated. +// If the field is exported and has no tag, then name is the field's name and all +// other fields are false. +// Otherwise, the information is obtained from the tag. +func fieldJSONInfo(f reflect.StructField) jsonInfo { + if !f.IsExported() { + return jsonInfo{omit: true} + } + info := jsonInfo{name: f.Name} + if tag, ok := f.Tag.Lookup("json"); ok { + name, rest, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return jsonInfo{omit: true} + } + if name != "" { + info.name = name + } + if len(rest) > 0 { + info.settings = map[string]bool{} + for _, s := range strings.Split(rest, ",") { + info.settings[s] = true + } + } + } + return info +} + +// wrapf wraps *errp with the given formatted message if *errp is not nil. +func wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/validate.go b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go new file mode 100644 index 000000000..b895bbd41 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go @@ -0,0 +1,789 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "encoding/json" + "errors" + "fmt" + "hash/maphash" + "iter" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" + "unicode/utf8" +) + +// The value of the "$schema" keyword for the version that we can validate. +const draft202012 = "https://json-schema.org/draft/2020-12/schema" + +// Validate validates the instance, which must be a JSON value, against the schema. +// It returns nil if validation is successful or an error if it is not. +// If the schema type is "object", instance can be a map[string]any or a struct. +func (rs *Resolved) Validate(instance any) error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + return st.validate(reflect.ValueOf(instance), st.rs.root, nil) +} + +// validateDefaults walks the schema tree. If it finds a default, it validates it +// against the schema containing it. +// +// TODO(jba): account for dynamic refs. This algorithm simple-mindedly +// treats each schema with a default as its own root. +func (rs *Resolved) validateDefaults() error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + for s := range rs.root.all() { + // We checked for nil schemas in [Schema.Resolve]. + assert(s != nil, "nil schema") + if s.DynamicRef != "" { + return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", rs.schemaString(s)) + } + if s.Default != nil { + var d any + if err := json.Unmarshal(s.Default, &d); err != nil { + return fmt.Errorf("unmarshaling default value of schema %s: %w", rs.schemaString(s), err) + } + if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { + return err + } + } + } + return nil +} + +// state is the state of single call to ResolvedSchema.Validate. +type state struct { + rs *Resolved + // stack holds the schemas from recursive calls to validate. + // These are the "dynamic scopes" used to resolve dynamic references. + // https://json-schema.org/draft/2020-12/json-schema-core#scopes + stack []*Schema +} + +// validate validates the reflected value of the instance. +func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { + defer wrapf(&err, "validating %s", st.rs.schemaString(schema)) + + // Maintain a stack for dynamic schema resolution. + st.stack = append(st.stack, schema) // push + defer func() { + st.stack = st.stack[:len(st.stack)-1] // pop + }() + + // We checked for nil schemas in [Schema.Resolve]. + assert(schema != nil, "nil schema") + + // Step through interfaces and pointers. + for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { + instance = instance.Elem() + } + + schemaInfo := st.rs.resolvedInfos[schema] + + // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 + if schema.Type != "" || schema.Types != nil { + gotType, ok := jsonType(instance) + if !ok { + return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) + } + if schema.Type != "" { + // "number" subsumes integers + if !(gotType == schema.Type || + gotType == "integer" && schema.Type == "number") { + return fmt.Errorf("type: %v has type %q, want %q", instance, gotType, schema.Type) + } + } else { + if !(slices.Contains(schema.Types, gotType) || (gotType == "integer" && slices.Contains(schema.Types, "number"))) { + return fmt.Errorf("type: %v has type %q, want one of %q", + instance, gotType, strings.Join(schema.Types, ", ")) + } + } + } + // enum: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.2 + if schema.Enum != nil { + ok := false + for _, e := range schema.Enum { + if equalValue(reflect.ValueOf(e), instance) { + ok = true + break + } + } + if !ok { + return fmt.Errorf("enum: %v does not equal any of: %v", instance, schema.Enum) + } + } + + // const: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.3 + if schema.Const != nil { + if !equalValue(reflect.ValueOf(*schema.Const), instance) { + return fmt.Errorf("const: %v does not equal %v", instance, *schema.Const) + } + } + + // numbers: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.2 + if schema.MultipleOf != nil || schema.Minimum != nil || schema.Maximum != nil || schema.ExclusiveMinimum != nil || schema.ExclusiveMaximum != nil { + n, ok := jsonNumber(instance) + if ok { // these keywords don't apply to non-numbers + if schema.MultipleOf != nil { + // TODO: validate MultipleOf as non-zero. + // The test suite assumes floats. + nf, _ := n.Float64() // don't care if it's exact or not + if _, f := math.Modf(nf / *schema.MultipleOf); f != 0 { + return fmt.Errorf("multipleOf: %s is not a multiple of %f", n, *schema.MultipleOf) + } + } + + m := new(big.Rat) // reuse for all of the following + cmp := func(f float64) int { return n.Cmp(m.SetFloat64(f)) } + + if schema.Minimum != nil && cmp(*schema.Minimum) < 0 { + return fmt.Errorf("minimum: %s is less than %f", n, *schema.Minimum) + } + if schema.Maximum != nil && cmp(*schema.Maximum) > 0 { + return fmt.Errorf("maximum: %s is greater than %f", n, *schema.Maximum) + } + if schema.ExclusiveMinimum != nil && cmp(*schema.ExclusiveMinimum) <= 0 { + return fmt.Errorf("exclusiveMinimum: %s is less than or equal to %f", n, *schema.ExclusiveMinimum) + } + if schema.ExclusiveMaximum != nil && cmp(*schema.ExclusiveMaximum) >= 0 { + return fmt.Errorf("exclusiveMaximum: %s is greater than or equal to %f", n, *schema.ExclusiveMaximum) + } + } + } + + // strings: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.3 + if instance.Kind() == reflect.String && (schema.MinLength != nil || schema.MaxLength != nil || schema.Pattern != "") { + str := instance.String() + n := utf8.RuneCountInString(str) + if schema.MinLength != nil { + if m := *schema.MinLength; n < m { + return fmt.Errorf("minLength: %q contains %d Unicode code points, fewer than %d", str, n, m) + } + } + if schema.MaxLength != nil { + if m := *schema.MaxLength; n > m { + return fmt.Errorf("maxLength: %q contains %d Unicode code points, more than %d", str, n, m) + } + } + + if schema.Pattern != "" && !schemaInfo.pattern.MatchString(str) { + return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) + } + } + + var anns annotations // all the annotations for this call and child calls + + // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 + if schema.Ref != "" { + if err := st.validate(instance, schemaInfo.resolvedRef, &anns); err != nil { + return err + } + } + + // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 + if schema.DynamicRef != "" { + // The ref behaves lexically or dynamically, but not both. + assert((schemaInfo.resolvedDynamicRef == nil) != (schemaInfo.dynamicRefAnchor == ""), + "DynamicRef not resolved properly") + if schemaInfo.resolvedDynamicRef != nil { + // Same as $ref. + if err := st.validate(instance, schemaInfo.resolvedDynamicRef, &anns); err != nil { + return err + } + } else { + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + var dynamicSchema *Schema + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[schemaInfo.dynamicRefAnchor] + if ok && info.dynamic { + dynamicSchema = info.schema + break + } + } + if dynamicSchema == nil { + return fmt.Errorf("missing dynamic anchor %q", schemaInfo.dynamicRefAnchor) + } + if err := st.validate(instance, dynamicSchema, &anns); err != nil { + return err + } + } + } + + // logic + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 + // These must happen before arrays and objects because if they evaluate an item or property, + // then the unevaluatedItems/Properties schemas don't apply to it. + // See https://json-schema.org/draft/2020-12/json-schema-core#section-11.2, paragraph 4. + // + // If any of these fail, then validation fails, even if there is an unevaluatedXXX + // keyword in the schema. The spec is unclear about this, but that is the intention. + + valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } + + if schema.AllOf != nil { + for _, ss := range schema.AllOf { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + if schema.AnyOf != nil { + // We must visit them all, to collect annotations. + ok := false + for _, ss := range schema.AnyOf { + if valid(ss, &anns) { + ok = true + } + } + if !ok { + return fmt.Errorf("anyOf: did not validate against any of %v", schema.AnyOf) + } + } + if schema.OneOf != nil { + // Exactly one. + var okSchema *Schema + for _, ss := range schema.OneOf { + if valid(ss, &anns) { + if okSchema != nil { + return fmt.Errorf("oneOf: validated against both %v and %v", okSchema, ss) + } + okSchema = ss + } + } + if okSchema == nil { + return fmt.Errorf("oneOf: did not validate against any of %v", schema.OneOf) + } + } + if schema.Not != nil { + // Ignore annotations from "not". + if valid(schema.Not, nil) { + return fmt.Errorf("not: validated against %v", schema.Not) + } + } + if schema.If != nil { + var ss *Schema + if valid(schema.If, &anns) { + ss = schema.Then + } else { + ss = schema.Else + } + if ss != nil { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + + // arrays + // TODO(jba): consider arrays of structs. + if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 + // This validate call doesn't collect annotations for the items of the instance; they are separate + // instances in their own right. + // TODO(jba): if the test suite doesn't cover this case, add a test. For example, nested arrays. + for i, ischema := range schema.PrefixItems { + if i >= instance.Len() { + break // shorter is OK + } + if err := st.validate(instance.Index(i), ischema, nil); err != nil { + return err + } + } + anns.noteEndIndex(min(len(schema.PrefixItems), instance.Len())) + + if schema.Items != nil { + for i := len(schema.PrefixItems); i < instance.Len(); i++ { + if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { + return err + } + } + // Note that all the items in this array have been validated. + anns.allItems = true + } + + nContains := 0 + if schema.Contains != nil { + for i := range instance.Len() { + if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { + nContains++ + anns.noteIndex(i) + } + } + if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { + return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.4 + // TODO(jba): check that these next four keywords' values are integers. + if schema.MinContains != nil && schema.Contains != nil { + if m := *schema.MinContains; nContains < m { + return fmt.Errorf("minContains: contains validated %d items, less than %d", nContains, m) + } + } + if schema.MaxContains != nil && schema.Contains != nil { + if m := *schema.MaxContains; nContains > m { + return fmt.Errorf("maxContains: contains validated %d items, greater than %d", nContains, m) + } + } + if schema.MinItems != nil { + if m := *schema.MinItems; instance.Len() < m { + return fmt.Errorf("minItems: array length %d is less than %d", instance.Len(), m) + } + } + if schema.MaxItems != nil { + if m := *schema.MaxItems; instance.Len() > m { + return fmt.Errorf("maxItems: array length %d is greater than %d", instance.Len(), m) + } + } + if schema.UniqueItems { + if instance.Len() > 1 { + // Hash each item and compare the hashes. + // If two hashes differ, the items differ. + // If two hashes are the same, compare the collisions for equality. + // (The same logic as hash table lookup.) + // TODO(jba): Use container/hash.Map when it becomes available (https://go.dev/issue/69559), + hashes := map[uint64][]int{} // from hash to indices + seed := maphash.MakeSeed() + for i := range instance.Len() { + item := instance.Index(i) + var h maphash.Hash + h.SetSeed(seed) + hashValue(&h, item) + hv := h.Sum64() + if sames := hashes[hv]; len(sames) > 0 { + for _, j := range sames { + if equalValue(item, instance.Index(j)) { + return fmt.Errorf("uniqueItems: array items %d and %d are equal", i, j) + } + } + } + hashes[hv] = append(hashes[hv], i) + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-11.2 + if schema.UnevaluatedItems != nil && !anns.allItems { + // Apply this subschema to all items in the array that haven't been successfully validated. + // That includes validations by subschemas on the same instance, like allOf. + for i := anns.endIndex; i < instance.Len(); i++ { + if !anns.evaluatedIndexes[i] { + if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { + return err + } + } + } + anns.allItems = true + } + } + + // objects + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 + // Validating structs is problematic. See https://github.com/google/jsonschema-go/issues/23. + if instance.Kind() == reflect.Struct { + return errors.New("cannot validate against a struct; see https://github.com/google/jsonschema-go/issues/23 for details") + } + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + // Track the evaluated properties for just this schema, to support additionalProperties. + // If we used anns here, then we'd be including properties evaluated in subschemas + // from allOf, etc., which additionalProperties shouldn't observe. + evalProps := map[string]bool{} + for prop, subschema := range schema.Properties { + val := property(instance, prop) + if !val.IsValid() { + // It's OK if the instance doesn't have the property. + continue + } + // If the instance is a struct and an optional property has the zero + // value, then we could interpret it as present or missing. Be generous: + // assume it's missing, and thus always validates successfully. + if instance.Kind() == reflect.Struct && val.IsZero() && !schemaInfo.isRequired[prop] { + continue + } + if err := st.validate(val, subschema, nil); err != nil { + return err + } + evalProps[prop] = true + } + if len(schema.PatternProperties) > 0 { + for prop, val := range properties(instance) { + // Check every matching pattern. + for re, schema := range schemaInfo.patternProperties { + if re.MatchString(prop) { + if err := st.validate(val, schema, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + if schema.AdditionalProperties != nil { + // Special case for a better error message when additional properties is + // 'falsy' + // + // If additionalProperties is {"not":{}} (which is how we + // unmarshal "false"), we can produce a better error message that + // summarizes all the extra properties. Otherwise, we fall back to the + // default validation. + // + // Note: this is much faster than comparing with falseSchema using Equal. + isFalsy := schema.AdditionalProperties.Not != nil && reflect.ValueOf(*schema.AdditionalProperties.Not).IsZero() + if isFalsy { + var disallowed []string + for prop := range properties(instance) { + if !evalProps[prop] { + disallowed = append(disallowed, prop) + } + } + if len(disallowed) > 0 { + return fmt.Errorf("unexpected additional properties %q", disallowed) + } + } else { + // Apply to all properties not handled above. + for prop, val := range properties(instance) { + if !evalProps[prop] { + if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + anns.noteProperties(evalProps) + if schema.PropertyNames != nil { + // Note: properties unnecessarily fetches each value. We could define a propertyNames function + // if performance ever matters. + for prop := range properties(instance) { + if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { + return err + } + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 + var min, max int + if schema.MinProperties != nil || schema.MaxProperties != nil { + min, max = numPropertiesBounds(instance, schemaInfo.isRequired) + } + if schema.MinProperties != nil { + if n, m := max, *schema.MinProperties; n < m { + return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) + } + } + if schema.MaxProperties != nil { + if n, m := min, *schema.MaxProperties; n > m { + return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) + } + } + + hasProperty := func(prop string) bool { + return property(instance, prop).IsValid() + } + + missingProperties := func(props []string) []string { + var missing []string + for _, p := range props { + if !hasProperty(p) { + missing = append(missing, p) + } + } + return missing + } + + if schema.Required != nil { + if m := missingProperties(schema.Required); len(m) > 0 { + return fmt.Errorf("required: missing properties: %q", m) + } + } + if schema.DependentRequired != nil { + // "Validation succeeds if, for each name that appears in both the instance + // and as a name within this keyword's value, every item in the corresponding + // array is also the name of a property in the instance." §6.5.4 + for dprop, reqs := range schema.DependentRequired { + if hasProperty(dprop) { + if m := missingProperties(reqs); len(m) > 0 { + return fmt.Errorf("dependentRequired[%q]: missing properties %q", dprop, m) + } + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2.2.4 + if schema.DependentSchemas != nil { + // This does not collect annotations, although it seems like it should. + for dprop, ss := range schema.DependentSchemas { + if hasProperty(dprop) { + // TODO: include dependentSchemas[dprop] in the errors. + err := st.validate(instance, ss, &anns) + if err != nil { + return err + } + } + } + } + if schema.UnevaluatedProperties != nil && !anns.allProperties { + // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf + // in addition to sibling keywords. + for prop, val := range properties(instance) { + if !anns.evaluatedProperties[prop] { + if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { + return err + } + } + } + // The spec says the annotation should be the set of evaluated properties, but we can optimize + // by setting a single boolean, since after this succeeds all properties will be validated. + // See https://json-schema.slack.com/archives/CT7FF623C/p1745592564381459. + anns.allProperties = true + } + } + + if callerAnns != nil { + // Our caller wants to know what we've validated. + callerAnns.merge(&anns) + } + return nil +} + +// resolveDynamicRef returns the schema referred to by the argument schema's +// $dynamicRef value. +// It returns an error if the dynamic reference has no referent. +// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. +func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { + if schema.DynamicRef == "" { + return nil, nil + } + info := st.rs.resolvedInfos[schema] + // The ref behaves lexically or dynamically, but not both. + assert((info.resolvedDynamicRef == nil) != (info.dynamicRefAnchor == ""), + "DynamicRef not statically resolved properly") + if r := info.resolvedDynamicRef; r != nil { + // Same as $ref. + return r, nil + } + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[info.dynamicRefAnchor] + if ok && info.dynamic { + return info.schema, nil + } + } + return nil, fmt.Errorf("missing dynamic anchor %q", info.dynamicRefAnchor) +} + +// ApplyDefaults modifies an instance by applying the schema's defaults to it. If +// a schema or sub-schema has a default, then a corresponding zero instance value +// is set to the default. +// +// The JSON Schema specification does not describe how defaults should be interpreted. +// This method honors defaults only on properties, and only those that are not required. +// If the instance is a map and the property is missing, the property is added to +// the map with the default. +// If the instance is a struct, the field corresponding to the property exists, and +// its value is zero, the field is set to the default. +// ApplyDefaults can panic if a default cannot be assigned to a field. +// +// The argument must be a pointer to the instance. +// (In case we decide that top-level defaults are meaningful.) +// +// It is recommended to first call Resolve with a ValidateDefaults option of true, +// then call this method, and lastly call Validate. +func (rs *Resolved) ApplyDefaults(instancep any) error { + // TODO(jba): consider what defaults on top-level or array instances might mean. + // TODO(jba): follow $ref and $dynamicRef + // TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. + st := &state{rs: rs} + return st.applyDefaults(reflect.ValueOf(instancep), rs.root) +} + +// Leave this as a potentially recursive helper function, because we'll surely want +// to apply defaults on sub-schemas someday. +func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { + defer wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) + + schemaInfo := st.rs.resolvedInfos[schema] + instance := instancep.Elem() + if instance.Kind() == reflect.Interface && instance.IsValid() { + // If we unmarshalled into 'any', the default object unmarshalling will be map[string]any. + instance = instance.Elem() + } + if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + } + for prop, subschema := range schema.Properties { + // Ignore defaults on required properties. (A required property shouldn't have a default.) + if schemaInfo.isRequired[prop] { + continue + } + val := property(instance, prop) + switch instance.Kind() { + case reflect.Map: + // If there is a default for this property, and the map key is missing, + // set the map value to the default. + if subschema.Default != nil && !val.IsValid() { + // Create an lvalue, since map values aren't addressable. + lvalue := reflect.New(instance.Type().Elem()) + if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { + return err + } + instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) + } + case reflect.Struct: + // If there is a default for this property, and the field exists but is zero, + // set the field to the default. + if subschema.Default != nil && val.IsValid() && val.IsZero() { + if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { + return err + } + } + default: + panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", + prop, instance, instance.Kind())) + } + } + } + return nil +} + +// property returns the value of the property of v with the given name, or the invalid +// reflect.Value if there is none. +// If v is a map, the property is the value of the map whose key is name. +// If v is a struct, the property is the value of the field with the given name according +// to the encoding/json package (see [jsonName]). +// If v is anything else, property panics. +func property(v reflect.Value, name string) reflect.Value { + switch v.Kind() { + case reflect.Map: + return v.MapIndex(reflect.ValueOf(name)) + case reflect.Struct: + props := structPropertiesOf(v.Type()) + // Ignore nonexistent properties. + if sf, ok := props[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} + default: + panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) + } +} + +// properties returns an iterator over the names and values of all properties +// in v, which must be a map or a struct. +// If a struct, zero-valued properties that are marked omitempty or omitzero +// are excluded. +func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { + return func(yield func(string, reflect.Value) bool) { + switch v.Kind() { + case reflect.Map: + for k, e := range v.Seq2() { + if !yield(k.String(), e) { + return + } + } + case reflect.Struct: + for name, sf := range structPropertiesOf(v.Type()) { + val := v.FieldByIndex(sf.Index) + if val.IsZero() { + info := fieldJSONInfo(sf) + if info.settings["omitempty"] || info.settings["omitzero"] { + continue + } + } + if !yield(name, val) { + return + } + } + default: + panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) + } + } +} + +// numPropertiesBounds returns bounds on the number of v's properties. +// v must be a map or a struct. +// If v is a map, both bounds are the map's size. +// If v is a struct, the max is the number of struct properties. +// But since we don't know whether a zero value indicates a missing optional property +// or not, be generous and use the number of non-zero properties as the min. +func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { + switch v.Kind() { + case reflect.Map: + return v.Len(), v.Len() + case reflect.Struct: + sp := structPropertiesOf(v.Type()) + min := 0 + for prop, sf := range sp { + if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { + min++ + } + } + return min, len(sp) + default: + panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) + } +} + +// A propertyMap is a map from property name to struct field index. +type propertyMap = map[string]reflect.StructField + +var structProperties sync.Map // from reflect.Type to propertyMap + +// structPropertiesOf returns the JSON Schema properties for the struct type t. +// The caller must not mutate the result. +func structPropertiesOf(t reflect.Type) propertyMap { + // Mutex not necessary: at worst we'll recompute the same value. + if props, ok := structProperties.Load(t); ok { + return props.(propertyMap) + } + props := map[string]reflect.StructField{} + for _, sf := range reflect.VisibleFields(t) { + if sf.Anonymous { + continue + } + info := fieldJSONInfo(sf) + if !info.omit { + props[info.name] = sf + } + } + structProperties.Store(t, props) + return props +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE new file mode 100644 index 000000000..508be9266 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Go MCP SDK Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go new file mode 100644 index 000000000..87665121c --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go @@ -0,0 +1,168 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "slices" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TokenInfo holds information from a bearer token. +type TokenInfo struct { + Scopes []string + Expiration time.Time + // UserID is an optional identifier for the authenticated user. + // If set by a TokenVerifier, it can be used by transports to prevent + // session hijacking by ensuring that all requests for a given session + // come from the same user. + UserID string + // TODO: add standard JWT fields + Extra map[string]any +} + +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// The error that a TokenVerifier should return for OAuth-specific protocol errors. +var ErrOAuth = errors.New("oauth error") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) + +// RequireBearerTokenOptions are options for [RequireBearerToken]. +type RequireBearerTokenOptions struct { + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. + ResourceMetadataURL string + // The required scopes. + Scopes []string +} + +type tokenInfoKey struct{} + +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + +// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. +// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. +// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header +// is populated to enable [protected resource metadata]. +// +// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 +func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { + // Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenInfo, errmsg, code := verify(r, verifier, opts) + if code != 0 { + if code == http.StatusUnauthorized || code == http.StatusForbidden { + if opts != nil && opts.ResourceMetadataURL != "" { + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + } + } + http.Error(w, errmsg, code) + return + } + r = r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo)) + handler.ServeHTTP(w, r) + }) + } +} + +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { + // Extract bearer token. + authHeader := req.Header.Get("Authorization") + fields := strings.Fields(authHeader) + if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { + return nil, "no bearer token", http.StatusUnauthorized + } + + // Verify the token and get information from it. + tokenInfo, err := verifier(req.Context(), fields[1], req) + if err != nil { + if errors.Is(err, ErrInvalidToken) { + return nil, err.Error(), http.StatusUnauthorized + } + if errors.Is(err, ErrOAuth) { + return nil, err.Error(), http.StatusBadRequest + } + return nil, err.Error(), http.StatusInternalServerError + } + + // Check scopes. All must be present. + if opts != nil { + // Note: quadratic, but N is small. + for _, s := range opts.Scopes { + if !slices.Contains(tokenInfo.Scopes, s) { + return nil, "insufficient scope", http.StatusForbidden + } + } + } + + // Check expiration. + if tokenInfo.Expiration.IsZero() { + return nil, "token missing expiration", http.StatusUnauthorized + } + if tokenInfo.Expiration.Before(time.Now()) { + return nil, "token expired", http.StatusUnauthorized + } + return tokenInfo, "", 0 +} + +// ProtectedResourceMetadataHandler returns an http.Handler that serves OAuth 2.0 +// protected resource metadata (RFC 9728) with CORS support. +// +// This handler allows cross-origin requests from any origin (Access-Control-Allow-Origin: *) +// because OAuth metadata is public information intended for client discovery (RFC 9728 §3.1). +// The metadata contains only non-sensitive configuration data about authorization servers +// and supported scopes. +// +// No validation of metadata fields is performed; ensure metadata accuracy at configuration time. +// +// For more sophisticated CORS policies or to restrict origins, wrap this handler with a +// CORS middleware like github.com/rs/cors or github.com/jub0bs/cors. +func ProtectedResourceMetadataHandler(metadata *oauthex.ProtectedResourceMetadata) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for cross-origin client discovery. + // OAuth metadata is public information, so allowing any origin is safe. + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // Handle CORS preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only GET allowed for metadata retrieval + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } + }) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go new file mode 100644 index 000000000..acadc51be --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go @@ -0,0 +1,123 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +type HTTPTransport struct { + handler OAuthHandler + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go new file mode 100644 index 000000000..627ffe7b6 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go @@ -0,0 +1,841 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +// Binder builds a connection configuration. +// This may be used in servers to generate a new configuration per connection. +// ConnectionOptions itself implements Binder returning itself unmodified, to +// allow for the simple cases where no per connection information is needed. +type Binder interface { + // Bind returns the ConnectionOptions to use when establishing the passed-in + // Connection. + // + // The connection is not ready to use when Bind is called, + // but Bind may close it without reading or writing to it. + Bind(context.Context, *Connection) ConnectionOptions +} + +// A BinderFunc implements the Binder interface for a standalone Bind function. +type BinderFunc func(context.Context, *Connection) ConnectionOptions + +func (f BinderFunc) Bind(ctx context.Context, c *Connection) ConnectionOptions { + return f(ctx, c) +} + +var _ Binder = BinderFunc(nil) + +// ConnectionOptions holds the options for new connections. +type ConnectionOptions struct { + // Framer allows control over the message framing and encoding. + // If nil, HeaderFramer will be used. + Framer Framer + // Preempter allows registration of a pre-queue message handler. + // If nil, no messages will be preempted. + Preempter Preempter + // Handler is used as the queued message handler for inbound messages. + // If nil, all responses will be ErrNotHandled. + Handler Handler + // OnInternalError, if non-nil, is called with any internal errors that occur + // while serving the connection, such as protocol errors or invariant + // violations. (If nil, internal errors result in panics.) + OnInternalError func(error) +} + +// Connection manages the jsonrpc2 protocol, connecting responses back to their +// calls. Connection is bidirectional; it does not have a designated server or +// client end. +// +// Note that the word 'Connection' is overloaded: the mcp.Connection represents +// the bidirectional stream of messages between client an server. The +// jsonrpc2.Connection layers RPC logic on top of that stream, dispatching RPC +// handlers, and correlating requests with responses from the peer. +// +// Some of the complexity of the Connection type is grown out of its usage in +// gopls: it could probably be simplified based on our usage in MCP. +type Connection struct { + seq int64 // must only be accessed using atomic operations + + stateMu sync.Mutex + state inFlightState // accessed only in updateInFlight + done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed + + writer Writer + handler Handler + + onInternalError func(error) + onDone func() +} + +// inFlightState records the state of the incoming and outgoing calls on a +// Connection. +type inFlightState struct { + connClosing bool // true when the Connection's Close method has been called + reading bool // true while the readIncoming goroutine is running + readErr error // non-nil when the readIncoming goroutine exits (typically io.EOF) + writeErr error // non-nil if a call to the Writer has failed with a non-canceled Context + + // closer shuts down and cleans up the Reader and Writer state, ideally + // interrupting any Read or Write call that is currently blocked. It is closed + // when the state is idle and one of: connClosing is true, readErr is non-nil, + // or writeErr is non-nil. + // + // After the closer has been invoked, the closer field is set to nil + // and the closeErr field is simultaneously set to its result. + closer io.Closer + closeErr error // error returned from closer.Close + + outgoingCalls map[ID]*AsyncCall // calls only + outgoingNotifications int // # of notifications awaiting "write" + + // incoming stores the total number of incoming calls and notifications + // that have not yet written or processed a result. + incoming int + + incomingByID map[ID]*incomingRequest // calls only + + // handlerQueue stores the backlog of calls and notifications that were not + // already handled by a preempter. + // The queue does not include the request currently being handled (if any). + handlerQueue []*incomingRequest + handlerRunning bool +} + +// updateInFlight locks the state of the connection's in-flight requests, allows +// f to mutate that state, and closes the connection if it is idle and either +// is closing or has a read or write error. +func (c *Connection) updateInFlight(f func(*inFlightState)) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + s := &c.state + + f(s) + + select { + case <-c.done: + // The connection was already completely done at the start of this call to + // updateInFlight, so it must remain so. (The call to f should have noticed + // that and avoided making any updates that would cause the state to be + // non-idle.) + if !s.idle() { + panic("jsonrpc2: updateInFlight transitioned to non-idle when already done") + } + return + default: + } + + if s.idle() && s.shuttingDown(ErrUnknown) != nil { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil // prevent duplicate Close calls + } + if s.reading { + // The readIncoming goroutine is still running. Our call to Close should + // cause it to exit soon, at which point it will make another call to + // updateInFlight, set s.reading to false, and mark the Connection done. + } else { + // The readIncoming goroutine has exited, or never started to begin with. + // Since everything else is idle, we're completely done. + if c.onDone != nil { + c.onDone() + } + close(c.done) + } + } +} + +// idle reports whether the connection is in a state with no pending calls or +// notifications. +// +// If idle returns true, the readIncoming goroutine may still be running, +// but no other goroutines are doing work on behalf of the connection. +func (s *inFlightState) idle() bool { + return len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning +} + +// shuttingDown reports whether the connection is in a state that should +// disallow new (incoming and outgoing) calls. It returns either nil or +// an error that is or wraps the provided errClosing. +func (s *inFlightState) shuttingDown(errClosing error) error { + if s.connClosing { + // If Close has been called explicitly, it doesn't matter what state the + // Reader and Writer are in: we shouldn't be starting new work because the + // caller told us not to start new work. + return errClosing + } + if s.readErr != nil { + // If the read side of the connection is broken, we cannot read new call + // requests, and cannot read responses to our outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.readErr) + } + if s.writeErr != nil { + // If the write side of the connection is broken, we cannot write responses + // for incoming calls, and cannot write requests for outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.writeErr) + } + return nil +} + +// incomingRequest is used to track an incoming request as it is being handled +type incomingRequest struct { + *Request // the request being processed + ctx context.Context + cancel context.CancelFunc +} + +// Bind returns the options unmodified. +func (o ConnectionOptions) Bind(context.Context, *Connection) ConnectionOptions { + return o +} + +// A ConnectionConfig configures a bidirectional jsonrpc2 connection. +type ConnectionConfig struct { + Reader Reader // required + Writer Writer // required + Closer io.Closer // required + Preempter Preempter // optional + Bind func(*Connection) Handler // required + OnDone func() // optional + OnInternalError func(error) // optional +} + +// NewConnection creates a new [Connection] object and starts processing +// incoming messages. +func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { + ctx = notDone{ctx} + + c := &Connection{ + state: inFlightState{closer: cfg.Closer}, + done: make(chan struct{}), + writer: cfg.Writer, + onDone: cfg.OnDone, + onInternalError: cfg.OnInternalError, + } + c.handler = cfg.Bind(c) + c.start(ctx, cfg.Reader, cfg.Preempter) + return c +} + +// bindConnection creates a new connection and runs it. +// +// This is used by the Dial and Serve functions to build the actual connection. +// +// The connection is closed automatically (and its resources cleaned up) when +// the last request has completed after the underlying ReadWriteCloser breaks, +// but it may be stopped earlier by calling Close (for a clean shutdown). +func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binder, onDone func()) *Connection { + // TODO: Should we create a new event span here? + // This will propagate cancellation from ctx; should it? + ctx := notDone{bindCtx} + + c := &Connection{ + state: inFlightState{closer: rwc}, + done: make(chan struct{}), + onDone: onDone, + } + // It's tempting to set a finalizer on c to verify that the state has gone + // idle when the connection becomes unreachable. Unfortunately, the Binder + // interface makes that unsafe: it allows the Handler to close over the + // Connection, which could create a reference cycle that would cause the + // Connection to become uncollectable. + + options := binder.Bind(bindCtx, c) + framer := options.Framer + if framer == nil { + framer = HeaderFramer() + } + c.handler = options.Handler + if c.handler == nil { + c.handler = defaultHandler{} + } + c.onInternalError = options.OnInternalError + + c.writer = framer.Writer(rwc) + reader := framer.Reader(rwc) + c.start(ctx, reader, options.Preempter) + return c +} + +func (c *Connection) start(ctx context.Context, reader Reader, preempter Preempter) { + c.updateInFlight(func(s *inFlightState) { + select { + case <-c.done: + // Bind already closed the connection; don't start a goroutine to read it. + return + default: + } + + // The goroutine started here will continue until the underlying stream is closed. + // + // (If the Binder closed the Connection already, this should error out and + // return almost immediately.) + s.reading = true + go c.readIncoming(ctx, reader, preempter) + }) +} + +// Notify invokes the target method but does not wait for a response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +func (c *Connection) Notify(ctx context.Context, method string, params any) (err error) { + attempted := false + + defer func() { + if attempted { + c.updateInFlight(func(s *inFlightState) { + s.outgoingNotifications-- + }) + } + }() + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, allow outgoing notifications only if + // there is at least one call still in flight. The number of calls in flight + // cannot increase once shutdown begins, and allowing outgoing notifications + // may permit notifications that will cancel in-flight calls. + if len(s.outgoingCalls) == 0 && len(s.incomingByID) == 0 { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + } + s.outgoingNotifications++ + attempted = true + }) + if err != nil { + return err + } + + notify, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("marshaling notify parameters: %v", err) + } + + return c.write(ctx, notify) +} + +// Call invokes the target method and returns an object that can be used to await the response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +// You do not have to wait for the response, it can just be ignored if not needed. +// If sending the call failed, the response will be ready and have the error in it. +func (c *Connection) Call(ctx context.Context, method string, params any) *AsyncCall { + // Generate a new request identifier. + id := Int64ID(atomic.AddInt64(&c.seq, 1)) + + ac := &AsyncCall{ + id: id, + ready: make(chan struct{}), + } + // When this method returns, either ac is retired, or the request has been + // written successfully and the call is awaiting a response (to be provided by + // the readIncoming goroutine). + + call, err := NewCall(ac.id, method, params) + if err != nil { + ac.retire(&Response{ID: id, Error: fmt.Errorf("marshaling call parameters: %w", err)}) + return ac + } + + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + if s.outgoingCalls == nil { + s.outgoingCalls = make(map[ID]*AsyncCall) + } + s.outgoingCalls[ac.id] = ac + }) + if err != nil { + ac.retire(&Response{ID: id, Error: err}) + return ac + } + + if err := c.write(ctx, call); err != nil { + // Sending failed. We will never get a response, so deliver a fake one if it + // wasn't already retired by the connection breaking. + c.Retire(ac, err) + } + return ac +} + +// Retire stops tracking the call, and reports err as its terminal error. +// +// Retire is safe to call multiple times: if the call is already no longer +// tracked, Retire is a no op. +func (c *Connection) Retire(ac *AsyncCall, err error) { + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: ac.id, Error: err}) + } else { + // ac was already retired elsewhere. + } + }) +} + +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + +type AsyncCall struct { + id ID + ready chan struct{} // closed after response has been set + response *Response +} + +// ID used for this call. +// This can be used to cancel the call if needed. +func (ac *AsyncCall) ID() ID { return ac.id } + +// IsReady can be used to check if the result is already prepared. +// This is guaranteed to return true on a result for which Await has already +// returned, or a call that failed to send in the first place. +func (ac *AsyncCall) IsReady() bool { + select { + case <-ac.ready: + return true + default: + return false + } +} + +// retire processes the response to the call. +// +// It is an error to call retire more than once: retire is guarded by the +// connection's outgoingCalls map. +func (ac *AsyncCall) retire(response *Response) { + select { + case <-ac.ready: + panic(fmt.Sprintf("jsonrpc2: retire called twice for ID %v", ac.id)) + default: + } + + ac.response = response + close(ac.ready) +} + +// Await waits for (and decodes) the results of a Call. +// The response will be unmarshaled from JSON into the result. +// +// If the call is cancelled due to context cancellation, the result is +// ctx.Err(). +func (ac *AsyncCall) Await(ctx context.Context, result any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ac.ready: + } + if ac.response.Error != nil { + return ac.response.Error + } + if result == nil { + return nil + } + return json.Unmarshal(ac.response.Result, result) +} + +// Cancel cancels the Context passed to the Handle call for the inbound message +// with the given ID. +// +// Cancel will not complain if the ID is not a currently active message, and it +// will not cause any messages that have not arrived yet with that ID to be +// cancelled. +func (c *Connection) Cancel(id ID) { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + req = s.incomingByID[id] + }) + if req != nil { + req.cancel() + } +} + +// Wait blocks until the connection is fully closed, but does not close it. +func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { + var err error + <-c.done + c.updateInFlight(func(s *inFlightState) { + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } + }) + return err +} + +// Close stops accepting new requests, waits for in-flight requests and enqueued +// Handle calls to complete, and then closes the underlying stream. +// +// After the start of a Close, notification requests (that lack IDs and do not +// receive responses) will continue to be passed to the Preempter, but calls +// with IDs will receive immediate responses with ErrServerClosing, and no new +// requests (not even notifications!) will be enqueued to the Handler. +func (c *Connection) Close() error { + // Stop handling new requests, and interrupt the reader (by closing the + // connection) as soon as the active requests finish. + c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) + return c.wait(false) +} + +// readIncoming collects inbound messages from the reader and delivers them, either responding +// to outgoing calls or feeding requests to the queue. +func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter Preempter) { + var err error + for { + var msg Message + msg, err = reader.Read(ctx) + if err != nil { + break + } + + switch msg := msg.(type) { + case *Request: + c.acceptRequest(ctx, msg, preempter) + + case *Response: + c.updateInFlight(func(s *inFlightState) { + if ac, ok := s.outgoingCalls[msg.ID]; ok { + delete(s.outgoingCalls, msg.ID) + ac.retire(msg) + } else { + // TODO: How should we report unexpected responses? + } + }) + + default: + c.internalErrorf("Read returned an unexpected message of type %T", msg) + } + } + + c.updateInFlight(func(s *inFlightState) { + s.reading = false + s.readErr = err + + // Retire any outgoing requests that were still in flight: with the Reader no + // longer being processed, they necessarily cannot receive a response. + for id, ac := range s.outgoingCalls { + ac.retire(&Response{ID: id, Error: err}) + } + s.outgoingCalls = nil + }) +} + +// acceptRequest either handles msg synchronously or enqueues it to be handled +// asynchronously. +func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter Preempter) { + // In theory notifications cannot be cancelled, but we build them a cancel + // context anyway. + reqCtx, cancel := context.WithCancel(ctx) + req := &incomingRequest{ + Request: msg, + ctx: reqCtx, + cancel: cancel, + } + + // If the request is a call, add it to the incoming map so it can be + // cancelled (or responded) by ID. + var err error + c.updateInFlight(func(s *inFlightState) { + s.incoming++ + + if req.IsCall() { + if s.incomingByID[req.ID] != nil { + err = fmt.Errorf("%w: request ID %v already in use", ErrInvalidRequest, req.ID) + req.ID = ID{} // Don't misattribute this error to the existing request. + return + } + + if s.incomingByID == nil { + s.incomingByID = make(map[ID]*incomingRequest) + } + s.incomingByID[req.ID] = req + + // When shutting down, reject all new Call requests, even if they could + // theoretically be handled by the preempter. The preempter could return + // ErrAsyncResponse, which would increase the amount of work in flight + // when we're trying to ensure that it strictly decreases. + err = s.shuttingDown(ErrServerClosing) + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + return + } + + if preempter != nil { + result, err := preempter.Preempt(req.ctx, req.Request) + + if !errors.Is(err, ErrNotHandled) { + c.processResult("Preempt", req, result, err) + return + } + } + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, don't enqueue anything to the + // handler — not even notifications. That ensures that if the handler + // continues to make progress, it will eventually become idle and + // close the connection. + err = s.shuttingDown(ErrServerClosing) + if err != nil { + return + } + + // We enqueue requests that have not been preempted to an unbounded slice. + // Unfortunately, we cannot in general limit the size of the handler + // queue: we have to read every response that comes in on the wire + // (because it may be responding to a request issued by, say, an + // asynchronous handler), and in order to get to that response we have + // to read all of the requests that came in ahead of it. + s.handlerQueue = append(s.handlerQueue, req) + if !s.handlerRunning { + // We start the handleAsync goroutine when it has work to do, and let it + // exit when the queue empties. + // + // Otherwise, in order to synchronize the handler we would need some other + // goroutine (probably readIncoming?) to explicitly wait for handleAsync + // to finish, and that would complicate error reporting: either the error + // report from the goroutine would be blocked on the handler emptying its + // queue (which was tried, and introduced a deadlock detected by + // TestCloseCallRace), or the error would need to be reported separately + // from synchronizing completion. Allowing the handler goroutine to exit + // when idle seems simpler than trying to implement either of those + // alternatives correctly. + s.handlerRunning = true + go c.handleAsync() + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + } +} + +// handleAsync invokes the handler on the requests in the handler queue +// sequentially until the queue is empty. +func (c *Connection) handleAsync() { + for { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + if len(s.handlerQueue) > 0 { + req, s.handlerQueue = s.handlerQueue[0], s.handlerQueue[1:] + } else { + s.handlerRunning = false + } + }) + if req == nil { + return + } + + // Only deliver to the Handler if not already canceled. + if err := req.ctx.Err(); err != nil { + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + // Assume that req.ctx was canceled due to s.writeErr. + // TODO(#51365): use a Context API to plumb this through req.ctx. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + } + }) + c.processResult("handleAsync", req, nil, err) + continue + } + + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch + } +} + +// processResult processes the result of a request and, if appropriate, sends a response. +func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { + switch err { + case ErrNotHandled, ErrMethodNotFound: + // Add detail describing the unhandled method. + err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) + } + + if result != nil && err != nil { + c.internalErrorf("%#v returned a non-nil result with a non-nil error for %s:\n%v\n%#v", from, req.Method, err, result) + result = nil // Discard the spurious result and respond with err. + } + + if req.IsCall() { + if result == nil && err == nil { + err = c.internalErrorf("%#v returned a nil result and nil error for a %q Request that requires a Response", from, req.Method) + } + + response, respErr := NewResponse(req.ID, result, err) + + // The caller could theoretically reuse the request's ID as soon as we've + // sent the response, so ensure that it is removed from the incoming map + // before sending. + c.updateInFlight(func(s *inFlightState) { + delete(s.incomingByID, req.ID) + }) + if respErr == nil { + writeErr := c.write(notDone{req.ctx}, response) + if err == nil { + err = writeErr + } + } else { + err = c.internalErrorf("%#v returned a malformed result for %q: %w", from, req.Method, respErr) + } + } else { // req is a notification + if result != nil { + err = c.internalErrorf("%#v returned a non-nil result for a %q Request without an ID", from, req.Method) + } else if err != nil { + err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) + } + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) + } + + // Cancel the request to free any associated resources. + req.cancel() + c.updateInFlight(func(s *inFlightState) { + if s.incoming == 0 { + panic("jsonrpc2: processResult called when incoming count is already zero") + } + s.incoming-- + }) + return nil +} + +// write is used by all things that write outgoing messages, including replies. +// it makes sure that writes are atomic +func (c *Connection) write(ctx context.Context, msg Message) error { + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For cancelled or rejected requests, we don't set the writeErr (which would + // break the connection). They can just be returned to the caller. + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { + // The call to Write failed, and since ctx.Err() is nil we can't attribute + // the failure (even indirectly) to Context cancellation. The writer appears + // to be broken, and future writes are likely to also fail. + // + // If the read side of the connection is also broken, we might not even be + // able to receive cancellation notifications. Since we can't reliably write + // the results of incoming calls and can't receive explicit cancellations, + // cancel the calls now. + c.updateInFlight(func(s *inFlightState) { + if s.writeErr == nil { + s.writeErr = err + for _, r := range s.incomingByID { + r.cancel() + } + } + }) + } + + return err +} + +// internalErrorf reports an internal error. By default it panics, but if +// c.onInternalError is non-nil it instead calls that and returns an error +// wrapping ErrInternal. +func (c *Connection) internalErrorf(format string, args ...any) error { + err := fmt.Errorf(format, args...) + if c.onInternalError == nil { + panic("jsonrpc2: " + err.Error()) + } + c.onInternalError(err) + + return fmt.Errorf("%w: %v", ErrInternal, err) +} + +// notDone is a context.Context wrapper that returns a nil Done channel. +type notDone struct{ ctx context.Context } + +func (ic notDone) Value(key any) any { + return ic.ctx.Value(key) +} + +func (notDone) Done() <-chan struct{} { return nil } +func (notDone) Err() error { return nil } +func (notDone) Deadline() (time.Time, bool) { return time.Time{}, false } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go new file mode 100644 index 000000000..46fcc9db9 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go @@ -0,0 +1,208 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "sync" +) + +// Reader abstracts the transport mechanics from the JSON RPC protocol. +// A Conn reads messages from the reader it was provided on construction, +// and assumes that each call to Read fully transfers a single message, +// or returns an error. +// +// A reader is not safe for concurrent use, it is expected it will be used by +// a single Conn in a safe manner. +type Reader interface { + // Read gets the next message from the stream. + Read(context.Context) (Message, error) +} + +// Writer abstracts the transport mechanics from the JSON RPC protocol. +// A Conn writes messages using the writer it was provided on construction, +// and assumes that each call to Write fully transfers a single message, +// or returns an error. +// +// A writer must be safe for concurrent use, as writes may occur concurrently +// in practice: libraries may make calls or respond to requests asynchronously. +type Writer interface { + // Write sends a message to the stream. + Write(context.Context, Message) error +} + +// Framer wraps low level byte readers and writers into jsonrpc2 message +// readers and writers. +// It is responsible for the framing and encoding of messages into wire form. +// +// TODO(rfindley): rethink the framer interface, as with JSONRPC2 batching +// there is a need for Reader and Writer to be correlated, and while the +// implementation of framing here allows that, it is not made explicit by the +// interface. +// +// Perhaps a better interface would be +// +// Frame(io.ReadWriteCloser) (Reader, Writer). +type Framer interface { + // Reader wraps a byte reader into a message reader. + Reader(io.Reader) Reader + // Writer wraps a byte writer into a message writer. + Writer(io.Writer) Writer +} + +// RawFramer returns a new Framer. +// The messages are sent with no wrapping, and rely on json decode consistency +// to determine message boundaries. +func RawFramer() Framer { return rawFramer{} } + +type rawFramer struct{} +type rawReader struct{ in *json.Decoder } +type rawWriter struct { + mu sync.Mutex + out io.Writer +} + +func (rawFramer) Reader(rw io.Reader) Reader { + return &rawReader{in: json.NewDecoder(rw)} +} + +func (rawFramer) Writer(rw io.Writer) Writer { + return &rawWriter{out: rw} +} + +func (r *rawReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var raw json.RawMessage + if err := r.in.Decode(&raw); err != nil { + return nil, err + } + msg, err := DecodeMessage(raw) + return msg, err +} + +func (w *rawWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + + w.mu.Lock() + defer w.mu.Unlock() + _, err = w.out.Write(data) + return err +} + +// HeaderFramer returns a new Framer. +// The messages are sent with HTTP content length and MIME type headers. +// This is the format used by LSP and others. +func HeaderFramer() Framer { return headerFramer{} } + +type headerFramer struct{} +type headerReader struct{ in *bufio.Reader } +type headerWriter struct { + mu sync.Mutex + out io.Writer +} + +func (headerFramer) Reader(rw io.Reader) Reader { + return &headerReader{in: bufio.NewReader(rw)} +} + +func (headerFramer) Writer(rw io.Writer) Writer { + return &headerWriter{out: rw} +} + +func (r *headerReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + firstRead := true // to detect a clean EOF below + var contentLength int64 + // read the header, stop on the first empty line + for { + line, err := r.in.ReadString('\n') + if err != nil { + if err == io.EOF { + if firstRead && line == "" { + return nil, io.EOF // clean EOF + } + err = io.ErrUnexpectedEOF + } + return nil, fmt.Errorf("failed reading header line: %w", err) + } + firstRead = false + + line = strings.TrimSpace(line) + // check we have a header line + if line == "" { + break + } + colon := strings.IndexRune(line, ':') + if colon < 0 { + return nil, fmt.Errorf("invalid header line %q", line) + } + name, value := line[:colon], strings.TrimSpace(line[colon+1:]) + switch name { + case "Content-Length": + if contentLength, err = strconv.ParseInt(value, 10, 32); err != nil { + return nil, fmt.Errorf("failed parsing Content-Length: %v", value) + } + if contentLength <= 0 { + return nil, fmt.Errorf("invalid Content-Length: %v", contentLength) + } + default: + // ignoring unknown headers + } + } + if contentLength == 0 { + return nil, fmt.Errorf("missing Content-Length header") + } + data := make([]byte, contentLength) + _, err := io.ReadFull(r.in, data) + if err != nil { + return nil, err + } + msg, err := DecodeMessage(data) + return msg, err +} + +func (w *headerWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + w.mu.Lock() + defer w.mu.Unlock() + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + _, err = fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data)) + if err == nil { + _, err = w.out.Write(data) + } + return err +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go new file mode 100644 index 000000000..234e6ee3a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go @@ -0,0 +1,121 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec. +// https://www.jsonrpc.org/specification +// It is intended to be compatible with other implementations at the wire level. +package jsonrpc2 + +import ( + "context" + "errors" +) + +var ( + // ErrIdleTimeout is returned when serving timed out waiting for new connections. + ErrIdleTimeout = errors.New("timed out waiting for new connections") + + // ErrNotHandled is returned from a Handler or Preempter to indicate it did + // not handle the request. + // + // If a Handler returns ErrNotHandled, the server replies with + // ErrMethodNotFound. + ErrNotHandled = errors.New("JSON RPC not handled") +) + +// Preempter handles messages on a connection before they are queued to the main +// handler. +// Primarily this is used for cancel handlers or notifications for which out of +// order processing is not an issue. +type Preempter interface { + // Preempt is invoked for each incoming request before it is queued for handling. + // + // If Preempt returns ErrNotHandled, the request will be queued, + // and eventually passed to a Handle call. + // + // Otherwise, the result and error are processed as if returned by Handle. + // + // Preempt must not block. (The Context passed to it is for Values only.) + Preempt(ctx context.Context, req *Request) (result any, err error) +} + +// A PreempterFunc implements the Preempter interface for a standalone Preempt function. +type PreempterFunc func(ctx context.Context, req *Request) (any, error) + +func (f PreempterFunc) Preempt(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Preempter = PreempterFunc(nil) + +// Handler handles messages on a connection. +type Handler interface { + // Handle is invoked sequentially for each incoming request that has not + // already been handled by a Preempter. + // + // If the Request has a nil ID, Handle must return a nil result, + // and any error may be logged but will not be reported to the caller. + // + // If the Request has a non-nil ID, Handle must return either a + // non-nil, JSON-marshalable result, or a non-nil error. + // + // The Context passed to Handle will be canceled if the + // connection is broken or the request is canceled or completed. + // (If Handle returns ErrAsyncResponse, ctx will remain uncanceled + // until either Cancel or Respond is called for the request's ID.) + Handle(ctx context.Context, req *Request) (result any, err error) +} + +type defaultHandler struct{} + +func (defaultHandler) Preempt(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +func (defaultHandler) Handle(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +// A HandlerFunc implements the Handler interface for a standalone Handle function. +type HandlerFunc func(ctx context.Context, req *Request) (any, error) + +func (f HandlerFunc) Handle(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Handler = HandlerFunc(nil) + +// async is a small helper for operations with an asynchronous result that you +// can wait for. +type async struct { + ready chan struct{} // closed when done + firstErr chan error // 1-buffered; contains either nil or the first non-nil error +} + +func newAsync() *async { + var a async + a.ready = make(chan struct{}) + a.firstErr = make(chan error, 1) + a.firstErr <- nil + return &a +} + +func (a *async) done() { + close(a.ready) +} + +func (a *async) wait() error { + <-a.ready + err := <-a.firstErr + a.firstErr <- err + return err +} + +func (a *async) setError(err error) { + storedErr := <-a.firstErr + if storedErr == nil { + storedErr = err + } + a.firstErr <- storedErr +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go new file mode 100644 index 000000000..791e698d9 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go @@ -0,0 +1,212 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" + "errors" + "fmt" +) + +// ID is a Request identifier, which is defined by the spec to be a string, integer, or null. +// https://www.jsonrpc.org/specification#request_object +type ID struct { + value any +} + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +// +// TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero. +// Simplify this package by making ID json serializable once we can rely on +// omitzero. +func MakeID(v any) (ID, error) { + switch v := v.(type) { + case nil: + return ID{}, nil + case float64: + return Int64ID(int64(v)), nil + case string: + return StringID(v), nil + } + return ID{}, fmt.Errorf("%w: invalid ID type %T", ErrParse, v) +} + +// Message is the interface to all jsonrpc2 message types. +// They share no common functionality, but are a closed set of concrete types +// that are allowed to implement this interface. The message types are *Request +// and *Response. +type Message interface { + // marshal builds the wire form from the API form. + // It is private, which makes the set of Message implementations closed. + marshal(to *wireCombined) +} + +// Request is a Message sent to a peer to request behavior. +// If it has an ID it is a call, otherwise it is a notification. +type Request struct { + // ID of this request, used to tie the Response back to the request. + // This will be nil for notifications. + ID ID + // Method is a string containing the method name to invoke. + Method string + // Params is either a struct or an array with the parameters of the method. + Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any +} + +// Response is a Message used as a reply to a call Request. +// It will have the same ID as the call it is a response to. +type Response struct { + // result is the content of the response. + Result json.RawMessage + // err is set only if the call failed. + Error error + // id of the request this is a response to. + ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any +} + +// StringID creates a new string request identifier. +func StringID(s string) ID { return ID{value: s} } + +// Int64ID creates a new integer request identifier. +func Int64ID(i int64) ID { return ID{value: i} } + +// IsValid returns true if the ID is a valid identifier. +// The default value for ID will return false. +func (id ID) IsValid() bool { return id.value != nil } + +// Raw returns the underlying value of the ID. +func (id ID) Raw() any { return id.value } + +// NewNotification constructs a new Notification message for the supplied +// method and parameters. +func NewNotification(method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{Method: method, Params: p}, merr +} + +// NewCall constructs a new Call message for the supplied ID, method and +// parameters. +func NewCall(id ID, method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{ID: id, Method: method, Params: p}, merr +} + +func (msg *Request) IsCall() bool { return msg.ID.IsValid() } + +func (msg *Request) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Method = msg.Method + to.Params = msg.Params +} + +// NewResponse constructs a new Response message that is a reply to the +// supplied. If err is set result may be ignored. +func NewResponse(id ID, result any, rerr error) (*Response, error) { + r, merr := marshalToRaw(result) + return &Response{ID: id, Result: r, Error: rerr}, merr +} + +func (msg *Response) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Error = toWireError(msg.Error) + to.Result = msg.Result +} + +func toWireError(err error) *WireError { + if err == nil { + // no error, the response is complete + return nil + } + if err, ok := err.(*WireError); ok { + // already a wire error, just use it + return err + } + result := &WireError{Message: err.Error()} + var wrapped *WireError + if errors.As(err, &wrapped) { + // if we wrapped a wire error, keep the code from the wrapped error + // but the message from the outer error + result.Code = wrapped.Code + } + return result +} + +func EncodeMessage(msg Message) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.Marshal(&wire) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +// EncodeIndent is like EncodeMessage, but honors indents. +// TODO(rfindley): refactor so that this concern is handled independently. +// Perhaps we should pass in a json.Encoder? +func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.MarshalIndent(&wire, prefix, indent) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +func DecodeMessage(data []byte) (Message, error) { + msg := wireCombined{} + if err := json.Unmarshal(data, &msg); err != nil { + return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) + } + if msg.VersionTag != wireVersion { + return nil, fmt.Errorf("invalid message version tag %q; expected %q", msg.VersionTag, wireVersion) + } + id, err := MakeID(msg.ID) + if err != nil { + return nil, err + } + if msg.Method != "" { + // has a method, must be a call + return &Request{ + Method: msg.Method, + ID: id, + Params: msg.Params, + }, nil + } + // no method, should be a response + if !id.IsValid() { + return nil, ErrInvalidRequest + } + resp := &Response{ + ID: id, + Result: msg.Result, + } + // we have to check if msg.Error is nil to avoid a typed error + if msg.Error != nil { + resp.Error = msg.Error + } + return resp, nil +} + +func marshalToRaw(obj any) (json.RawMessage, error) { + if obj == nil { + return nil, nil + } + data, err := json.Marshal(obj) + if err != nil { + return nil, err + } + return json.RawMessage(data), nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go new file mode 100644 index 000000000..05db06261 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go @@ -0,0 +1,138 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "io" + "net" + "os" +) + +// This file contains implementations of the transport primitives that use the standard network +// package. + +// NetListenOptions is the optional arguments to the NetListen function. +type NetListenOptions struct { + NetListenConfig net.ListenConfig + NetDialer net.Dialer +} + +// NetListener returns a new Listener that listens on a socket using the net package. +func NetListener(ctx context.Context, network, address string, options NetListenOptions) (Listener, error) { + ln, err := options.NetListenConfig.Listen(ctx, network, address) + if err != nil { + return nil, err + } + return &netListener{net: ln}, nil +} + +// netListener is the implementation of Listener for connections made using the net package. +type netListener struct { + net net.Listener +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netListener) Accept(context.Context) (io.ReadWriteCloser, error) { + return l.net.Accept() +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netListener) Close() error { + addr := l.net.Addr() + err := l.net.Close() + if addr.Network() == "unix" { + rerr := os.Remove(addr.String()) + if rerr != nil && err == nil { + err = rerr + } + } + return err +} + +// Dialer returns a dialer that can be used to connect to the listener. +func (l *netListener) Dialer() Dialer { + return NetDialer(l.net.Addr().Network(), l.net.Addr().String(), net.Dialer{}) +} + +// NetDialer returns a Dialer using the supplied standard network dialer. +func NetDialer(network, address string, nd net.Dialer) Dialer { + return &netDialer{ + network: network, + address: address, + dialer: nd, + } +} + +type netDialer struct { + network string + address string + dialer net.Dialer +} + +func (n *netDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + return n.dialer.DialContext(ctx, n.network, n.address) +} + +// NetPipeListener returns a new Listener that listens using net.Pipe. +// It is only possibly to connect to it using the Dialer returned by the +// Dialer method, each call to that method will generate a new pipe the other +// side of which will be returned from the Accept call. +func NetPipeListener(ctx context.Context) (Listener, error) { + return &netPiper{ + done: make(chan struct{}), + dialed: make(chan io.ReadWriteCloser), + }, nil +} + +// netPiper is the implementation of Listener build on top of net.Pipes. +type netPiper struct { + done chan struct{} + dialed chan io.ReadWriteCloser +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netPiper) Accept(context.Context) (io.ReadWriteCloser, error) { + // Block until the pipe is dialed or the listener is closed, + // preferring the latter if already closed at the start of Accept. + select { + case <-l.done: + return nil, net.ErrClosed + default: + } + select { + case rwc := <-l.dialed: + return rwc, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netPiper) Close() error { + // unblock any accept calls that are pending + close(l.done) + return nil +} + +func (l *netPiper) Dialer() Dialer { + return l +} + +func (l *netPiper) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + client, server := net.Pipe() + + select { + case l.dialed <- server: + return client, nil + + case <-l.done: + client.Close() + server.Close() + return nil, net.ErrClosed + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go new file mode 100644 index 000000000..424163aaf --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go @@ -0,0 +1,330 @@ +// Copyright 2020 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// Listener is implemented by protocols to accept new inbound connections. +type Listener interface { + // Accept accepts an inbound connection to a server. + // It blocks until either an inbound connection is made, or the listener is closed. + Accept(context.Context) (io.ReadWriteCloser, error) + + // Close closes the listener. + // Any blocked Accept or Dial operations will unblock and return errors. + Close() error + + // Dialer returns a dialer that can be used to connect to this listener + // locally. + // If a listener does not implement this it will return nil. + Dialer() Dialer +} + +// Dialer is used by clients to dial a server. +type Dialer interface { + // Dial returns a new communication byte stream to a listening server. + Dial(ctx context.Context) (io.ReadWriteCloser, error) +} + +// Server is a running server that is accepting incoming connections. +type Server struct { + listener Listener + binder Binder + async *async + + shutdownOnce sync.Once + closing int32 // atomic: set to nonzero when Shutdown is called +} + +// Dial uses the dialer to make a new connection, wraps the returned +// reader and writer using the framer to make a stream, and then builds +// a connection on top of that stream using the binder. +// +// The returned Connection will operate independently using the Preempter and/or +// Handler provided by the Binder, and will release its own resources when the +// connection is broken, but the caller may Close it earlier to stop accepting +// (or sending) new requests. +// +// If non-nil, the onDone function is called when the connection is closed. +func Dial(ctx context.Context, dialer Dialer, binder Binder, onDone func()) (*Connection, error) { + // dial a server + rwc, err := dialer.Dial(ctx) + if err != nil { + return nil, err + } + return bindConnection(ctx, rwc, binder, onDone), nil +} + +// NewServer starts a new server listening for incoming connections and returns +// it. +// This returns a fully running and connected server, it does not block on +// the listener. +// You can call Wait to block on the server, or Shutdown to get the sever to +// terminate gracefully. +// To notice incoming connections, use an intercepting Binder. +func NewServer(ctx context.Context, listener Listener, binder Binder) *Server { + server := &Server{ + listener: listener, + binder: binder, + async: newAsync(), + } + go server.run(ctx) + return server +} + +// Wait returns only when the server has shut down. +func (s *Server) Wait() error { + return s.async.wait() +} + +// Shutdown informs the server to stop accepting new connections. +func (s *Server) Shutdown() { + s.shutdownOnce.Do(func() { + atomic.StoreInt32(&s.closing, 1) + s.listener.Close() + }) +} + +// run accepts incoming connections from the listener, +// If IdleTimeout is non-zero, run exits after there are no clients for this +// duration, otherwise it exits only on error. +func (s *Server) run(ctx context.Context) { + defer s.async.done() + + var activeConns sync.WaitGroup + for { + rwc, err := s.listener.Accept(ctx) + if err != nil { + // Only Shutdown closes the listener. If we get an error after Shutdown is + // called, assume that was the cause and don't report the error; + // otherwise, report the error in case it is unexpected. + if atomic.LoadInt32(&s.closing) == 0 { + s.async.setError(err) + } + // We are done generating new connections for good. + break + } + + // A new inbound connection. + activeConns.Add(1) + _ = bindConnection(ctx, rwc, s.binder, activeConns.Done) // unregisters itself when done + } + activeConns.Wait() +} + +// NewIdleListener wraps a listener with an idle timeout. +// +// When there are no active connections for at least the timeout duration, +// calls to Accept will fail with ErrIdleTimeout. +// +// A connection is considered inactive as soon as its Close method is called. +func NewIdleListener(timeout time.Duration, wrap Listener) Listener { + l := &idleListener{ + wrapped: wrap, + timeout: timeout, + active: make(chan int, 1), + timedOut: make(chan struct{}), + idleTimer: make(chan *time.Timer, 1), + } + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + return l +} + +type idleListener struct { + wrapped Listener + timeout time.Duration + + // Only one of these channels is receivable at any given time. + active chan int // count of active connections; closed when Close is called if not timed out + timedOut chan struct{} // closed when the idle timer expires + idleTimer chan *time.Timer // holds the timer only when idle +} + +// Accept accepts an incoming connection. +// +// If an incoming connection is accepted concurrent to the listener being closed +// due to idleness, the new connection is immediately closed. +func (l *idleListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { + rwc, err := l.wrapped.Accept(ctx) + + select { + case n, ok := <-l.active: + if err != nil { + if ok { + l.active <- n + } + return nil, err + } + if ok { + l.active <- n + 1 + } else { + // l.wrapped.Close Close has been called, but Accept returned a + // connection. This race can occur with concurrent Accept and Close calls + // with any net.Listener, and it is benign: since the listener was closed + // explicitly, it can't have also timed out. + } + return l.newConn(rwc), nil + + case <-l.timedOut: + if err == nil { + // Keeping the connection open would leave the listener simultaneously + // active and closed due to idleness, which would be contradictory and + // confusing. Close the connection and pretend that it never happened. + rwc.Close() + } else { + // In theory the timeout could have raced with an unrelated error return + // from Accept. However, ErrIdleTimeout is arguably still valid (since we + // would have closed due to the timeout independent of the error), and the + // harm from returning a spurious ErrIdleTimeout is negligible anyway. + } + return nil, ErrIdleTimeout + + case timer := <-l.idleTimer: + if err != nil { + // The idle timer doesn't run until it receives itself from the idleTimer + // channel, so it can't have called l.wrapped.Close yet and thus err can't + // be ErrIdleTimeout. Leave the idle timer as it was and return whatever + // error we got. + l.idleTimer <- timer + return nil, err + } + + if !timer.Stop() { + // Failed to stop the timer — the timer goroutine is in the process of + // firing. Send the timer back to the timer goroutine so that it can + // safely close the timedOut channel, and then wait for the listener to + // actually be closed before we return ErrIdleTimeout. + l.idleTimer <- timer + rwc.Close() + <-l.timedOut + return nil, ErrIdleTimeout + } + + l.active <- 1 + return l.newConn(rwc), nil + } +} + +func (l *idleListener) Close() error { + select { + case _, ok := <-l.active: + if ok { + close(l.active) + } + + case <-l.timedOut: + // Already closed by the timer; take care not to double-close if the caller + // only explicitly invokes this Close method once, since the io.Closer + // interface explicitly leaves doubled Close calls undefined. + return ErrIdleTimeout + + case timer := <-l.idleTimer: + if !timer.Stop() { + // Couldn't stop the timer. It shouldn't take long to run, so just wait + // (so that the Listener is guaranteed to be closed before we return) + // and pretend that this call happened afterward. + // That way we won't leak any timers or goroutines when Close returns. + l.idleTimer <- timer + <-l.timedOut + return ErrIdleTimeout + } + close(l.active) + } + + return l.wrapped.Close() +} + +func (l *idleListener) Dialer() Dialer { + return l.wrapped.Dialer() +} + +func (l *idleListener) timerExpired() { + select { + case n, ok := <-l.active: + if ok { + panic(fmt.Sprintf("jsonrpc2: idleListener idle timer fired with %d connections still active", n)) + } else { + panic("jsonrpc2: Close finished with idle timer still running") + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired more than once") + + case <-l.idleTimer: + // The timer for this very call! + } + + // Close the Listener with all channels still blocked to ensure that this call + // to l.wrapped.Close doesn't race with the one in l.Close. + defer close(l.timedOut) + l.wrapped.Close() +} + +func (l *idleListener) connClosed() { + select { + case n, ok := <-l.active: + if !ok { + // l is already closed, so it can't close due to idleness, + // and we don't need to track the number of active connections any more. + return + } + n-- + if n == 0 { + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + } else { + l.active <- n + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired before last active connection was closed") + + case <-l.idleTimer: + panic("jsonrpc2: idleListener idle timer active before last active connection was closed") + } +} + +type idleListenerConn struct { + wrapped io.ReadWriteCloser + l *idleListener + closeOnce sync.Once +} + +func (l *idleListener) newConn(rwc io.ReadWriteCloser) *idleListenerConn { + c := &idleListenerConn{ + wrapped: rwc, + l: l, + } + + // A caller that forgets to call Close may disrupt the idleListener's + // accounting, even though the file descriptor for the underlying connection + // may eventually be garbage-collected anyway. + // + // Set a (best-effort) finalizer to verify that a Close call always occurs. + // (We will clear the finalizer explicitly in Close.) + runtime.SetFinalizer(c, func(c *idleListenerConn) { + panic("jsonrpc2: IdleListener connection became unreachable without a call to Close") + }) + + return c +} + +func (c *idleListenerConn) Read(p []byte) (int, error) { return c.wrapped.Read(p) } +func (c *idleListenerConn) Write(p []byte) (int, error) { return c.wrapped.Write(p) } + +func (c *idleListenerConn) Close() error { + defer c.closeOnce.Do(func() { + c.l.connClosed() + runtime.SetFinalizer(c, nil) + }) + return c.wrapped.Close() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go new file mode 100644 index 000000000..c0a41bffb --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go @@ -0,0 +1,97 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" +) + +// This file contains the go forms of the wire specification. +// see http://www.jsonrpc.org/specification for details + +var ( + // ErrParse is used when invalid JSON was received by the server. + ErrParse = NewError(-32700, "parse error") + // ErrInvalidRequest is used when the JSON sent is not a valid Request object. + ErrInvalidRequest = NewError(-32600, "invalid request") + // ErrMethodNotFound should be returned by the handler when the method does + // not exist / is not available. + ErrMethodNotFound = NewError(-32601, "method not found") + // ErrInvalidParams should be returned by the handler when method + // parameter(s) were invalid. + ErrInvalidParams = NewError(-32602, "invalid params") + // ErrInternal indicates a failure to process a call correctly + ErrInternal = NewError(-32603, "internal error") + + // The following errors are not part of the json specification, but + // compliant extensions specific to this implementation. + + // ErrServerOverloaded is returned when a message was refused due to a + // server being temporarily unable to accept any new messages. + ErrServerOverloaded = NewError(-32000, "overloaded") + // ErrUnknown should be used for all non coded errors. + ErrUnknown = NewError(-32001, "unknown error") + // ErrServerClosing is returned for calls that arrive while the server is closing. + ErrServerClosing = NewError(-32004, "server is closing") + // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. + ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32005, "rejected by transport") +) + +const wireVersion = "2.0" + +// wireCombined has all the fields of both Request and Response. +// We can decode this and then work out which it is. +type wireCombined struct { + VersionTag string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *WireError `json:"error,omitempty"` +} + +// WireError represents a structured error in a Response. +type WireError struct { + // Code is an error code indicating the type of failure. + Code int64 `json:"code"` + // Message is a short description of the error. + Message string `json:"message"` + // Data is optional structured data containing additional information about the error. + Data json.RawMessage `json:"data,omitempty"` +} + +// NewError returns an error that will encode on the wire correctly. +// The standard codes are made available from this package, this function should +// only be used to build errors for application specific codes as allowed by the +// specification. +func NewError(code int64, message string) error { + return &WireError{ + Code: code, + Message: message, + } +} + +func (err *WireError) Error() string { + return err.Message +} + +func (err *WireError) Is(other error) bool { + w, ok := other.(*WireError) + if !ok { + return false + } + return err.Code == w.Code +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go new file mode 100644 index 000000000..4b5c325fa --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go @@ -0,0 +1,44 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package util + +import ( + "cmp" + "fmt" + "iter" + "slices" +) + +// Helpers below are copied from gopls' moremaps package. + +// Sorted returns an iterator over the entries of m in key order. +func Sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] { + // TODO(adonovan): use maps.Sorted if proposal #68598 is accepted. + return func(yield func(K, V) bool) { + keys := KeySlice(m) + slices.Sort(keys) + for _, k := range keys { + if !yield(k, m[k]) { + break + } + } + } +} + +// KeySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)). +func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// Wrapf wraps *errp with the given formatted message if *errp is not nil. +func Wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go new file mode 100644 index 000000000..849060d57 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go @@ -0,0 +1,23 @@ +// Copyright 2019 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package xcontext is a package to offer the extra functionality we need +// from contexts that is not available from the standard context package. +package xcontext + +import ( + "context" + "time" +) + +// Detach returns a context that keeps all the values of its parent context +// but detaches from the cancellation and error handling. +func Detach(ctx context.Context) context.Context { return detachedContext{ctx} } + +type detachedContext struct{ parent context.Context } + +func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } +func (v detachedContext) Done() <-chan struct{} { return nil } +func (v detachedContext) Err() error { return nil } +func (v detachedContext) Value(key any) any { return v.parent.Value(key) } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go new file mode 100644 index 000000000..a9ea78fa8 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go @@ -0,0 +1,56 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response + // Error is a structured error in a JSON-RPC response. + Error = jsonrpc2.WireError +) + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +func MakeID(v any) (ID, error) { + return jsonrpc2.MakeID(v) +} + +// EncodeMessage serializes a JSON-RPC message to its wire format. +func EncodeMessage(msg Message) ([]byte, error) { + return jsonrpc2.EncodeMessage(msg) +} + +// DecodeMessage deserializes JSON-RPC wire format data into a Message. +// It returns either a Request or Response based on the message content. +func DecodeMessage(data []byte) (Message, error) { + return jsonrpc2.DecodeMessage(data) +} + +// Standard JSON-RPC 2.0 error codes. +// See https://www.jsonrpc.org/specification#error_object +const ( + // CodeParseError indicates invalid JSON was received by the server. + CodeParseError = -32700 + // CodeInvalidRequest indicates the JSON sent is not a valid Request object. + CodeInvalidRequest = -32600 + // CodeMethodNotFound indicates the method does not exist or is not available. + CodeMethodNotFound = -32601 + // CodeInvalidParams indicates invalid method parameter(s). + CodeInvalidParams = -32602 + // CodeInternalError indicates an internal JSON-RPC error. + CodeInternalError = -32603 +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go new file mode 100644 index 000000000..2dc1a86c0 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go @@ -0,0 +1,1075 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "iter" + "log/slog" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// A Client is an MCP client, which may be connected to an MCP server +// using the [Client.Connect] method. +type Client struct { + impl *Implementation + opts ClientOptions + logger *slog.Logger // TODO: file proposal to export this + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler +} + +// NewClient creates a new [Client]. +// +// Use [Client.Connect] to connect it to an MCP server. +// +// The first argument must not be nil. +// +// If non-nil, the provided options configure the Client. +func NewClient(impl *Implementation, opts *ClientOptions) *Client { + if impl == nil { + panic("nil Implementation") + } + c := &Client{ + impl: impl, + logger: ensureLogger(nil), // ensure we have a logger + roots: newFeatureSet(func(r *Root) string { return r.URI }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + } + if opts != nil { + c.opts = *opts + } + return c +} + +// ClientOptions configures the behavior of the client. +type ClientOptions struct { + // CreateMessageHandler handles incoming requests for sampling/createMessage. + // + // Setting CreateMessageHandler to a non-nil value automatically causes the + // client to advertise the sampling capability, with default value + // &SamplingCapabilities{}. If [ClientOptions.Capabilities] is set and has a + // non nil value for [ClientCapabilities.Sampling], that value overrides the + // inferred capability. + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // ElicitationHandler handles incoming requests for elicitation/create. + // + // Setting ElicitationHandler to a non-nil value automatically causes the + // client to advertise the elicitation capability, with default value + // &ElicitationCapabilities{}. If [ClientOptions.Capabilities] is set and has + // a non nil value for [ClientCapabilities.ELicitattion], that value + // overrides the inferred capability. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) + // Capabilities optionally configures the client's default capabilities, + // before any capabilities are inferred from other configuration. + // + // If Capabilities is nil, the default client capabilities are + // {"roots":{"listChanged":true}}, for historical reasons. Setting + // Capabilities to a non-nil value overrides this default. As a special case, + // to work around #607, Capabilities.Roots is ignored: set + // Capabilities.RootsV2 to configure the roots capability. This allows the + // "roots" capability to be disabled entirely. + // + // For example: + // - To disable the "roots" capability, use &ClientCapabilities{} + // - To configure "roots", but disable "listChanged" notifications, use + // &ClientCapabilities{RootsV2:&RootCapabilities{}}. + // + // # Interaction with capability inference + // + // Sampling and elicitation capabilities are automatically added when their + // corresponding handlers are set, with the default value described at + // [ClientOptions.CreateMessageHandler] and + // [ClientOptions.ElicitationHandler]. If the Sampling or Elicitation fields + // are set in the Capabilities field, their values override the inferred + // value. + // + // For example, to to configure elicitation modes: + // + // Capabilities: &ClientCapabilities{ + // Elicitation: &ElicitationCapabilities{ + // Form: &FormElicitationCapabilities{}, + // URL: &URLElicitationCapabilities{}, + // }, + // } + // + // Conversely, if Capabilities does not set a field (for example, if the + // Elicitation field is nil), the inferred elicitation capability will be + // used. + Capabilities *ClientCapabilities + // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. + ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) + // Handlers for notifications from the server. + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration +} + +// bind implements the binder[*ClientSession] interface, so that Clients can +// be connected using [connect]. +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} + if state != nil { + cs.state = *state + } + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = append(c.sessions, cs) + return cs +} + +// disconnect implements the binder[*Client] interface, so that +// Clients can be connected using [connect]. +func (c *Client) disconnect(cs *ClientSession) { + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = slices.DeleteFunc(c.sessions, func(cs2 *ClientSession) bool { + return cs2 == cs + }) +} + +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct { + // protocolVersion overrides the protocol version sent in the initialize + // request, for testing. If empty, latestProtocolVersion is used. + protocolVersion string +} + +func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ClientCapabilities + if c.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = c.opts.Capabilities.clone() + } else { + // SDK defaults: roots with listChanged. + // (this was the default behavior at v1.0.0, and so cannot be changed) + caps = &ClientCapabilities{ + RootsV2: &RootCapabilities{ + ListChanged: true, + }, + } + } + + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + + // Augment with sampling capability if handler is set. + if c.opts.CreateMessageHandler != nil { + if caps.Sampling == nil { + caps.Sampling = &SamplingCapabilities{} + } + } + + // Augment with elicitation capability if handler is set. + if c.opts.ElicitationHandler != nil { + if caps.Elicitation == nil { + caps.Elicitation = &ElicitationCapabilities{} + // Form elicitation was added in 2025-11-25; for older versions, + // {} is treated the same as {"form":{}}. + if protocolVersion >= protocolVersion20251125 { + caps.Elicitation.Form = &FormElicitationCapabilities{} + } + } + } + return caps +} + +// Connect begins an MCP session by connecting to a server over the given +// transport. The resulting session is initialized, and ready to use. +// +// Typically, it is the responsibility of the client to close the connection +// when it is no longer needed. However, if the connection is closed by the +// server, calls or notifications will return an error wrapping +// [ErrConnectionClosed]. +func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) + if err != nil { + return nil, err + } + + protocolVersion := latestProtocolVersion + if opts != nil && opts.protocolVersion != "" { + protocolVersion = opts.protocolVersion + } + params := &InitializeParams{ + ProtocolVersion: protocolVersion, + ClientInfo: c.impl, + Capabilities: c.capabilities(protocolVersion), + } + req := &InitializeRequest{Session: cs, Params: params} + res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) + if err != nil { + _ = cs.Close() + return nil, err + } + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } + cs.state.InitializeResult = res + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + req2 := &initializedClientRequest{Session: cs, Params: &InitializedParams{}} + if err := handleNotify(ctx, notificationInitialized, req2); err != nil { + _ = cs.Close() + return nil, err + } + + if c.opts.KeepAlive > 0 { + cs.startKeepalive(c.opts.KeepAlive) + } + + return cs, nil +} + +// A ClientSession is a logical connection with an MCP server. Its +// methods can be used to send requests or notifications to the server. Create +// a session by calling [Client.Connect]. +// +// Call [ClientSession.Close] to close the connection, or await server +// termination with [ClientSession.Wait]. +type ClientSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState + + // Pending URL elicitations waiting for completion notifications. + pendingElicitationsMu sync.Mutex + pendingElicitations map[string]chan struct{} +} + +type clientSessionState struct { + InitializeResult *InitializeResult +} + +func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } + +func (cs *ClientSession) ID() string { + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Close performs a graceful close of the connection, preventing new requests +// from being handled, and waiting for ongoing requests to return. Close then +// terminates the connection. +// +// Close is idempotent and concurrency safe. +func (cs *ClientSession) Close() error { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + if cs.keepaliveCancel != nil { + cs.keepaliveCancel() + } + err := cs.conn.Close() + + if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { + cs.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the server. +// Generally, clients should be responsible for closing the connection. +func (cs *ClientSession) Wait() error { + return cs.conn.Wait() +} + +// registerElicitationWaiter registers a waiter for an elicitation complete +// notification with the given elicitation ID. It returns two functions: an await +// function that waits for the notification or context cancellation, and a cleanup +// function that must be called to unregister the waiter. This must be called before +// triggering the elicitation to avoid a race condition where the notification +// arrives before the waiter is registered. +// +// The cleanup function must be called even if the await function is never called, +// to prevent leaking the registration. +func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) { + // Create a channel for this elicitation. + ch := make(chan struct{}, 1) + + // Register the channel. + cs.pendingElicitationsMu.Lock() + if cs.pendingElicitations == nil { + cs.pendingElicitations = make(map[string]chan struct{}) + } + cs.pendingElicitations[elicitationID] = ch + cs.pendingElicitationsMu.Unlock() + + // Return await and cleanup functions. + await = func(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err()) + case <-ch: + return nil + } + } + + cleanup = func() { + cs.pendingElicitationsMu.Lock() + delete(cs.pendingElicitations, elicitationID) + cs.pendingElicitationsMu.Unlock() + } + + return await, cleanup +} + +// startKeepalive starts the keepalive mechanism for this client session. +func (cs *ClientSession) startKeepalive(interval time.Duration) { + startKeepalive(cs, interval, &cs.keepaliveCancel) +} + +// AddRoots adds the given roots to the client, +// replacing any with the same URIs, +// and notifies any connected servers. +func (c *Client) AddRoots(roots ...*Root) { + // Only notify if something could change. + if len(roots) == 0 { + return + } + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { c.roots.add(roots...); return true }) +} + +// RemoveRoots removes the roots with the given URIs, +// and notifies any connected servers if the list has changed. +// It is not an error to remove a nonexistent root. +func (c *Client) RemoveRoots(uris ...string) { + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { return c.roots.remove(uris...) }) +} + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it notifies a snapshot of the sessions. +func changeAndNotify[P Params](c *Client, notification string, params P, change func() bool) { + var sessions []*ClientSession + // Lock for the change, but not for the notification. + c.mu.Lock() + if change() { + // Check if listChanged is enabled for this notification type. + if c.shouldSendListChangedNotification(notification) { + sessions = slices.Clone(c.sessions) + } + } + c.mu.Unlock() + notifySessions(sessions, notification, params, c.logger) +} + +// shouldSendListChangedNotification checks if the client's capabilities allow +// sending the given list-changed notification. +func (c *Client) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := c.opts.Capabilities + + switch notification { + case notificationRootsListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil { + return true + } + // Check RootsV2 first (preferred), then fall back to Roots. + if caps.RootsV2 != nil { + return caps.RootsV2.ListChanged + } + return caps.Roots.ListChanged + default: + // Unknown notification, allow by default. + return true + } +} + +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + roots := slices.Collect(c.roots.all()) + if roots == nil { + roots = []*Root{} // avoid JSON null + } + return &ListRootsResult{ + Roots: roots, + }, nil +} + +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + if c.opts.CreateMessageHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} + } + return c.opts.CreateMessageHandler(ctx, req) +} + +// urlElicitationMiddleware returns middleware that automatically handles URL elicitation +// required errors by executing the elicitation handler, waiting for completion notifications, +// and retrying the operation. +// +// This middleware should be added to clients that want automatic URL elicitation handling: +// +// client := mcp.NewClient(impl, opts) +// client.AddSendingMiddleware(mcp.urlElicitationMiddleware()) +// +// TODO(rfindley): this isn't strictly necessary for the SEP, but may be +// useful. Propose exporting it. +func urlElicitationMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + // Call the underlying handler. + res, err := next(ctx, method, req) + if err == nil { + return res, nil + } + + // Check if this is a URL elicitation required error. + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired { + return res, err + } + + // Notifications don't support retries. + if strings.HasPrefix(method, "notifications/") { + return res, err + } + + // Extract the client session. + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, err + } + + // Check if the client has an elicitation handler. + if cs.client.opts.ElicitationHandler == nil { + return res, err + } + + // Parse the elicitations from the error data. + var errorData struct { + Elicitations []*ElicitParams `json:"elicitations"` + } + if rpcErr.Data != nil { + if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil { + return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err) + } + } + + // Validate that all elicitations are URL mode. + for _, elicit := range errorData.Elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // Default mode. + } + if mode != "url" { + return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode) + } + } + + // Register waiters for all elicitations before executing handlers + // to avoid race condition where notification arrives before waiter is registered. + type waiter struct { + await func(context.Context) error + cleanup func() + } + waiters := make([]waiter, 0, len(errorData.Elicitations)) + for _, elicitParams := range errorData.Elicitations { + await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID) + waiters = append(waiters, waiter{await: await, cleanup: cleanup}) + } + + // Ensure cleanup happens even if we return early. + defer func() { + for _, w := range waiters { + w.cleanup() + } + }() + + // Execute the elicitation handler for each elicitation. + for _, elicitParams := range errorData.Elicitations { + elicitReq := newClientRequest(cs, elicitParams) + _, elicitErr := cs.client.elicit(ctx, elicitReq) + if elicitErr != nil { + return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr) + } + } + + // Wait for all elicitations to complete. + for _, w := range waiters { + if err := w.await(ctx); err != nil { + return nil, err + } + } + + // All elicitations complete, retry the original operation. + return next(ctx, method, req) + } + } +} + +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"} + } + + // Validate the elicitation parameters based on the mode. + mode := req.Params.Mode + if mode == "" { + mode = "form" + } + + switch mode { + case "form": + if req.Params.URL != "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"} + } + schema, err := validateElicitSchema(req.Params.RequestedSchema) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()} + } + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + // Validate elicitation result content against requested schema. + if schema != nil && res.Content != nil { + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)} + } + if err := resolved.Validate(res.Content); err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)} + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)} + } + } + return res, nil + case "url": + if req.Params.RequestedSchema != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"} + } + if req.Params.URL == "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"} + } + // No schema validation for URL mode, just pass through to handler. + return c.opts.ElicitationHandler(ctx, req) + default: + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)} + } +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) { + if wireSchema == nil { + return nil, nil // nil schema is allowed + } + + var schema *jsonschema.Schema + if err := remarshal(wireSchema, &schema); err != nil { + return nil, err + } + if schema == nil { + return nil, nil + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return nil, fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return nil, err + } + } + } + + return schema, nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]any); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return validateDefaultProperty[string](propName, propSchema) +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + intDefaultError := validateDefaultProperty[int](propName, propSchema) + floatDefaultError := validateDefaultProperty[float64](propName, propSchema) + if intDefaultError != nil && floatDefaultError != nil { + return fmt.Errorf("elicit schema property %q has default value that cannot be interpreted as an int or float", propName) + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + return validateDefaultProperty[bool](propName, propSchema) +} + +func validateDefaultProperty[T any](propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid T + if propSchema.Default != nil { + var defaultValue T + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a %T: %v", propName, defaultValue, err) + } + } + return nil +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (c *Client) AddSendingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.receivingMethodHandler_, middleware) +} + +// clientMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var clientMethodInfos = map[string]methodInfo{ + methodComplete: newClientMethodInfo(clientSessionMethod((*ClientSession).Complete), 0), + methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), + notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), + notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), +} + +func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { + return serverMethodInfos +} + +func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { + return clientMethodInfos +} + +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } + return handleReceive(ctx, cs, req) +} + +func (cs *ClientSession) sendingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendingMethodHandler_ +} + +func (cs *ClientSession) receivingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.receivingMethodHandler_ +} + +// getConn implements [Session.getConn]. +func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } + +func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { + return &ClientRequest[P]{Session: cs, Params: params} +} + +// Ping makes an MCP "ping" request to the server. +func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListPrompts lists prompts that are currently available on the server. +func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) +} + +// GetPrompt gets a prompt from the server. +func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) +} + +// ListTools lists tools that are currently available on the server. +func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) +} + +// CallTool calls the tool with the given parameters. +// +// The params.Arguments can be any value that marshals into a JSON object. +func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { + if params == nil { + params = new(CallToolParams) + } + if params.Arguments == nil { + // Avoid sending nil over the wire. + params.Arguments = map[string]any{} + } + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { + _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListResources lists the resources that are currently available on the server. +func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) +} + +// ListResourceTemplates lists the resource templates that are currently available on the server. +func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { + return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) +} + +// ReadResource asks the server to read a resource and return its contents. +func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { + return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) +} + +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { + if h := c.opts.ToolListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { + if h := c.opts.PromptListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { + if h := c.opts.ResourceListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { + if h := c.opts.ResourceUpdatedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { + if h := c.opts.LoggingMessageHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { + if h := cs.client.opts.ProgressNotificationHandler; h != nil { + h(ctx, clientRequestFor(cs, params)) + } + return nil, nil +} + +func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) { + // Check if there's a pending elicitation waiting for this notification. + if cs, ok := req.GetSession().(*ClientSession); ok { + cs.pendingElicitationsMu.Lock() + if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists { + select { + case ch <- struct{}{}: + default: + // Channel already signaled. + } + } + cs.pendingElicitationsMu.Unlock() + } + + // Call the user's handler if provided. + if h := c.opts.ElicitationCompleteHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the client to the server +// associated with this session. +// This can be used if the client is performing a long-running task that was +// initiated by the server. +func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) +} + +// Tools provides an iterator for all tools available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[*Tool, error] { + if params == nil { + params = &ListToolsParams{} + } + return paginate(ctx, params, cs.ListTools, func(res *ListToolsResult) []*Tool { + return res.Tools + }) +} + +// Resources provides an iterator for all resources available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[*Resource, error] { + if params == nil { + params = &ListResourcesParams{} + } + return paginate(ctx, params, cs.ListResources, func(res *ListResourcesResult) []*Resource { + return res.Resources + }) +} + +// ResourceTemplates provides an iterator for all resource templates available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) iter.Seq2[*ResourceTemplate, error] { + if params == nil { + params = &ListResourceTemplatesParams{} + } + return paginate(ctx, params, cs.ListResourceTemplates, func(res *ListResourceTemplatesResult) []*ResourceTemplate { + return res.ResourceTemplates + }) +} + +// Prompts provides an iterator for all prompts available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[*Prompt, error] { + if params == nil { + params = &ListPromptsParams{} + } + return paginate(ctx, params, cs.ListPrompts, func(res *ListPromptsResult) []*Prompt { + return res.Prompts + }) +} + +// paginate is a generic helper function to provide a paginated iterator. +func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + for { + res, err := listFunc(ctx, params) + if err != nil { + yield(nil, err) + return + } + for _, r := range items(res) { + if !yield(r, nil) { + return + } + } + nextCursorVal := res.nextCursorPtr() + if nextCursorVal == nil || *nextCursorVal == "" { + return + } + *params.cursorPtr() = *nextCursorVal + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go new file mode 100644 index 000000000..b531eaf13 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go @@ -0,0 +1,108 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "io" + "os/exec" + "syscall" + "time" +) + +var defaultTerminateDuration = 5 * time.Second // mutable for testing + +// A CommandTransport is a [Transport] that runs a command and communicates +// with it over stdin/stdout, using newline-delimited JSON. +type CommandTransport struct { + Command *exec.Cmd + // TerminateDuration controls how long Close waits after closing stdin + // for the process to exit before sending SIGTERM. + // If zero or negative, the default of 5s is used. + TerminateDuration time.Duration +} + +// Connect starts the command, and connects to it over stdin/stdout. +func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { + stdout, err := t.Command.StdoutPipe() + if err != nil { + return nil, err + } + stdout = io.NopCloser(stdout) // close the connection by closing stdin, not stdout + stdin, err := t.Command.StdinPipe() + if err != nil { + return nil, err + } + if err := t.Command.Start(); err != nil { + return nil, err + } + td := t.TerminateDuration + if td <= 0 { + td = defaultTerminateDuration + } + return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil +} + +// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over +// stdin/stdout pipes. +type pipeRWC struct { + cmd *exec.Cmd + stdout io.ReadCloser + stdin io.WriteCloser + terminateDuration time.Duration +} + +func (s *pipeRWC) Read(p []byte) (n int, err error) { + return s.stdout.Read(p) +} + +func (s *pipeRWC) Write(p []byte) (n int, err error) { + return s.stdin.Write(p) +} + +// Close closes the input stream to the child process, and awaits normal +// termination of the command. If the command does not exit, it is signalled to +// terminate, and then eventually killed. +func (s *pipeRWC) Close() error { + // Spec: + // "For the stdio transport, the client SHOULD initiate shutdown by:... + + // "...First, closing the input stream to the child process (the server)" + if err := s.stdin.Close(); err != nil { + return fmt.Errorf("closing stdin: %v", err) + } + resChan := make(chan error, 1) + go func() { + resChan <- s.cmd.Wait() + }() + // "...Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time" + wait := func() (error, bool) { + select { + case err := <-resChan: + return err, true + case <-time.After(s.terminateDuration): + } + return nil, false + } + if err, ok := wait(); ok { + return err + } + // Note the condition here: if sending SIGTERM fails, don't wait and just + // move on to SIGKILL. + if err := s.cmd.Process.Signal(syscall.SIGTERM); err == nil { + if err, ok := wait(); ok { + return err + } + } + // "...Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM" + if err := s.cmd.Process.Kill(); err != nil { + return err + } + if err, ok := wait(); ok { + return err + } + return fmt.Errorf("unresponsive subprocess") +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go new file mode 100644 index 000000000..fb1a0d1e5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go @@ -0,0 +1,289 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + +package mcp + +import ( + "encoding/json" + "errors" + "fmt" +) + +// A Content is a [TextContent], [ImageContent], [AudioContent], +// [ResourceLink], or [EmbeddedResource]. +type Content interface { + MarshalJSON() ([]byte, error) + fromWire(*wireContent) +} + +// TextContent is a textual content. +type TextContent struct { + Text string + Meta Meta + Annotations *Annotations +} + +func (c *TextContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ + Type: "text", + Text: c.Text, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *TextContent) fromWire(wire *wireContent) { + c.Text = wire.Text + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ImageContent contains base64-encoded image data. +type ImageContent struct { + Meta Meta + Annotations *Annotations + Data []byte // base64-encoded + MIMEType string +} + +func (c *ImageContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "image", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *ImageContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// AudioContent contains base64-encoded audio data. +type AudioContent struct { + Data []byte + MIMEType string + Meta Meta + Annotations *Annotations +} + +func (c AudioContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "audio", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *AudioContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// Custom wire format to ensure required fields are always included, even when empty. +type imageAudioWire struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + +// ResourceLink is a link to a resource +type ResourceLink struct { + URI string + Name string + Title string + Description string + MIMEType string + Size *int64 + Meta Meta + Annotations *Annotations + // Icons for the resource link, if any. + Icons []Icon `json:"icons,omitempty"` +} + +func (c *ResourceLink) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource_link", + URI: c.URI, + Name: c.Name, + Title: c.Title, + Description: c.Description, + MIMEType: c.MIMEType, + Size: c.Size, + Meta: c.Meta, + Annotations: c.Annotations, + Icons: c.Icons, + }) +} + +func (c *ResourceLink) fromWire(wire *wireContent) { + c.URI = wire.URI + c.Name = wire.Name + c.Title = wire.Title + c.Description = wire.Description + c.MIMEType = wire.MIMEType + c.Size = wire.Size + c.Meta = wire.Meta + c.Annotations = wire.Annotations + c.Icons = wire.Icons +} + +// EmbeddedResource contains embedded resources. +type EmbeddedResource struct { + Resource *ResourceContents + Meta Meta + Annotations *Annotations +} + +func (c *EmbeddedResource) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource", + Resource: c.Resource, + Meta: c.Meta, + Annotations: c.Annotations, + }) +} + +func (c *EmbeddedResource) fromWire(wire *wireContent) { + c.Resource = wire.Resource + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ResourceContents contains the contents of a specific resource or +// sub-resource. +type ResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` + Meta Meta `json:"_meta,omitempty"` +} + +func (r *ResourceContents) MarshalJSON() ([]byte, error) { + // If we could assume Go 1.24, we could use omitzero for Blob and avoid this method. + if r.URI == "" { + return nil, errors.New("ResourceContents missing URI") + } + if r.Blob == nil { + // Text. Marshal normally. + type wireResourceContents ResourceContents // (lacks MarshalJSON method) + return json.Marshal((wireResourceContents)(*r)) + } + // Blob. + if r.Text != "" { + return nil, errors.New("ResourceContents has non-zero Text and Blob fields") + } + // r.Blob may be the empty slice, so marshal with an alternative definition. + br := struct { + URI string `json:"uri,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` + }{ + URI: r.URI, + MIMEType: r.MIMEType, + Blob: r.Blob, + Meta: r.Meta, + } + return json.Marshal(br) +} + +// wireContent is the wire format for content. +// It represents the protocol types TextContent, ImageContent, AudioContent, +// ResourceLink, and EmbeddedResource. +// The Type field distinguishes them. In the protocol, each type has a constant +// value for the field. +// At most one of Text, Data, Resource, and URI is non-zero. +type wireContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitempty"` + Resource *ResourceContents `json:"resource,omitempty"` + URI string `json:"uri,omitempty"` + Name string `json:"name,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Size *int64 `json:"size,omitempty"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + Icons []Icon `json:"icons,omitempty"` +} + +func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { + var blocks []Content + for _, wire := range wires { + block, err := contentFromWire(wire, allow) + if err != nil { + return nil, err + } + blocks = append(blocks, block) + } + return blocks, nil +} + +func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("nil content") + } + if allow != nil && !allow[wire.Type] { + return nil, fmt.Errorf("invalid content type %q", wire.Type) + } + switch wire.Type { + case "text": + v := new(TextContent) + v.fromWire(wire) + return v, nil + case "image": + v := new(ImageContent) + v.fromWire(wire) + return v, nil + case "audio": + v := new(AudioContent) + v.fromWire(wire) + return v, nil + case "resource_link": + v := new(ResourceLink) + v.fromWire(wire) + return v, nil + case "resource": + v := new(EmbeddedResource) + v.fromWire(wire) + return v, nil + } + return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go new file mode 100644 index 000000000..5c322c4a3 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go @@ -0,0 +1,429 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file is for SSE events. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events. + +package mcp + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "iter" + "maps" + "net/http" + "slices" + "strings" + "sync" +) + +// If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. +// Enable for debugging. +const validateMemoryEventStore = false + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field + Retry string // the "retry" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 && e.Retry == "" +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + if evt.Retry != "" { + fmt.Fprintf(&b, "retry: %s\n", evt.Retry) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + retryKey = []byte("retry") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, retryKey): + evt.Retry = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} + +// An EventStore tracks data for SSE streams. +// A single EventStore suffices for all sessions, since session IDs are +// globally unique. So one EventStore can be created per process, for +// all Servers in the process. +// Such a store is able to bound resource usage for the entire process. +// +// All of an EventStore's methods must be safe for use by multiple goroutines. +type EventStore interface { + // Open is called when a new stream is created. It may be used to ensure that + // the underlying data structure for the stream is initialized, making it + // ready to store and replay event streams. + Open(_ context.Context, sessionID, streamID string) error + + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID, streamID string, data []byte) error + + // After returns an iterator over the data for the given session and stream, beginning + // just after the given index. + // + // Once the iterator yields a non-nil error, it will stop. + // After's iterator must return an error immediately if any data after index was + // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] + + // SessionClosed informs the store that the given session is finished, along + // with all of its streams. + // + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. +} + +// A dataList is a list of []byte. +// The zero dataList is ready to use. +type dataList struct { + size int // total size of data bytes + first int // the stream index of the first element in data + data [][]byte +} + +func (dl *dataList) appendData(d []byte) { + // Empty data consumes memory but doesn't increment size. However, it should + // be rare. + dl.data = append(dl.data, d) + dl.size += len(d) +} + +// removeFirst removes the first data item in dl, returning the size of the item. +// It panics if dl is empty. +func (dl *dataList) removeFirst() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + r := len(dl.data[0]) + dl.size -= r + dl.data[0] = nil // help GC + dl.data = dl.data[1:] + dl.first++ + return r +} + +// A MemoryEventStore is an [EventStore] backed by memory. +type MemoryEventStore struct { + mu sync.Mutex + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList +} + +// MemoryEventStoreOptions are options for a [MemoryEventStore]. +type MemoryEventStoreOptions struct{} + +// MaxBytes returns the maximum number of bytes that the store will retain before +// purging data. +func (s *MemoryEventStore) MaxBytes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.maxBytes +} + +// SetMaxBytes sets the maximum number of bytes the store will retain before purging +// data. The argument must not be negative. If it is zero, a suitable default will be used. +// SetMaxBytes can be called at any time. The size of the store will be adjusted +// immediately. +func (s *MemoryEventStore) SetMaxBytes(n int) { + s.mu.Lock() + defer s.mu.Unlock() + switch { + case n < 0: + panic("negative argument") + case n == 0: + s.maxBytes = defaultMaxBytes + default: + s.maxBytes = n + } + s.purge() +} + +const defaultMaxBytes = 10 << 20 // 10 MiB + +// NewMemoryEventStore creates a [MemoryEventStore] with the default value +// for MaxBytes. +func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { + return &MemoryEventStore{ + maxBytes: defaultMaxBytes, + store: make(map[string]map[string]*dataList), + } +} + +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} + +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { + streamMap, ok := s.store[sessionID] + if !ok { + streamMap = make(map[string]*dataList) + s.store[sessionID] = streamMap + } + dl, ok := streamMap[streamID] + if !ok { + dl = &dataList{} + streamMap[streamID] = dl + } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) + // Purge before adding, so at least the current data item will be present. + // (That could result in nBytes > maxBytes, but we'll live with that.) + s.purge() + dl.appendData(data) + s.nBytes += len(data) + return nil +} + +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + +// After implements [EventStore.After]. +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { + // Return the data items to yield. + // We must copy, because dataList.removeFirst nils out slice elements. + copyData := func() ([][]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + streamMap, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + } + dl, ok := streamMap[streamID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + } + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) + } + return slices.Clone(dl.data[start-dl.first:]), nil + } + + return func(yield func([]byte, error) bool) { + ds, err := copyData() + if err != nil { + yield(nil, err) + return + } + for _, d := range ds { + if !yield(d, nil) { + return + } + } + } +} + +// SessionClosed implements [EventStore.SessionClosed]. +func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, dl := range s.store[sessionID] { + s.nBytes -= dl.size + } + delete(s.store, sessionID) + s.validate() + return nil +} + +// purge removes data until no more than s.maxBytes bytes are in use. +// It must be called with s.mu held. +func (s *MemoryEventStore) purge() { + // Remove the first element of every dataList until below the max. + for s.nBytes > s.maxBytes { + changed := false + for _, sm := range s.store { + for _, dl := range sm { + if dl.size > 0 { + r := dl.removeFirst() + if r > 0 { + changed = true + s.nBytes -= r + } + } + } + } + if !changed { + panic("no progress during purge") + } + } + s.validate() +} + +// validate checks that the store's data structures are valid. +// It must be called with s.mu held. +func (s *MemoryEventStore) validate() { + if !validateMemoryEventStore { + return + } + // Check that we're accounting for the size correctly. + n := 0 + for _, sm := range s.store { + for _, dl := range sm { + for _, d := range dl.data { + n += len(d) + } + } + } + if n != s.nBytes { + panic("sizes don't add up") + } +} + +// debugString returns a string containing the state of s. +// Used in tests. +func (s *MemoryEventStore) debugString() string { + s.mu.Lock() + defer s.mu.Unlock() + var b strings.Builder + for i, sess := range slices.Sorted(maps.Keys(s.store)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + sm := s.store[sess] + for i, sid := range slices.Sorted(maps.Keys(sm)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + dl := sm[sid] + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) + for _, d := range dl.data { + fmt.Fprintf(&b, " %s", d) + } + } + } + return b.String() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go new file mode 100644 index 000000000..438370fe5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go @@ -0,0 +1,114 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "iter" + "maps" + "slices" +) + +// This file contains implementations that are common to all features. +// A feature is an item provided to a peer. In the 2025-03-26 spec, +// the features are prompt, tool, resource and root. + +// A featureSet is a collection of features of type T. +// Every feature has a unique ID, and the spec never mentions +// an ordering for the List calls, so what it calls a "list" is actually a set. +// +// An alternative implementation would use an ordered map, but that's probably +// not necessary as adds and removes are rare, and usually batched. +type featureSet[T any] struct { + uniqueID func(T) string + features map[string]T + sortedKeys []string // lazily computed; nil after add or remove +} + +// newFeatureSet creates a new featureSet for features of type T. +// The argument function should return the unique ID for a single feature. +func newFeatureSet[T any](uniqueIDFunc func(T) string) *featureSet[T] { + return &featureSet[T]{ + uniqueID: uniqueIDFunc, + features: make(map[string]T), + } +} + +// add adds each feature to the set if it is not present, +// or replaces an existing feature. +func (s *featureSet[T]) add(fs ...T) { + for _, f := range fs { + s.features[s.uniqueID(f)] = f + } + s.sortedKeys = nil +} + +// remove removes all features with the given uids from the set if present, +// and returns whether any were removed. +// It is not an error to remove a nonexistent feature. +func (s *featureSet[T]) remove(uids ...string) bool { + changed := false + for _, uid := range uids { + if _, ok := s.features[uid]; ok { + changed = true + delete(s.features, uid) + } + } + if changed { + s.sortedKeys = nil + } + return changed +} + +// get returns the feature with the given uid. +// If there is none, it returns zero, false. +func (s *featureSet[T]) get(uid string) (T, bool) { + t, ok := s.features[uid] + return t, ok +} + +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + +// all returns an iterator over of all the features in the set +// sorted by unique ID. +func (s *featureSet[T]) all() iter.Seq[T] { + s.sortKeys() + return func(yield func(T) bool) { + s.yieldFrom(0, yield) + } +} + +// above returns an iterator over features in the set whose unique IDs are +// greater than `uid`, in ascending ID order. +func (s *featureSet[T]) above(uid string) iter.Seq[T] { + s.sortKeys() + index, found := slices.BinarySearch(s.sortedKeys, uid) + if found { + index++ + } + return func(yield func(T) bool) { + s.yieldFrom(index, yield) + } +} + +// sortKeys is a helper that maintains a sorted list of feature IDs. It +// computes this list lazily upon its first call after a modification, or +// if it's nil. +func (s *featureSet[T]) sortKeys() { + if s.sortedKeys != nil { + return + } + s.sortedKeys = slices.Sorted(maps.Keys(s.features)) +} + +// yieldFrom is a helper that iterates over the features in the set, +// starting at the given index, and calls the yield function for each one. +func (s *featureSet[T]) yieldFrom(index int, yield func(T) bool) { + for i := index; i < len(s.sortedKeys); i++ { + if !yield(s.features[s.sortedKeys[i]]) { + return + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go new file mode 100644 index 000000000..208427e22 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go @@ -0,0 +1,207 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "log/slog" + "sync" + "time" +) + +// Logging levels. +const ( + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelNotice = (slog.LevelInfo + slog.LevelWarn) / 2 + LevelWarning = slog.LevelWarn + LevelError = slog.LevelError + LevelCritical = slog.LevelError + 4 + LevelAlert = slog.LevelError + 8 + LevelEmergency = slog.LevelError + 12 +) + +var slogToMCP = map[slog.Level]LoggingLevel{ + LevelDebug: "debug", + LevelInfo: "info", + LevelNotice: "notice", + LevelWarning: "warning", + LevelError: "error", + LevelCritical: "critical", + LevelAlert: "alert", + LevelEmergency: "emergency", +} + +var mcpToSlog = make(map[LoggingLevel]slog.Level) + +func init() { + for sl, ml := range slogToMCP { + mcpToSlog[ml] = sl + } +} + +func slogLevelToMCP(sl slog.Level) LoggingLevel { + if ml, ok := slogToMCP[sl]; ok { + return ml + } + return "debug" // for lack of a better idea +} + +func mcpLevelToSlog(ll LoggingLevel) slog.Level { + if sl, ok := mcpToSlog[ll]; ok { + return sl + } + // TODO: is there a better default? + return LevelDebug +} + +// compareLevels behaves like [cmp.Compare] for [LoggingLevel]s. +func compareLevels(l1, l2 LoggingLevel) int { + return cmp.Compare(mcpLevelToSlog(l1), mcpLevelToSlog(l2)) +} + +// LoggingHandlerOptions are options for a LoggingHandler. +type LoggingHandlerOptions struct { + // The value for the "logger" field of logging notifications. + LoggerName string + // Limits the rate at which log messages are sent. + // Excess messages are dropped. + // If zero, there is no rate limiting. + MinInterval time.Duration +} + +// A LoggingHandler is a [slog.Handler] for MCP. +type LoggingHandler struct { + opts LoggingHandlerOptions + ss *ServerSession + // Ensures that the buffer reset is atomic with the write (see Handle). + // A pointer so that clones share the mutex. See + // https://github.com/golang/example/blob/master/slog-handler-guide/README.md#getting-the-mutex-right. + mu *sync.Mutex + lastMessageSent time.Time // for rate-limiting + buf *bytes.Buffer + handler slog.Handler +} + +// discardHandler is a slog.Handler that drops all logs. +// TODO: use slog.DiscardHandler when we require Go 1.24+. +type discardHandler struct{} + +func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } +func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } + +// ensureLogger returns l if non-nil, otherwise a discard logger. +func ensureLogger(l *slog.Logger) *slog.Logger { + if l != nil { + return l + } + return slog.New(discardHandler{}) +} + +// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a +// [slog.JSONHandler]. +func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { + var buf bytes.Buffer + jsonHandler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Remove level: it appears in LoggingMessageParams. + if a.Key == slog.LevelKey { + return slog.Attr{} + } + return a + }, + }) + lh := &LoggingHandler{ + ss: ss, + mu: new(sync.Mutex), + buf: &buf, + handler: jsonHandler, + } + if opts != nil { + lh.opts = *opts + } + return lh +} + +// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { + // This is also checked in ServerSession.LoggingMessage, so checking it here + // is just an optimization that skips building the JSON. + h.ss.mu.Lock() + mcpLevel := h.ss.state.LogLevel + h.ss.mu.Unlock() + return level >= mcpLevelToSlog(mcpLevel) +} + +// WithAttrs implements [slog.Handler.WithAttrs]. +func (h *LoggingHandler) WithAttrs(as []slog.Attr) slog.Handler { + h2 := *h + h2.handler = h.handler.WithAttrs(as) + return &h2 +} + +// WithGroup implements [slog.Handler.WithGroup]. +func (h *LoggingHandler) WithGroup(name string) slog.Handler { + h2 := *h + h2.handler = h.handler.WithGroup(name) + return &h2 +} + +// Handle implements [slog.Handler.Handle] by writing the Record to a JSONHandler, +// then calling [ServerSession.LoggingMessage] with the result. +func (h *LoggingHandler) Handle(ctx context.Context, r slog.Record) error { + err := h.handle(ctx, r) + // TODO(jba): find a way to surface the error. + // The return value will probably be ignored. + return err +} + +func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { + // Observe the rate limit. + // TODO(jba): use golang.org/x/time/rate. (We can't here because it would require adding + // golang.org/x/time to the go.mod file.) + h.mu.Lock() + skip := time.Since(h.lastMessageSent) < h.opts.MinInterval + h.mu.Unlock() + if skip { + return nil + } + + var err error + // Make the buffer reset atomic with the record write. + // We are careful here in the unlikely event that the handler panics. + // We don't want to hold the lock for the entire function, because Notify is + // an I/O operation. + // This can result in out-of-order delivery. + func() { + h.mu.Lock() + defer h.mu.Unlock() + h.buf.Reset() + err = h.handler.Handle(ctx, r) + }() + if err != nil { + return err + } + + h.mu.Lock() + h.lastMessageSent = time.Now() + h.mu.Unlock() + + params := &LoggingMessageParams{ + Logger: h.opts.LoggerName, + Level: slogLevelToMCP(r.Level), + Data: json.RawMessage(h.buf.Bytes()), + } + // We pass the argument context to Notify, even though slog.Handler.Handle's + // documentation says not to. + // In this case logging is a service to clients, not a means for debugging the + // server, so we want to cancel the log message. + return h.ss.Log(ctx, params) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go new file mode 100644 index 000000000..56e950b86 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go @@ -0,0 +1,88 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The mcp package provides an SDK for writing model context protocol clients +// and servers. +// +// To get started, create either a [Client] or [Server], add features to it +// using `AddXXX` functions, and connect it to a peer using a [Transport]. +// +// For example, to run a simple server on the [StdioTransport]: +// +// server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) +// +// // Using the generic AddTool automatically populates the the input and output +// // schema of the tool. +// type args struct { +// Name string `json:"name" jsonschema:"the person to greet"` +// } +// mcp.AddTool(server, &mcp.Tool{ +// Name: "greet", +// Description: "say hi", +// }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{ +// &mcp.TextContent{Text: "Hi " + args.Name}, +// }, +// }, nil, nil +// }) +// +// // Run the server on the stdio transport. +// if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { +// log.Printf("Server failed: %v", err) +// } +// +// To connect to this server, use the [CommandTransport]: +// +// client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) +// transport := &mcp.CommandTransport{Command: exec.Command("myserver")} +// session, err := client.Connect(ctx, transport, nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Close() +// +// params := &mcp.CallToolParams{ +// Name: "greet", +// Arguments: map[string]any{"name": "you"}, +// } +// res, err := session.CallTool(ctx, params) +// if err != nil { +// log.Fatalf("CallTool failed: %v", err) +// } +// +// # Clients, servers, and sessions +// +// In this SDK, both a [Client] and [Server] may handle many concurrent +// connections. Each time a client or server is connected to a peer using a +// [Transport], it creates a new session (either a [ClientSession] or +// [ServerSession]): +// +// Client Server +// ⇅ (jsonrpc2) ⇅ +// ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession +// +// The session types expose an API to interact with its peer. For example, +// [ClientSession.CallTool] or [ServerSession.ListRoots]. +// +// # Adding features +// +// Add MCP servers to your Client or Server using AddXXX methods (for example +// [Client.AddRoot] or [Server.AddPrompt]). If any peers are connected when +// AddXXX is called, they will receive a corresponding change notification +// (for example notifications/roots/list_changed). +// +// Adding tools is special: tools may be bound to ordinary Go functions by +// using the top-level generic [AddTool] function, which allows specifying an +// input and output type. When AddTool is used, the tool's input schema and +// output schema are automatically populated, and inputs are automatically +// validated. As a special case, if the output type is 'any', no output schema +// is generated. +// +// func double(_ context.Context, _ *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) { +// return nil, Out{Answer: 2*in.Number}, nil +// } +// ... +// mcp.AddTool(server, &mcp.Tool{Name: "double"}, double) +package mcp diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go new file mode 100644 index 000000000..62f38a36a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go @@ -0,0 +1,17 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" +) + +// A PromptHandler handles a call to prompts/get. +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) + +type serverPrompt struct { + prompt *Prompt + handler PromptHandler +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go new file mode 100644 index 000000000..26c8982f8 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go @@ -0,0 +1,1357 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// Protocol types for version 2025-06-18. +// To see the schema changes from the previous version, run: +// +// prefix=https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema +// sdiff -l <(curl $prefix/2025-03-26/schema.ts) <(curl $prefix/2025/06-18/schema.ts) + +import ( + "encoding/json" + "fmt" +) + +// Optional annotations for the client. The client can use annotations to inform +// how objects are used or displayed. +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., []Role{"user", "assistant"}). + Audience []Role `json:"audience,omitempty"` + // The moment the resource was last modified, as an ISO 8601 formatted string. + // + // Should be an ISO 8601 formatted string (e.g., "2025-01-12T15:00:58Z"). + // + // Examples: last activity timestamp in an open file, timestamp when the + // resource was attached, etc. + LastModified string `json:"lastModified,omitempty"` + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that the + // data is entirely optional. + Priority float64 `json:"priority,omitempty"` +} + +// CallToolParams is used by clients to call a tool. +type CallToolParams struct { + // Meta is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name"` + // Arguments holds the tool arguments. It can hold any value that can be + // marshaled to JSON. + Arguments any `json:"arguments,omitempty"` +} + +// CallToolParamsRaw is passed to tool handlers on the server. Its arguments +// are not yet unmarshaled (hence "raw"), so that the handlers can perform +// unmarshaling themselves. +type CallToolParamsRaw struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool being called. + Name string `json:"name"` + // Arguments is the raw arguments received over the wire from the client. It + // is the responsibility of the tool handler to unmarshal and validate the + // Arguments (see [AddTool]). + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +// A CallToolResult is the server's response to a tool call. +// +// The [ToolHandler] and [ToolHandlerFor] handler functions return this result, +// though [ToolHandlerFor] populates much of it automatically as documented at +// each field. +type CallToolResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + + // A list of content objects that represent the unstructured result of the tool + // call. + // + // When using a [ToolHandlerFor] with structured output, if Content is unset + // it will be populated with JSON text content corresponding to the + // structured output value. + Content []Content `json:"content"` + + // StructuredContent is an optional value that represents the structured + // result of the tool call. It must marshal to a JSON object. + // + // When using a [ToolHandlerFor] with structured output, you should not + // populate this field. It will be automatically populated with the typed Out + // value. + StructuredContent any `json:"structuredContent,omitempty"` + + // IsError reports whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + // + // Any errors that originate from the tool should be reported inside the + // Content field, with IsError set to true, not as an MCP protocol-level + // error response. Otherwise, the LLM would not be able to see that an error + // occurred and self-correct. + // + // However, any errors in finding the tool, an error indicating that the + // server does not support tool calls, or any other exceptional conditions, + // should be reported as an MCP error response. + // + // When using a [ToolHandlerFor], this field is automatically set when the + // tool handler returns an error, and the error string is included as text in + // the Content field. + IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error +} + +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err +} + +func (*CallToolResult) isResult() {} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion + var wire struct { + res + Content []*wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { + return err + } + *x = CallToolResult(wire.res) + return nil +} + +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } + +func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } + +type CancelledParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional string describing the reason for the cancellation. This may be + // logged or presented to the user. + Reason string `json:"reason,omitempty"` + // The ID of the request to cancel. + // + // This must correspond to the ID of a request previously issued in the same + // direction. + RequestID any `json:"requestId"` +} + +func (x *CancelledParams) isParams() {} +func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// RootCapabilities describes a client's support for roots. +type RootCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// Capabilities a client may support. Known capabilities are defined here, in +// this schema, but this is not a closed set: any client can define its own, +// additional capabilities. +type ClientCapabilities struct { + + // NOTE: any addition to ClientCapabilities must also be reflected in + // [ClientCapabilities.clone]. + + // Experimental reports non-standard capabilities that the client supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Roots describes the client's support for roots. + // + // Deprecated: use RootsV2. As described in #607, Roots should have been a + // pointer to a RootCapabilities value. Roots will be continue to be + // populated, but any new fields will only be added in the RootsV2 field. + Roots struct { + // ListChanged reports whether the client supports notifications for + // changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // RootsV2 is present if the client supports roots. When capabilities are explicitly configured via [ClientOptions.Capabilities] + RootsV2 *RootCapabilities `json:"-"` + // Sampling is present if the client supports sampling from an LLM. + Sampling *SamplingCapabilities `json:"sampling,omitempty"` + // Elicitation is present if the client supports elicitation from the server. + Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` +} + +// clone returns a deep copy of the ClientCapabilities. +func (c *ClientCapabilities) clone() *ClientCapabilities { + cp := *c + cp.RootsV2 = shallowClone(c.RootsV2) + cp.Sampling = shallowClone(c.Sampling) + if c.Elicitation != nil { + x := *c.Elicitation + x.Form = shallowClone(c.Elicitation.Form) + x.URL = shallowClone(c.Elicitation.URL) + cp.Elicitation = &x + } + return &cp +} + +// shallowClone returns a shallow clone of *p, or nil if p is nil. +func shallowClone[T any](p *T) *T { + if p == nil { + return nil + } + x := *p + return &x +} + +func (c *ClientCapabilities) toV2() *clientCapabilitiesV2 { + return &clientCapabilitiesV2{ + ClientCapabilities: *c, + Roots: c.RootsV2, + } +} + +// clientCapabilitiesV2 is a version of ClientCapabilities that fixes the bug +// described in #607: Roots should have been a pointer to value type +// RootCapabilities. +type clientCapabilitiesV2 struct { + ClientCapabilities + Roots *RootCapabilities `json:"roots,omitempty"` +} + +func (c *clientCapabilitiesV2) toV1() *ClientCapabilities { + caps := c.ClientCapabilities + caps.RootsV2 = c.Roots + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + return &caps +} + +type CompleteParamsArgument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` +} + +// CompleteContext represents additional, optional context for completions. +type CompleteContext struct { + // Previously-resolved variables in a URI template or prompt. + Arguments map[string]string `json:"arguments,omitempty"` +} + +// CompleteReference represents a completion reference type (ref/prompt ref/resource). +// The Type field determines which other fields are relevant. +type CompleteReference struct { + Type string `json:"type"` + // Name is relevant when Type is "ref/prompt". + Name string `json:"name,omitempty"` + // URI is relevant when Type is "ref/resource". + URI string `json:"uri,omitempty"` +} + +func (r *CompleteReference) UnmarshalJSON(data []byte) error { + type wireCompleteReference CompleteReference // for naive unmarshaling + var r2 wireCompleteReference + if err := json.Unmarshal(data, &r2); err != nil { + return err + } + switch r2.Type { + case "ref/prompt", "ref/resource": + if r2.Type == "ref/prompt" && r2.URI != "" { + return fmt.Errorf("reference of type %q must not have a URI set", r2.Type) + } + if r2.Type == "ref/resource" && r2.Name != "" { + return fmt.Errorf("reference of type %q must not have a Name set", r2.Type) + } + default: + return fmt.Errorf("unrecognized content type %q", r2.Type) + } + *r = CompleteReference(r2) + return nil +} + +func (r *CompleteReference) MarshalJSON() ([]byte, error) { + // Validation for marshalling: ensure consistency before converting to JSON. + switch r.Type { + case "ref/prompt": + if r.URI != "" { + return nil, fmt.Errorf("reference of type %q must not have a URI set for marshalling", r.Type) + } + case "ref/resource": + if r.Name != "" { + return nil, fmt.Errorf("reference of type %q must not have a Name set for marshalling", r.Type) + } + default: + return nil, fmt.Errorf("unrecognized reference type %q for marshalling", r.Type) + } + + type wireReference CompleteReference + return json.Marshal(wireReference(*r)) +} + +type CompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The argument's information + Argument CompleteParamsArgument `json:"argument"` + Context *CompleteContext `json:"context,omitempty"` + Ref *CompleteReference `json:"ref"` +} + +func (*CompleteParams) isParams() {} + +type CompletionResultDetails struct { + HasMore bool `json:"hasMore,omitempty"` + Total int `json:"total,omitempty"` + Values []string `json:"values"` +} + +// The server's response to a completion/complete request +type CompleteResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Completion CompletionResultDetails `json:"completion"` +} + +func (*CompleteResult) isResult() {} + +type CreateMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // A request to include context from one or more MCP servers (including the + // caller), to be attached to the prompt. The client may ignore this request. + IncludeContext string `json:"includeContext,omitempty"` + // The maximum number of tokens to sample, as requested by the server. The + // client may choose to sample fewer tokens than requested. + MaxTokens int64 `json:"maxTokens"` + Messages []*SamplingMessage `json:"messages"` + // Optional metadata to pass through to the LLM provider. The format of this + // metadata is provider-specific. + Metadata any `json:"metadata,omitempty"` + // The server's preferences for which model to select. The client may ignore + // these preferences. + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + // An optional system prompt the server wants to use for sampling. The client + // may modify or omit this prompt. + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a sampling/create_message request from the server. +// The client should inform the user before returning the sampled message, to +// allow them to inspect the response (human in the loop) and decide whether to +// allow the server to see it. +type CreateMessageResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Content Content `json:"content"` + // The name of the model that generated the message. + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +func (*CreateMessageResult) isResult() {} +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + +type GetPromptParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` + // The name of the prompt or prompt template. + Name string `json:"name"` +} + +func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } +func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a prompts/get request from the client. +type GetPromptResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []*PromptMessage `json:"messages"` +} + +func (*GetPromptResult) isResult() {} + +// InitializeParams is sent by the client to initialize the session. +type InitializeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Capabilities describes the client's capabilities. + Capabilities *ClientCapabilities `json:"capabilities"` + // ClientInfo provides information about the client. + ClientInfo *Implementation `json:"clientInfo"` + // ProtocolVersion is the latest version of the Model Context Protocol that + // the client supports. + ProtocolVersion string `json:"protocolVersion"` +} + +func (p *InitializeParams) toV2() *initializeParamsV2 { + return &initializeParamsV2{ + InitializeParams: *p, + Capabilities: p.Capabilities.toV2(), + } +} + +// initializeParamsV2 works around the mistake in #607: Capabilities.Roots +// should have been a pointer. +type initializeParamsV2 struct { + InitializeParams + Capabilities *clientCapabilitiesV2 `json:"capabilities"` +} + +func (p *initializeParamsV2) toV1() *InitializeParams { + p1 := p.InitializeParams + if p.Capabilities != nil { + p1.Capabilities = p.Capabilities.toV1() + } + return &p1 +} + +func (x *InitializeParams) isParams() {} +func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// InitializeResult is sent by the server in response to an initialize request +// from the client. +type InitializeResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Capabilities describes the server's capabilities. + Capabilities *ServerCapabilities `json:"capabilities"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of available + // tools, resources, etc. It can be thought of like a "hint" to the model. For + // example, this information may be added to the system prompt. + Instructions string `json:"instructions,omitempty"` + // The version of the Model Context Protocol that the server wants to use. This + // may not match the version that the client requested. If the client cannot + // support this version, it must disconnect. + ProtocolVersion string `json:"protocolVersion"` + ServerInfo *Implementation `json:"serverInfo"` +} + +func (*InitializeResult) isResult() {} + +type InitializedParams struct { + // Meta is reserved by the protocol to allow clients and servers to attach + // additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *InitializedParams) isParams() {} +func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ListPromptsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a prompts/list request from the client. +type ListPromptsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Prompts []*Prompt `json:"prompts"` +} + +func (x *ListPromptsResult) isResult() {} +func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourceTemplatesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` +} + +func (x *ListResourceTemplatesResult) isResult() {} +func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourcesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/list request from the client. +type ListResourcesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Resources []*Resource `json:"resources"` +} + +func (x *ListResourcesResult) isResult() {} +func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListRootsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a roots/list request from the server. This result +// contains an array of Root objects, each representing a root directory or file +// that the server can operate on. +type ListRootsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Roots []*Root `json:"roots"` +} + +func (*ListRootsResult) isResult() {} + +type ListToolsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a tools/list request from the client. +type ListToolsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Tools []*Tool `json:"tools"` +} + +func (x *ListToolsResult) isResult() {} +func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } + +// The severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +type LoggingMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` +} + +func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client should treat this as a substring of a model name; for example: - + // `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` - `sonnet` + // should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. - + // `claude` should match any Claude model + // + // The client may also map the string to a different provider's model name or a + // different model family, as long as it fills a similar niche; for example: - + // `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +// The server's preferences for model selection, requested of the client during +// sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" model is +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client may ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important factor. + CostPriority float64 `json:"costPriority,omitempty"` + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client must evaluate them in order (such + // that the first match is taken). + // + // The client should prioritize these hints over the numeric priorities, but may + // still use the priorities to select from ambiguous matches. + Hints []*ModelHint `json:"hints,omitempty"` + // How much to prioritize intelligence and capabilities when selecting a model. + // A value of 0 means intelligence is not important, while a value of 1 means + // intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` +} + +type PingParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PingParams) isParams() {} +func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ProgressNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The progress token which was given in the initial request, used to associate + // this notification with the request that is proceeding. + ProgressToken any `json:"progressToken"` + // An optional message describing the current progress. + Message string `json:"message,omitempty"` + // The progress thus far. This should increase every time progress is made, even + // if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + // Zero means unknown. + Total float64 `json:"total,omitempty"` +} + +func (*ProgressNotificationParams) isParams() {} + +// IconTheme specifies the theme an icon is designed for. +type IconTheme string + +const ( + // IconThemeLight indicates the icon is designed for a light background. + IconThemeLight IconTheme = "light" + // IconThemeDark indicates the icon is designed for a dark background. + IconThemeDark IconTheme = "dark" +) + +// Icon provides visual identifiers for their resources, tools, prompts, and implementations +// See [/specification/draft/basic/index#icons] for notes on icons +// +// TODO(iamsurajbobade): update specification url from draft. +type Icon struct { + // Source is A URI pointing to the icon resource (required). This can be: + // - An HTTP/HTTPS URL pointing to an image file + // - A data URI with base64-encoded image data + Source string `json:"src"` + // Optional MIME type if the server's type is missing or generic + MIMEType string `json:"mimeType,omitempty"` + // Optional size specification (e.g., ["48x48"], ["any"] for scalable formats like SVG, or ["48x48", "96x96"] for multiple sizes) + Sizes []string `json:"sizes,omitempty"` + // Optional theme specifier. "light" indicates the icon is designed for a light + // background, "dark" indicates the icon is designed for a dark background. + Theme IconTheme `json:"theme,omitempty"` +} + +// A prompt or prompt template that the server offers. +type Prompt struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // A list of arguments to use for templating the prompt. + Arguments []*PromptArgument `json:"arguments,omitempty"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + // Icons for the prompt, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// Describes an argument that a prompt can accept. +type PromptArgument struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + Required bool `json:"required,omitempty"` +} + +type PromptListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Describes a message returned as part of a prompt. +// +// This is similar to SamplingMessage, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *PromptMessage) UnmarshalJSON(data []byte) error { + type msg PromptMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, nil); err != nil { + return err + } + *m = PromptMessage(wire.msg) + return nil +} + +type ReadResourceParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to read. The URI can use any protocol; it is up to + // the server how to interpret it. + URI string `json:"uri"` +} + +func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a resources/read request from the client. +type ReadResourceResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Contents []*ResourceContents `json:"contents"` +} + +func (*ReadResourceResult) isResult() {} + +// A known resource that the server is capable of reading. +type Resource struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // The size of the raw resource content, in bytes (i.e., before base64 encoding + // or any tokenization), if known. + // + // This can be used by Hosts to display file sizes and estimate context window + // usage. + Size int64 `json:"size,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // Icons for the resource, if any. + Icons []Icon `json:"icons,omitempty"` +} + +type ResourceListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// A template description for resources available on the server. +type ResourceTemplate struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only be + // included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // A URI template (according to RFC 6570) that can be used to construct resource + // URIs. + URITemplate string `json:"uriTemplate"` + // Icons for the resource template, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// The sender or recipient of messages and data in a conversation. +type Role string + +// Represents a root directory or file that the server can operate on. +type Root struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. This + // restriction may be relaxed in future versions of the protocol to allow other + // URI schemes. + URI string `json:"uri"` +} + +type RootsListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// TODO: to be consistent with ServerCapabilities, move the capability types +// below directly above ClientCapabilities. + +// SamplingCapabilities describes the client's support for sampling. +type SamplingCapabilities struct{} + +// ElicitationCapabilities describes the capabilities for elicitation. +// +// If neither Form nor URL is set, the 'Form' capabilitiy is assumed. +type ElicitationCapabilities struct { + Form *FormElicitationCapabilities + URL *URLElicitationCapabilities +} + +// FormElicitationCapabilities describes capabilities for form elicitation. +type FormElicitationCapabilities struct { +} + +// URLElicitationCapabilities describes capabilities for url elicitation. +type URLElicitationCapabilities struct { +} + +// Describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *SamplingMessage) UnmarshalJSON(data []byte) error { + type msg SamplingMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *m = SamplingMessage(wire.msg) + return nil +} + +type SetLoggingLevelParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The level of logging that the client wants to receive from the server. The + // server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/message. + Level LoggingLevel `json:"level"` +} + +func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } +func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Definition for a tool the client can call. +type Tool struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional additional tool information. + // + // Display name precedence order is: title, annotations.title, then name. + Annotations *ToolAnnotations `json:"annotations,omitempty"` + // A human-readable description of the tool. + // + // This can be used by clients to improve the LLM's understanding of available + // tools. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // InputSchema holds a JSON Schema object defining the expected parameters + // for the tool. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's input schema (a map[string]any). + InputSchema any `json:"inputSchema"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // OutputSchema holds an optional JSON Schema object defining the structure + // of the tool's output returned in the StructuredContent field of a + // CallToolResult. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's output schema (a map[string]any). + OutputSchema any `json:"outputSchema,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // If not provided, Annotations.Title should be used for display if present, + // otherwise Name. + Title string `json:"title,omitempty"` + // Icons for the tool, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// Additional properties describing a Tool to clients. +// +// NOTE: all properties in ToolAnnotations are hints. They are not +// guaranteed to provide a faithful description of tool behavior (including +// descriptive properties like title). +// +// Clients should never make tool use decisions based on ToolAnnotations +// received from untrusted servers. +type ToolAnnotations struct { + // If true, the tool may perform destructive updates to its environment. If + // false, the tool performs only additive updates. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: true + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, calling the tool repeatedly with the same arguments will have no + // additional effect on the its environment. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: false + IdempotentHint bool `json:"idempotentHint,omitempty"` + // If true, this tool may interact with an "open world" of external entities. If + // false, the tool's domain of interaction is closed. For example, the world of + // a web search tool is open, whereas that of a memory tool is not. + // + // Default: true + OpenWorldHint *bool `json:"openWorldHint,omitempty"` + // If true, the tool does not modify its environment. + // + // Default: false + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + // A human-readable title for the tool. + Title string `json:"title,omitempty"` +} + +type ToolListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +func (*SubscribeParams) isParams() {} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +func (*UnsubscribeParams) isParams() {} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +func (*ResourceUpdatedNotificationParams) isParams() {} + +// TODO(jba): add CompleteRequest and related types. + +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The mode of elicitation to use. + // + // If unset, will be inferred from the other fields. + Mode string `json:"mode"` + // The message to present to the user. + Message string `json:"message"` + // A JSON schema object defining the requested elicitation schema. + // + // From the server, this field may be set to any value that can JSON-marshal + // to valid JSON schema (including json.RawMessage for raw schema values). + // Internally, the SDK uses github.com/google/jsonschema-go for validation, + // which only supports the 2020-12 draft of the JSON schema spec. + // + // From the client, this field will use the default JSON marshaling (a + // map[string]any). + // + // Only top-level properties are allowed, without nesting. + // + // This is only used for "form" elicitation. + RequestedSchema any `json:"requestedSchema,omitempty"` + // The URL to present to the user. + // + // This is only used for "url" elicitation. + URL string `json:"url,omitempty"` + // The ID of the elicitation. + // + // This is only used for "url" elicitation. + ElicitationID string `json:"elicitationId,omitempty"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} + +// ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. +type ElicitationCompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The ID of the elicitation that has completed. This must correspond to the + // elicitationId from the original elicitation/create request. + ElicitationID string `json:"elicitationId"` +} + +func (*ElicitationCompleteParams) isParams() {} + +// An Implementation describes the name and version of an MCP implementation, with an optional +// title for UI representation. +type Implementation struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + Version string `json:"version"` + // WebsiteURL for the server, if any. + WebsiteURL string `json:"websiteUrl,omitempty"` + // Icons for the Server, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// CompletionCapabilities describes the server's support for argument autocompletion. +type CompletionCapabilities struct{} + +// LoggingCapabilities describes the server's support for sending log messages to the client. +type LoggingCapabilities struct{} + +// PromptCapabilities describes the server's support for prompts. +type PromptCapabilities struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// ResourceCapabilities describes the server's support for resources. +type ResourceCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the resource list. + ListChanged bool `json:"listChanged,omitempty"` + // Subscribe reports whether this server supports subscribing to resource + // updates. + Subscribe bool `json:"subscribe,omitempty"` +} + +// ToolCapabilities describes the server's support for tools. +type ToolCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// ServerCapabilities describes capabilities that a server supports. +type ServerCapabilities struct { + + // NOTE: any addition to ServerCapabilities must also be reflected in + // [ServerCapabilities.clone]. + + // Experimental reports non-standard capabilities that the server supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Completions is present if the server supports argument autocompletion + // suggestions. + Completions *CompletionCapabilities `json:"completions,omitempty"` + // Logging is present if the server supports log messages. + Logging *LoggingCapabilities `json:"logging,omitempty"` + // Prompts is present if the server supports prompts. + Prompts *PromptCapabilities `json:"prompts,omitempty"` + // Resources is present if the server supports resourcs. + Resources *ResourceCapabilities `json:"resources,omitempty"` + // Tools is present if the supports tools. + Tools *ToolCapabilities `json:"tools,omitempty"` +} + +// clone returns a deep copy of the ServerCapabilities. +func (c *ServerCapabilities) clone() *ServerCapabilities { + cp := *c + cp.Completions = shallowClone(c.Completions) + cp.Logging = shallowClone(c.Logging) + cp.Prompts = shallowClone(c.Prompts) + cp.Resources = shallowClone(c.Resources) + cp.Tools = shallowClone(c.Tools) + return &cp +} + +const ( + methodCallTool = "tools/call" + notificationCancelled = "notifications/cancelled" + methodComplete = "completion/complete" + methodCreateMessage = "sampling/createMessage" + methodElicit = "elicitation/create" + notificationElicitationComplete = "notifications/elicitation/complete" + methodGetPrompt = "prompts/get" + methodInitialize = "initialize" + notificationInitialized = "notifications/initialized" + methodListPrompts = "prompts/list" + methodListResourceTemplates = "resources/templates/list" + methodListResources = "resources/list" + methodListRoots = "roots/list" + methodListTools = "tools/list" + notificationLoggingMessage = "notifications/message" + methodPing = "ping" + notificationProgress = "notifications/progress" + notificationPromptListChanged = "notifications/prompts/list_changed" + methodReadResource = "resources/read" + notificationResourceListChanged = "notifications/resources/list_changed" + notificationResourceUpdated = "notifications/resources/updated" + notificationRootsListChanged = "notifications/roots/list_changed" + methodSetLevel = "logging/setLevel" + methodSubscribe = "resources/subscribe" + notificationToolListChanged = "notifications/tools/list_changed" + methodUnsubscribe = "resources/unsubscribe" +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go new file mode 100644 index 000000000..f64d6fb62 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go @@ -0,0 +1,38 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +type ( + CallToolRequest = ServerRequest[*CallToolParamsRaw] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams] +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go new file mode 100644 index 000000000..dc657f5dd --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go @@ -0,0 +1,164 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" +) + +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler +} + +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler +} + +// A ResourceHandler is a function that reads a resource. +// It will be called when the client calls [ClientSession.ReadResource]. +// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) + +// ResourceNotFoundError returns an error indicating that a resource being read could +// not be found. +func ResourceNotFoundError(uri string) error { + return &jsonrpc.Error{ + Code: CodeResourceNotFound, + Message: "Resource not found", + Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), + } +} + +// readFileResource reads from the filesystem at a URI relative to dirFilepath, respecting +// the roots. +// dirFilepath and rootFilepaths are absolute filesystem paths. +func readFileResource(rawURI, dirFilepath string, rootFilepaths []string) ([]byte, error) { + uriFilepath, err := computeURIFilepath(rawURI, dirFilepath, rootFilepaths) + if err != nil { + return nil, err + } + + var data []byte + err = withFile(dirFilepath, uriFilepath, func(f *os.File) error { + var err error + data, err = io.ReadAll(f) + return err + }) + if os.IsNotExist(err) { + err = ResourceNotFoundError(rawURI) + } + return data, err +} + +// computeURIFilepath returns a path relative to dirFilepath. +// The dirFilepath and rootFilepaths are absolute file paths. +func computeURIFilepath(rawURI, dirFilepath string, rootFilepaths []string) (string, error) { + // We use "file path" to mean a filesystem path. + uri, err := url.Parse(rawURI) + if err != nil { + return "", err + } + if uri.Scheme != "file" { + return "", fmt.Errorf("URI is not a file: %s", uri) + } + if uri.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // The URI's path is interpreted relative to dirFilepath, and in the local filesystem. + // It must not try to escape its directory. + uriFilepathRel, err := filepath.Localize(strings.TrimPrefix(uri.Path, "/")) + if err != nil { + return "", fmt.Errorf("%q cannot be localized: %w", uriFilepathRel, err) + } + + // Check roots, if there are any. + if len(rootFilepaths) > 0 { + // To check against the roots, we need an absolute file path, not relative to the directory. + // uriFilepath is local, so the joined path is under dirFilepath. + uriFilepathAbs := filepath.Join(dirFilepath, uriFilepathRel) + rootOK := false + // Check that the requested file path is under some root. + // Since both paths are absolute, that's equivalent to filepath.Rel constructing + // a local path. + for _, rootFilepathAbs := range rootFilepaths { + if rel, err := filepath.Rel(rootFilepathAbs, uriFilepathAbs); err == nil && filepath.IsLocal(rel) { + rootOK = true + break + } + } + if !rootOK { + return "", fmt.Errorf("URI path %q is not under any root", uriFilepathAbs) + } + } + return uriFilepathRel, nil +} + +// fileRoots transforms the Roots obtained from the client into absolute paths on +// the local filesystem. +// TODO(jba): expose this functionality to user ResourceHandlers, +// so they don't have to repeat it. +func fileRoots(rawRoots []*Root) ([]string, error) { + var fileRoots []string + for _, r := range rawRoots { + fr, err := fileRoot(r) + if err != nil { + return nil, err + } + fileRoots = append(fileRoots, fr) + } + return fileRoots, nil +} + +// fileRoot returns the absolute path for Root. +func fileRoot(root *Root) (_ string, err error) { + defer util.Wrapf(&err, "root %q", root.URI) + + // Convert to absolute file path. + rurl, err := url.Parse(root.URI) + if err != nil { + return "", err + } + if rurl.Scheme != "file" { + return "", errors.New("not a file URI") + } + if rurl.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // We don't want Localize here: we want an absolute path, which is not local. + fileRoot := filepath.Clean(filepath.FromSlash(rurl.Path)) + if !filepath.IsAbs(fileRoot) { + return "", errors.New("not an absolute path") + } + return fileRoot, nil +} + +// Matches reports whether the receiver's uri template matches the uri. +func (sr *serverResourceTemplate) Matches(uri string) bool { + tmpl, err := uritemplate.New(sr.resourceTemplate.URITemplate) + if err != nil { + return false + } + return tmpl.Regexp().MatchString(uri) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go new file mode 100644 index 000000000..4a35603c6 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package mcp + +import ( + "errors" + "os" +) + +// withFile calls f on the file at join(dir, rel), +// protecting against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + r, err := os.OpenRoot(dir) + if err != nil { + return err + } + defer r.Close() + file, err := r.Open(rel) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go new file mode 100644 index 000000000..d1f72eedc --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go @@ -0,0 +1,25 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build !go1.24 + +package mcp + +import ( + "errors" + "os" + "path/filepath" +) + +// withFile calls f on the file at join(dir, rel). +// It does not protect against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + file, err := os.Open(filepath.Join(dir, rel)) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go new file mode 100644 index 000000000..1f7edf9c5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go @@ -0,0 +1,1497 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "iter" + "log/slog" + "maps" + "net/url" + "path/filepath" + "reflect" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" +) + +// DefaultPageSize is the default for [ServerOptions.PageSize]. +const DefaultPageSize = 1000 + +// A Server is an instance of an MCP server. +// +// Servers expose server-side MCP features, which can serve one or more MCP +// sessions by using [Server.Run]. +type Server struct { + // fixed at creation + impl *Implementation + opts ServerOptions + + mu sync.Mutex + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] + sessions []*ServerSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send +} + +// ServerOptions is used to configure behavior of the server. +type ServerOptions struct { + // Optional instructions for connected clients. + Instructions string + // If non-nil, log server activity. + Logger *slog.Logger + // If non-nil, called when "notifications/initialized" is received. + InitializedHandler func(context.Context, *InitializedRequest) + // PageSize is the maximum number of items to return in a single page for + // list methods (e.g. ListTools). + // + // If zero, defaults to [DefaultPageSize]. + PageSize int + // If non-nil, called when "notifications/roots/list_changed" is received. + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) + // If non-nil, called when "notifications/progress" is received. + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) + // If non-nil, called when "completion/complete" is received. + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeRequest) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error + + // Capabilities optionally configures the server's default capabilities, + // before any capabilities are inferred from other configuration or server + // features. + // + // If Capabilities is nil, the default server capabilities are {"logging":{}}, + // for historical reasons. Setting Capabilities to a non-nil value overrides + // this default. For example, setting Capabilities to `&ServerCapabilities{}` + // disables the logging capability. + // + // # Interaction with capability inference + // + // "tools", "prompts", and "resources" capabilities are automatically added when + // tools, prompts, or resources are added to the server (for example, via + // [Server.AddPrompt]), with default value `{"listChanged":true}`. Similarly, + // if the [ClientOptions.SubscribeHandler] or + // [ClientOptions.CompletionHandler] are set, the inferred capabilities are + // adjusted accordingly. + // + // Any non-nil field in Capabilities overrides the inferred value. + // For example: + // + // - To advertise the "tools" capability, even if no tools are added, set + // Capabilities.Tools to &ToolCapabilities{ListChanged:true}. + // - To disable tool list notifications, set Capabilities.Tools to + // &ToolCapabilities{}. + // + // Conversely, if Capabilities does not set a field (for example, if the + // Prompts field is nil), the inferred capability will be used. + Capabilities *ServerCapabilities + + // If true, advertises the prompts capability during initialization, + // even if no prompts have been registered. + // + // Deprecated: Use Capabilities instead. + HasPrompts bool + // If true, advertises the resources capability during initialization, + // even if no resources have been registered. + // + // Deprecated: Use Capabilities instead. + HasResources bool + // If true, advertises the tools capability during initialization, + // even if no tools have been registered. + // + // Deprecated: Use Capabilities instead. + HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string +} + +// NewServer creates a new MCP server. The resulting server has no features: +// add features using the various Server.AddXXX methods, and the [AddTool] function. +// +// The server can be connected to one or more MCP clients using [Server.Run]. +// +// The first argument must not be nil. +// +// If non-nil, the provided options are used to configure the server. +func NewServer(impl *Implementation, options *ServerOptions) *Server { + if impl == nil { + panic("nil Implementation") + } + var opts ServerOptions + if options != nil { + opts = *options + } + options = nil // prevent reuse + if opts.PageSize < 0 { + panic(fmt.Errorf("invalid page size %d", opts.PageSize)) + } + if opts.PageSize == 0 { + opts.PageSize = DefaultPageSize + } + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + + if opts.Logger == nil { // ensure we have a logger + opts.Logger = ensureLogger(nil) + } + + return &Server{ + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), + pendingNotifications: make(map[string]*time.Timer), + } +} + +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) + s.changeAndNotify( + notificationPromptListChanged, + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) +} + +// RemovePrompts removes the prompts with the given names. +// It is not an error to remove a nonexistent prompt. +func (s *Server) RemovePrompts(names ...string) { + s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) }) +} + +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The Tool argument must not be modified after this call. +// +// The tool's input schema must be non-nil and have the type "object". For a tool +// that takes no input, or one where any input is valid, set [Tool.InputSchema] to +// `{"type": "object"}`, using your preferred library or `json.RawMessage`. +// +// If present, [Tool.OutputSchema] must also have type "object". +// +// When the handler is invoked as part of a CallTool request, req.Params.Arguments +// will be a json.RawMessage. +// +// Unmarshaling the arguments and validating them against the input schema are the +// caller's responsibility. +// +// Validating the result against the output schema, if any, is the caller's responsibility. +// +// Setting the result's Content, StructuredContent and IsError fields are the caller's +// responsibility. +// +// Most users should use the top-level function [AddTool], which handles all these +// responsibilities. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + if err := validateToolName(t.Name); err != nil { + s.opts.Logger.Error(fmt.Sprintf("AddTool: invalid tool name %q: %v", t.Name, err)) + } + if t.InputSchema == nil { + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) + } + if s, ok := t.InputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.InputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal input schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object" (got %v)`, t.Name, typ)) + } + } + if t.OutputSchema != nil { + if s, ok := t.OutputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.OutputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal output schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object" (got %v)`, t.Name, typ)) + } + } + } + st := &serverTool{tool: t, handler: h} + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. + s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true }) +} + +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { + tt := *t + + // Special handling for an "any" input: treat as an empty object. + if reflect.TypeFor[In]() == reflect.TypeFor[any]() && t.InputSchema == nil { + tt.InputSchema = &jsonschema.Schema{Type: "object"} + } + + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + if err != nil { + return nil, nil, fmt.Errorf("output schema: %v", err) + } + } + + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + + // Unmarshal and validate args. + var in In + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) + } + } + + // Call typed handler. + res, out, err := h(ctx, req, in) + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc.Error), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors + if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc.Error); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil + } + + if res == nil { + res = &CallToolResult{} + } + + // Marshal the output and put the RawMessage in the StructuredContent field. + var outval any = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the unpointered type. + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + outval = elemZero + } + } + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outJSON), + }} + } + } + return res, nil + } // end of handler + + return &tt, th, nil +} + +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { + var internalSchema *jsonschema.Schema + if *sfield == nil { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + if err == nil { + *sfield = internalSchema + } + } else if err := remarshal(*sfield, &internalSchema); err != nil { + return zero, err + } + if err != nil { + return zero, err + } + *rfield, err = internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + +// AddTool adds a tool and typed tool handler to the server. +// +// If the tool's input schema is nil, it is set to the schema inferred from the +// In type parameter. Types are inferred from Go types, and property +// descriptions are read from the 'jsonschema' struct tag. Internally, the SDK +// uses the github.com/google/jsonschema-go package for inference and +// validation. The In type argument must be a map or a struct, so that its +// inferred JSON Schema has type "object", as required by the spec. As a +// special case, if the In type is 'any', the tool's input schema is set to an +// empty object schema value. +// +// If the tool's output schema is nil, and the Out type is not 'any', the +// output schema is set to the schema inferred from the Out type argument, +// which must also be a map or struct. If the Out type is 'any', the output +// schema is omitted. +// +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces +// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed +// description of this automatic behavior. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + } + s.AddTool(tt, hh) +} + +// RemoveTools removes the tools with the given names. +// It is not an error to remove a nonexistent tool. +func (s *Server) RemoveTools(names ...string) { + s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) }) +} + +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, + func() bool { + if _, err := url.Parse(r.URI); err != nil { + panic(err) // url.Parse includes the URI in the error + } + s.resources.add(&serverResource{r, h}) + return true + }) +} + +// RemoveResources removes the resources with the given URIs. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResources(uris ...string) { + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) }) +} + +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, + func() bool { + // Validate the URI template syntax + _, err := uritemplate.New(t.URITemplate) + if err != nil { + panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) + } + s.resourceTemplates.add(&serverResourceTemplate{t, h}) + return true + }) +} + +// RemoveResourceTemplates removes the resource templates with the given URI templates. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) }) +} + +func (s *Server) capabilities() *ServerCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ServerCapabilities + if s.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = s.opts.Capabilities.clone() + } else { + // SDK defaults: only logging capability. + caps = &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + } + } + + // Augment with tools capability if tools exist or legacy HasTools is set. + if s.opts.HasTools || s.tools.len() > 0 { + if caps.Tools == nil { + caps.Tools = &ToolCapabilities{ListChanged: true} + } + } + + // Augment with prompts capability if prompts exist or legacy HasPrompts is set. + if s.opts.HasPrompts || s.prompts.len() > 0 { + if caps.Prompts == nil { + caps.Prompts = &PromptCapabilities{ListChanged: true} + } + } + + // Augment with resources capability if resources/templates exist or legacy HasResources is set. + if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + if caps.Resources == nil { + caps.Resources = &ResourceCapabilities{ListChanged: true} + } + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } + } + + // Augment with completions capability if handler is set. + if s.opts.CompletionHandler != nil { + if caps.Completions == nil { + caps.Completions = &CompletionCapabilities{} + } + } + + return caps +} + +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { + if s.opts.CompletionHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + return s.opts.CompletionHandler(ctx, req) +} + +// Map from notification name to its corresponding params. The params have no fields, +// so a single struct can be reused. +var changeNotificationParams = map[string]Params{ + notificationToolListChanged: &ToolListChangedParams{}, + notificationPromptListChanged: &PromptListChangedParams{}, + notificationResourceListChanged: &ResourceListChangedParams{}, +} + +// How long to wait before sending a change notification. +const notificationDelay = 10 * time.Millisecond + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it sets a timer to send a notification. +// This debounces change notifications: a single notification is sent after +// multiple changes occur in close proximity. +func (s *Server) changeAndNotify(notification string, change func() bool) { + s.mu.Lock() + defer s.mu.Unlock() + if change() && s.shouldSendListChangedNotification(notification) { + // Reset the outstanding delayed call, if any. + if t := s.pendingNotifications[notification]; t == nil { + s.pendingNotifications[notification] = time.AfterFunc(notificationDelay, func() { s.notifySessions(notification) }) + } else { + t.Reset(notificationDelay) + } + } +} + +// notifySessions sends the notification n to all existing sessions. +// It is called asynchronously by changeAndNotify. +func (s *Server) notifySessions(n string) { + s.mu.Lock() + sessions := slices.Clone(s.sessions) + s.pendingNotifications[n] = nil + s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. + notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) +} + +// shouldSendListChangedNotification checks if the server's capabilities allow +// sending the given list-changed notification. +func (s *Server) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := s.opts.Capabilities + + switch notification { + case notificationToolListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil || caps.Tools == nil { + return true + } + return caps.Tools.ListChanged + case notificationPromptListChanged: + if caps == nil || caps.Prompts == nil { + return true + } + return caps.Prompts.ListChanged + case notificationResourceListChanged: + if caps == nil || caps.Resources == nil { + return true + } + return caps.Resources.ListChanged + default: + // Unknown notification, allow by default. + return true + } +} + +// Sessions returns an iterator that yields the current set of server sessions. +// +// There is no guarantee that the iterator observes sessions that are added or +// removed during iteration. +func (s *Server) Sessions() iter.Seq[*ServerSession] { + s.mu.Lock() + clients := slices.Clone(s.sessions) + s.mu.Unlock() + return slices.Values(clients) +} + +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListPromptsParams{} + } + return paginateList(s.prompts, s.opts.PageSize, req.Params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { + res.Prompts = []*Prompt{} // avoid JSON null + for _, p := range prompts { + res.Prompts = append(res.Prompts, p.prompt) + } + }) +} + +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + s.mu.Lock() + prompt, ok := s.prompts.get(req.Params.Name) + s.mu.Unlock() + if !ok { + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } + } + return prompt.handler(ctx, req) +} + +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListToolsParams{} + } + return paginateList(s.tools, s.opts.PageSize, req.Params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { + res.Tools = []*Tool{} // avoid JSON null + for _, t := range tools { + res.Tools = append(res.Tools, t.tool) + } + }) +} + +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + s.mu.Lock() + st, ok := s.tools.get(req.Params.Name) + s.mu.Unlock() + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } + } + res, err := st.handler(ctx, req) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } + return res, err +} + +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourcesParams{} + } + return paginateList(s.resources, s.opts.PageSize, req.Params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { + res.Resources = []*Resource{} // avoid JSON null + for _, r := range resources { + res.Resources = append(res.Resources, r.resource) + } + }) +} + +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourceTemplatesParams{} + } + return paginateList(s.resourceTemplates, s.opts.PageSize, req.Params, &ListResourceTemplatesResult{}, + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { + res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null + for _, rt := range rts { + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) + } + }) +} + +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + uri := req.Params.URI + // Look up the resource URI in the lists of resources and resource templates. + // This is a security check as well as an information lookup. + handler, mimeType, ok := s.lookupResourceHandler(uri) + if !ok { + // Don't expose the server configuration to the client. + // Treat an unregistered resource the same as a registered one that couldn't be found. + return nil, ResourceNotFoundError(uri) + } + res, err := handler(ctx, req) + if err != nil { + return nil, err + } + if res == nil || res.Contents == nil { + return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) + } + // As a convenience, populate some fields. + for _, c := range res.Contents { + if c.URI == "" { + c.URI = uri + } + if c.MIMEType == "" { + c.MIMEType = mimeType + } + } + return res, nil +} + +// lookupResourceHandler returns the resource handler and MIME type for the resource or +// resource template matching uri. If none, the last return value is false. +func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + // Try resources first. + if r, ok := s.resources.get(uri); ok { + return r.handler, r.resource.MIMEType, true + } + // Look for matching template. + for rt := range s.resourceTemplates.all() { + if rt.Matches(uri) { + return rt.handler, rt.resourceTemplate.MIMEType, true + } + } + return nil, "", false +} + +// fileResourceHandler returns a ReadResourceHandler that reads paths using dir as +// a base directory. +// It honors client roots and protects against path traversal attacks. +// +// The dir argument should be a filesystem path. It need not be absolute, but +// that is recommended to avoid a dependency on the current working directory (the +// check against client roots is done with an absolute path). If dir is not absolute +// and the current working directory is unavailable, fileResourceHandler panics. +// +// Lexical path traversal attacks, where the path has ".." elements that escape dir, +// are always caught. Go 1.24 and above also protects against symlink-based attacks, +// where symlinks under dir lead out of the tree. +func fileResourceHandler(dir string) ResourceHandler { + // Convert dir to an absolute path. + dirFilepath, err := filepath.Abs(dir) + if err != nil { + panic(err) + } + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { + defer util.Wrapf(&err, "reading resource %s", req.Params.URI) + + // TODO(#25): use a memoizing API here. + rootRes, err := req.Session.ListRoots(ctx, nil) + if err != nil { + return nil, fmt.Errorf("listing roots: %w", err) + } + roots, err := fileRoots(rootRes.Roots) + if err != nil { + return nil, err + } + data, err := readFileResource(req.Params.URI, dirFilepath, roots) + if err != nil { + return nil, err + } + // TODO(jba): figure out mime type. Omit for now: Server.readResource will fill it in. + return &ReadResourceResult{Contents: []*ResourceContents{ + {URI: req.Params.URI, Blob: data}, + }}, nil + } +} + +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) + s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) + return nil +} + +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[req.Params.URI] == nil { + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[req.Params.URI][req.Session] = true + s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[req.Params.URI]; ok { + delete(subscribedSessions, req.Session) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, req.Params.URI) + } + } + s.opts.Logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +// Run runs the server over the given transport, which must be persistent. +// +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. +// +// Run is a convenience for servers that handle a single session (or one session at a time). +// It need not be called on servers that are used for multiple concurrent connections, +// as with [StreamableHTTPHandler]. +func (s *Server) Run(ctx context.Context, t Transport) error { + s.opts.Logger.Info("server run start") + ss, err := s.Connect(ctx, t, nil) + if err != nil { + s.opts.Logger.Error("server connect failed", "error", err) + return err + } + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + <-ssClosed // wait until waiting go routine above actually completes + s.opts.Logger.Error("server run cancelled", "error", ctx.Err()) + return ctx.Err() + case err := <-ssClosed: + if err != nil { + s.opts.Logger.Error("server session ended with error", "error", err) + } else { + s.opts.Logger.Info("server session ended") + } + return err + } +} + +// bind implements the binder[*ServerSession] interface, so that Servers can +// be connected using [connect]. +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} + if state != nil { + ss.state = *state + } + s.mu.Lock() + s.sessions = append(s.sessions, ss) + s.mu.Unlock() + s.opts.Logger.Info("server session connected", "session_id", ss.ID()) + return ss +} + +// disconnect implements the binder[*ServerSession] interface, so that +// Servers can be connected using [connect]. +func (s *Server) disconnect(cc *ServerSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { + return cc2 == cc + }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) +} + +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState + + onClose func() // used to clean up associated resources +} + +// Connect connects the MCP server over the given transport and starts handling +// messages. +// +// It returns a connection object that may be used to terminate the connection +// (with [Connection.Close]), or await client termination (with +// [Connection.Wait]). +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + var onClose func() + if opts != nil { + state = opts.State + onClose = opts.onClose + } + + s.opts.Logger.Info("server connecting") + ss, err := connect(ctx, t, s, state, onClose) + if err != nil { + s.opts.Logger.Error("server connect error", "error", err) + return nil, err + } + return ss, nil +} + +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) + + if !wasInit { + ss.server.opts.Logger.Error("initialized before initialize") + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitd { + ss.server.opts.Logger.Error("duplicate initialized notification") + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) + } + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + if h := ss.server.opts.InitializedHandler; h != nil { + h(ctx, serverRequestFor(ss, params)) + } + ss.server.opts.Logger.Info("session initialized") + return nil, nil +} + +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { + if h := s.opts.RootsListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { + if h := ss.server.opts.ProgressNotificationHandler; h != nil { + h(ctx, serverRequestFor(ss, p)) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the server to the client +// associated with this session. +// This is typically used to report on the status of a long-running request +// that was initiated by the client. +func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) +} + +func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { + return &ServerRequest[P]{Session: ss, Params: params} +} + +// A ServerSession is a logical connection from a single MCP client. Its +// methods can be used to send requests or notifications to the client. Create +// a session by calling [Server.Connect]. +// +// Call [ServerSession.Close] to close the connection, or await client +// termination with [ServerSession.Wait]. +type ServerSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState +} + +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } +} + +// hasInitialized reports whether the server has received the initialized +// notification. +// +// TODO(findleyr): use this to prevent change notifications. +func (ss *ServerSession) hasInitialized() bool { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializedParams != nil +} + +// checkInitialized returns a formatted error if the server has not yet +// received the initialized notification. +func (ss *ServerSession) checkInitialized(method string) error { + if !ss.hasInitialized() { + // TODO(rfindley): enable this check. + // Right now is is flaky, because server tests don't await the initialized notification. + // Perhaps requests should simply block until they have received the initialized notification + + // if strings.HasPrefix(method, "notifications/") { + // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized) + // } else { + // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized) + // } + } + return nil +} + +func (ss *ServerSession) ID() string { + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Ping pings the client. +func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params))) + return err +} + +// ListRoots lists the client roots. +func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + if err := ss.checkInitialized(methodListRoots); err != nil { + return nil, err + } + return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) +} + +// CreateMessage sends a sampling request to the client. +func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageParams{Messages: []*SamplingMessage{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessage{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) +} + +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + if err := ss.checkInitialized(methodElicit); err != nil { + return nil, err + } + if params == nil { + return nil, fmt.Errorf("%w: params cannot be nil", jsonrpc2.ErrInvalidParams) + } + + if params.Mode == "" { + params2 := *params + if params.URL != "" || params.ElicitationID != "" { + params2.Mode = "url" + } else { + params2.Mode = "form" + } + params = ¶ms2 + } + + if iparams := ss.InitializeParams(); iparams == nil || iparams.Capabilities == nil || iparams.Capabilities.Elicitation == nil { + return nil, fmt.Errorf("client does not support elicitation") + } + caps := ss.InitializeParams().Capabilities.Elicitation + switch params.Mode { + case "form": + if caps.Form == nil && caps.URL != nil { + // Note: if both 'Form' and 'URL' are nil, we assume the client supports + // form elicitation for backward compatibility. + return nil, errors.New(`client does not support "form" elicitation`) + } + case "url": + if caps.URL == nil { + return nil, errors.New(`client does not support "url" elicitation`) + } + } + + res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) + if err != nil { + return nil, err + } + + if params.RequestedSchema == nil { + return res, nil + } + schema, err := validateElicitSchema(params.RequestedSchema) + if err != nil { + return nil, err + } + if schema == nil { + return res, nil + } + + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, err + } + if err := resolved.Validate(res.Content); err != nil { + return nil, fmt.Errorf("elicitation result content does not match requested schema: %v", err) + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, fmt.Errorf("failed to apply schema defalts to elicitation result: %v", err) + } + + return res, nil +} + +// Log sends a log message to the client. +// The message is not sent if the client has not called SetLevel, or if its level +// is below that of the last SetLevel. +func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { + ss.mu.Lock() + logLevel := ss.state.LogLevel + ss.mu.Unlock() + if logLevel == "" { + // The spec is unclear, but seems to imply that no log messages are sent until the client + // sets the level. + // TODO(jba): read other SDKs, possibly file an issue. + return nil + } + if compareLevels(params.Level, logLevel) < 0 { + return nil + } + return handleNotify(ctx, notificationLoggingMessage, newServerRequest(ss, orZero[Params](params))) +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (s *Server) AddSendingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.receivingMethodHandler_, middleware) +} + +// serverMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var serverMethodInfos = map[string]methodInfo{ + methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: initializeMethodInfo(), + methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), + notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), + notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), +} + +// initializeMethodInfo handles the workaround for #607: we must set +// params.Capabilities.RootsV2. +func initializeMethodInfo() methodInfo { + info := newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0) + info.unmarshalParams = func(m json.RawMessage) (Params, error) { + var params *initializeParamsV2 + if m != nil { + if err := json.Unmarshal(m, ¶ms); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, params, err) + } + } + if params == nil { + return nil, fmt.Errorf(`missing required "params"`) + } + return params.toV1(), nil + } + return info +} + +func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } + +func (ss *ServerSession) sendingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendingMethodHandler_ +} + +func (ss *ServerSession) receivingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.receivingMethodHandler_ +} + +// getConn implements [session.getConn]. +func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } + +// handle invokes the method described by the given JSON RPC request. +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + ss.mu.Lock() + initialized := ss.state.InitializeParams != nil + ss.mu.Unlock() + + // From the spec: + // "The client SHOULD NOT send requests other than pings before the server + // has responded to the initialize request." + switch req.Method { + case methodInitialize, methodPing, notificationInitialized: + default: + if !initialized { + ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) + return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) + } + } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + + // For the streamable transport, we need the request ID to correlate + // server->client calls and notifications to the incoming request from which + // they originated. See [idContextKey] for details. + ctx = context.WithValue(ctx, idContextKey{}, req.ID) + return handleReceive(ctx, ss, req) +} + +// InitializeParams returns the InitializeParams provided during the client's +// initial connection. +func (ss *ServerSession) InitializeParams() *InitializeParams { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializeParams +} + +func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { + if params == nil { + return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) + } + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) + + s := ss.server + return &InitializeResult{ + // TODO(rfindley): alter behavior when falling back to an older version: + // reject unsupported features. + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), + Capabilities: s.capabilities(), + Instructions: s.opts.Instructions, + ServerInfo: s.impl, + }, nil +} + +func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) { + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) + ss.server.opts.Logger.Info("client log level set", "level", params.Level) + return &emptyResult{}, nil +} + +// Close performs a graceful shutdown of the connection, preventing new +// requests from being handled, and waiting for ongoing requests to return. +// Close then terminates the connection. +// +// Close is idempotent and concurrency safe. +func (ss *ServerSession) Close() error { + if ss.keepaliveCancel != nil { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + ss.keepaliveCancel() + } + err := ss.conn.Close() + + if ss.onClose != nil && ss.calledOnClose.CompareAndSwap(false, true) { + ss.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the client. +func (ss *ServerSession) Wait() error { + return ss.conn.Wait() +} + +// startKeepalive starts the keepalive mechanism for this server session. +func (ss *ServerSession) startKeepalive(interval time.Duration) { + startKeepalive(ss, interval, &ss.keepaliveCancel) +} + +// pageToken is the internal structure for the opaque pagination cursor. +// It will be Gob-encoded and then Base64-encoded for use as a string token. +type pageToken struct { + LastUID string // The unique ID of the last resource seen. +} + +// encodeCursor encodes a unique identifier (UID) into a opaque pagination cursor +// by serializing a pageToken struct. +func encodeCursor(uid string) (string, error) { + var buf bytes.Buffer + token := pageToken{LastUID: uid} + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(token); err != nil { + return "", fmt.Errorf("failed to encode page token: %w", err) + } + return base64.URLEncoding.EncodeToString(buf.Bytes()), nil +} + +// decodeCursor decodes an opaque pagination cursor into the original pageToken struct. +func decodeCursor(cursor string) (*pageToken, error) { + decodedBytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("failed to decode cursor: %w", err) + } + + var token pageToken + buf := bytes.NewBuffer(decodedBytes) + decoder := gob.NewDecoder(buf) + if err := decoder.Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) + } + return &token, nil +} + +// paginateList is a generic helper that returns a paginated slice of items +// from a featureSet. It populates the provided result res with the items +// and sets its next cursor for subsequent pages. +// If there are no more pages, the next cursor within the result will be an empty string. +func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageSize int, params P, res R, setFunc func(R, []T)) (R, error) { + var seq iter.Seq[T] + if params.cursorPtr() == nil || *params.cursorPtr() == "" { + seq = fs.all() + } else { + pageToken, err := decodeCursor(*params.cursorPtr()) + // According to the spec, invalid cursors should return Invalid params. + if err != nil { + var zero R + return zero, jsonrpc2.ErrInvalidParams + } + seq = fs.above(pageToken.LastUID) + } + var count int + var features []T + for f := range seq { + count++ + // If we've seen pageSize + 1 elements, we've gathered enough info to determine + // if there's a next page. Stop processing the sequence. + if count == pageSize+1 { + break + } + features = append(features, f) + } + setFunc(res, features) + // No remaining pages. + if count < pageSize+1 { + return res, nil + } + nextCursor, err := encodeCursor(fs.uniqueID(features[len(features)-1])) + if err != nil { + var zero R + return zero, err + } + *res.nextCursorPtr() = nextCursor + return res, nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go new file mode 100644 index 000000000..dcf9888cc --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go new file mode 100644 index 000000000..d83eae7da --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go @@ -0,0 +1,610 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains code shared between client and server, including +// method handler and middleware definitions. +// +// Much of this is here so that we can factor out commonalities using +// generics. If this becomes unwieldy, it can perhaps be simplified with +// reflection. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "reflect" + "slices" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20251125 = "2025-11-25" // not yet released + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20241105 = "2024-11-05" +) + +var supportedProtocolVersions = []string{ + protocolVersion20251125, + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20241105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion +} + +// A MethodHandler handles MCP messages. +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) + +// A Session is either a [ClientSession] or a [ServerSession]. +type Session interface { + // ID returns the session ID, or the empty string if there is none. + ID() string + + sendingMethodInfos() map[string]methodInfo + receivingMethodInfos() map[string]methodInfo + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler + getConn() *jsonrpc2.Connection +} + +// Middleware is a function from [MethodHandler] to [MethodHandler]. +type Middleware func(MethodHandler) MethodHandler + +// addMiddleware wraps the handler in the middleware functions. +func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { + for _, m := range slices.Backward(middleware) { + *handlerp = m(*handlerp) + } +} + +func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().sendingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + params := req.GetParams() + if initParams, ok := params.(*InitializeParams); ok { + // Fix the marshaling of initialize params, to work around #607. + // + // The initialize params we produce should never be nil, nor have nil + // capabilities, so any panic here is a bug. + params = initParams.toV2() + } + // Notifications don't have results. + if strings.HasPrefix(method, "notifications/") { + return nil, req.GetSession().getConn().Notify(ctx, method, params) + } + // Create the result to unmarshal into. + // The concrete type of the result is the return type of the receiving function. + res := info.newResult() + if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { + return nil, err + } + return res, nil +} + +// Helper method to avoid typed nil. +func orZero[T any, P *U, U any](p P) T { + if p == nil { + var zero T + return zero + } + return any(p).(T) +} + +func handleNotify(ctx context.Context, method string, req Request) error { + mh := req.GetSession().sendingMethodHandler() + _, err := mh(ctx, method, req) + return err +} + +func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { + mh := req.GetSession().sendingMethodHandler() + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, method, req) + if err != nil { + var z R + return z, err + } + return res.(R), nil +} + +// defaultReceivingMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. +func defaultReceivingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().receivingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + return info.handleMethod(ctx, method, req) +} + +func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { + info, err := checkRequest(jreq, session.receivingMethodInfos()) + if err != nil { + return nil, err + } + params, err := info.unmarshalParams(jreq.Params) + if err != nil { + return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) + } + + mh := session.receivingMethodHandler() + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, jreq.Method, req) + if err != nil { + return nil, err + } + return res, nil +} + +// checkRequest checks the given request against the provided method info, to +// ensure it is a valid MCP request. +// +// If valid, the relevant method info is returned. Otherwise, a non-nil error +// is returned describing why the request is invalid. +// +// This is extracted from request handling so that it can be called in the +// transport layer to preemptively reject bad requests. +func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) { + info, ok := infos[req.Method] + if !ok { + return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) + } + if info.flags¬ification != 0 && req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + if info.flags¬ification == 0 && !req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + // missingParamsOK is checked here to catch the common case where "params" is + // missing entirely. + // + // However, it's checked again after unmarshalling to catch the rare but + // possible case where "params" is JSON null (see https://go.dev/issue/33835). + if info.flags&missingParamsOK == 0 && len(req.Params) == 0 { + return methodInfo{}, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return info, nil +} + +// methodInfo is information about sending and receiving a method. +type methodInfo struct { + // flags is a collection of flags controlling how the JSONRPC method is + // handled. See individual flag values for documentation. + flags methodFlags + // Unmarshal params from the wire into a Params struct. + // Used on the receive side. + unmarshalParams func(json.RawMessage) (Params, error) + newRequest func(Session, Params, *RequestExtra) Request + // Run the code when a call to the method is received. + // Used on the receive side. + handleMethod MethodHandler + // Create a pointer to a Result struct. + // Used on the send side. + newResult func() Result +} + +// The following definitions support converting from typed to untyped method handlers. +// Type parameter meanings: +// - S: sessions +// - P: params +// - R: results + +// A typedMethodHandler is like a MethodHandler, but with type information. +type ( + typedClientMethodHandler[P Params, R Result] func(context.Context, *ClientRequest[P]) (R, error) + typedServerMethodHandler[P Params, R Result] func(context.Context, *ServerRequest[P]) (R, error) +) + +type paramsPtr[T any] interface { + *T + Params +} + +type methodFlags int + +const ( + notification methodFlags = 1 << iota // method is a notification, not request + missingParamsOK // params may be missing or null +) + +func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { + r := &ClientRequest[P]{Session: s.(*ClientSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ClientRequest[P])) + }) + return mi +} + +func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ServerRequest[P])) + }) + return mi +} + +// newMethodInfo creates a methodInfo from a typedMethodHandler. +// +// If isRequest is set, the method is treated as a request rather than a +// notification. +func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { + return methodInfo{ + flags: flags, + unmarshalParams: func(m json.RawMessage) (Params, error) { + var p P + if m != nil { + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } + } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return orZero[Params](p), nil + }, + // newResult is used on the send side, to construct the value to unmarshal the result into. + // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. + // TODO(jba): explore generic approaches to this, perhaps by treating R in + // the signature as the unpointered type. + newResult: func() Result { return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) }, + } +} + +// serverMethod is glue for creating a typedMethodHandler from a method on Server. +func serverMethod[P Params, R Result]( + f func(*Server, context.Context, *ServerRequest[P]) (R, error), +) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.Session.server, ctx, req) + } +} + +// clientMethod is glue for creating a typedMethodHandler from a method on Client. +func clientMethod[P Params, R Result]( + f func(*Client, context.Context, *ClientRequest[P]) (R, error), +) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.Session.client, ctx, req) + } +} + +// serverSessionMethod is glue for creating a typedServerMethodHandler from a method on ServerSession. +func serverSessionMethod[P Params, R Result](f func(*ServerSession, context.Context, P) (R, error)) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.GetSession().(*ServerSession), ctx, req.Params) + } +} + +// clientSessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Context, P) (R, error)) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.GetSession().(*ClientSession), ctx, req.Params) + } +} + +// MCP-specific error codes. +const ( + // CodeResourceNotFound indicates that a requested resource could not be found. + CodeResourceNotFound = -32002 + // CodeURLElicitationRequired indicates that the server requires URL elicitation + // before processing the request. The client should execute the elicitation handler + // with the elicitations provided in the error data. + CodeURLElicitationRequired = -32042 +) + +// URLElicitationRequiredError returns an error indicating that URL elicitation is required +// before the request can be processed. The elicitations parameter should contain the +// elicitation requests that must be completed. +func URLElicitationRequiredError(elicitations []*ElicitParams) error { + // Validate that all elicitations are URL mode + for _, elicit := range elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // default mode + } + if mode != "url" { + panic(fmt.Sprintf("URLElicitationRequiredError requires all elicitations to be URL mode, got %q", mode)) + } + } + + data, err := json.Marshal(map[string]any{ + "elicitations": elicitations, + }) + if err != nil { + // This should never happen with valid ElicitParams + panic(fmt.Sprintf("failed to marshal elicitations: %v", err)) + } + return &jsonrpc.Error{ + Code: CodeURLElicitationRequired, + Message: "URL elicitation required", + Data: json.RawMessage(data), + } +} + +// Internal error codes +const ( + // The error code if the method exists and was called properly, but the peer does not support it. + // + // TODO(rfindley): this code is wrong, and we should fix it to be + // consistent with other SDKs. + codeUnsupportedMethod = -31001 +) + +// notifySessions calls Notify on all the sessions. +// Should be called on a copy of the peer sessions. +// The logger must be non-nil. +func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) { + if sessions == nil { + return + } + // Notify with the background context, so the messages are sent on the + // standalone stream. + // TODO: make this timeout configurable, or call handleNotify asynchronously. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // TODO: there's a potential spec violation here, when the feature list + // changes before the session (client or server) is initialized. + for _, s := range sessions { + req := newRequest(s, params) + if err := handleNotify(ctx, method, req); err != nil { + logger.Warn(fmt.Sprintf("calling %s: %v", method, err)) + } + } +} + +func newRequest[S Session, P Params](s S, p P) Request { + switch s := any(s).(type) { + case *ClientSession: + return &ClientRequest[P]{Session: s, Params: p} + case *ServerSession: + return &ServerRequest[P]{Session: s, Params: p} + default: + panic("bad session") + } +} + +// Meta is additional metadata for requests, responses and other types. +type Meta map[string]any + +// GetMeta returns metadata from a value. +func (m Meta) GetMeta() map[string]any { return m } + +// SetMeta sets the metadata on a value. +func (m *Meta) SetMeta(x map[string]any) { *m = x } + +const progressTokenKey = "progressToken" + +func getProgressToken(p Params) any { + return p.GetMeta()[progressTokenKey] +} + +func setProgressToken(p Params, pt any) { + switch pt.(type) { + // Support int32 and int64 for atomic.IntNN. + case int, int32, int64, string: + default: + panic(fmt.Sprintf("progress token %v is of type %[1]T, not int or string", pt)) + } + m := p.GetMeta() + if m == nil { + m = map[string]any{} + } + m[progressTokenKey] = pt +} + +// A Request is a method request with parameters and additional information, such as the session. +// Request is implemented by [*ClientRequest] and [*ServerRequest]. +type Request interface { + isRequest() + GetSession() Session + GetParams() Params + // GetExtra returns the Extra field for ServerRequests, and nil for ClientRequests. + GetExtra() *RequestExtra +} + +// A ClientRequest is a request to a client. +type ClientRequest[P Params] struct { + Session *ClientSession + Params P +} + +// A ServerRequest is a request to a server. +type ServerRequest[P Params] struct { + Session *ServerSession + Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any + + // If set, CloseSSEStream explicitly closes the current SSE request stream. + // + // [SEP-1699] introduced server-side SSE stream disconnection: for + // long-running requests, servers may opt to close the SSE stream and + // ask the client to retry at a later time. CloseSSEStream implements this + // feature; if RetryAfter is set, an event is sent with a `retry:` field + // to configure the reconnection delay. + // + // [SEP-1699]: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699 + CloseSSEStream func(CloseSSEStreamArgs) +} + +// CloseSSEStreamArgs are arguments for [RequestExtra.CloseSSEStream]. +type CloseSSEStreamArgs struct { + // RetryAfter configures the reconnection delay sent to the client via the + // SSE retry field. If zero, no retry field is sent. + RetryAfter time.Duration +} + +func (*ClientRequest[P]) isRequest() {} +func (*ServerRequest[P]) isRequest() {} + +func (r *ClientRequest[P]) GetSession() Session { return r.Session } +func (r *ServerRequest[P]) GetSession() Session { return r.Session } + +func (r *ClientRequest[P]) GetParams() Params { return r.Params } +func (r *ServerRequest[P]) GetParams() Params { return r.Params } + +func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } +func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } + +func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { + return &ServerRequest[P]{Session: s, Params: p} +} + +func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { + return &ClientRequest[P]{Session: s, Params: p} +} + +// Params is a parameter (input) type for an MCP call or notification. +type Params interface { + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) + + // isParams discourages implementation of Params outside of this package. + isParams() +} + +// RequestParams is a parameter (input) type for an MCP request. +type RequestParams interface { + Params + + // GetProgressToken returns the progress token from the params' Meta field, or nil + // if there is none. + GetProgressToken() any + + // SetProgressToken sets the given progress token into the params' Meta field. + // It panics if its argument is not an int or a string. + SetProgressToken(any) +} + +// Result is a result of an MCP call. +type Result interface { + // isResult discourages implementation of Result outside of this package. + isResult() + + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) +} + +// emptyResult is returned by methods that have no result, like ping. +// Those methods cannot return nil, because jsonrpc2 cannot handle nils. +type emptyResult struct{} + +func (*emptyResult) isResult() {} +func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } +func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } + +type listParams interface { + // Returns a pointer to the param's Cursor field. + cursorPtr() *string +} + +type listResult[T any] interface { + // Returns a pointer to the param's NextCursor field. + nextCursorPtr() *string +} + +// keepaliveSession represents a session that supports keepalive functionality. +type keepaliveSession interface { + Ping(ctx context.Context, params *PingParams) error + Close() error +} + +// startKeepalive starts the keepalive mechanism for a session. +// It assigns the cancel function to the provided cancelPtr and starts a goroutine +// that sends ping messages at the specified interval. +func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + // Assign cancel function before starting goroutine to avoid race condition. + // We cannot return it because the caller may need to cancel during the + // window between goroutine scheduling and function return. + *cancelPtr = cancel + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) + err := session.Ping(pingCtx, nil) + pingCancel() + if err != nil { + // Ping failed, close the session + _ = session.Close() + return + } + } + } + }() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go new file mode 100644 index 000000000..7f644918b --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go @@ -0,0 +1,479 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// This file implements support for SSE (HTTP with server-sent events) +// transport server and client. +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +// +// The transport is simple, at least relative to the new streamable transport +// introduced in the 2025-03-26 version of the spec. In short: +// +// 1. Sessions are initiated via a hanging GET request, which streams +// server->client messages as SSE 'message' events. +// 2. The first event in the SSE stream must be an 'endpoint' event that +// informs the client of the session endpoint. +// 3. The client POSTs client->server messages to the session endpoint. +// +// Therefore, the each new GET request hands off its responsewriter to an +// [SSEServerTransport] type that abstracts the transport as follows: +// - Write writes a new event to the responseWriter, or fails if the GET has +// exited. +// - Read reads off a message queue that is pushed to via POST requests. +// - Close causes the hanging GET to exit. + +// SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by +// the [2024-11-05 version] of the MCP spec. +// +// [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEHandler struct { + getServer func(request *http.Request) *Server + opts SSEOptions + onConnection func(*ServerSession) // for testing; must not block + + mu sync.Mutex + sessions map[string]*SSEServerTransport +} + +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + +// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP +// sessions created via incoming HTTP requests. +// +// Sessions are created when the client issues a GET request to the server, +// which must accept text/event-stream responses (server-sent events). +// For each such request, a new [SSEServerTransport] is created with a distinct +// messages endpoint, and connected to the server returned by getServer. +// The SSEHandler also handles requests to the message endpoints, by +// delegating them to the relevant server transport. +// +// The getServer function may return a distinct [Server] for each new +// request, or reuse an existing server. If it returns nil, the handler +// will return a 400 Bad Request. +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ + getServer: getServer, + sessions: make(map[string]*SSEServerTransport), + } + + if opts != nil { + s.opts = *opts + } + + return s +} + +// A SSEServerTransport is a logical SSE session created through a hanging GET +// request. +// +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// +// When connected, it returns the following [Connection] implementation: +// - Writes are SSE 'message' events to the GET response. +// - Reads are received from POSTs to the session endpoint, via +// [SSEServerTransport.ServeHTTP]. +// - Close terminates the hanging GET. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +// +// Each SSEServerTransport may be connected (via [Server.Connect]) at most +// once, since [SSEServerTransport.ServeHTTP] serves messages to the connected +// session. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. +type SSEServerTransport struct { + // Endpoint is the endpoint for this session, where the client can POST + // messages. + Endpoint string + + // Response is the hanging response body to the incoming GET request. + Response http.ResponseWriter + + // incoming is the queue of incoming messages. + // It is never closed, and by convention, incoming is non-nil if and only if + // the transport is connected. + incoming chan jsonrpc.Message + + // We must guard both pushes to the incoming queue and writes to the response + // writer, because incoming POST requests are arbitrarily concurrent and we + // need to ensure we don't write push to the queue, or write to the + // ResponseWriter, after the session GET request exits. + mu sync.Mutex // also guards writes to Response + closed bool // set when the stream is closed + done chan struct{} // closed when the connection is closed +} + +// ServeHTTP handles POST requests to the transport endpoint. +func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.incoming == nil { + http.Error(w, "session not connected", http.StatusInternalServerError) + return + } + + // Read and parse the message. + data, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + // Optionally, we could just push the data onto a channel, and let the + // message fail to parse when it is read. This failure seems a bit more + // useful + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + if req, ok := msg.(*jsonrpc.Request); ok { + if _, err := checkRequest(req, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + select { + case t.incoming <- msg: + w.WriteHeader(http.StatusAccepted) + case <-t.done: + http.Error(w, "session closed", http.StatusBadRequest) + } +} + +// Connect sends the 'endpoint' event to the client. +// See [SSEServerTransport] for more details on the [Connection] implementation. +func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { + if t.incoming != nil { + return nil, fmt.Errorf("already connected") + } + t.incoming = make(chan jsonrpc.Message, 100) + t.done = make(chan struct{}) + _, err := writeEvent(t.Response, Event{ + Name: "endpoint", + Data: []byte(t.Endpoint), + }) + if err != nil { + return nil, err + } + return &sseServerConn{t: t}, nil +} + +func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + sessionID := req.URL.Query().Get("sessionid") + + // TODO: consider checking Content-Type here. For now, we are lax. + + // For POST requests, the message body is a message to send to a session. + if req.Method == http.MethodPost { + // Look up the session. + if sessionID == "" { + http.Error(w, "sessionid must be provided", http.StatusBadRequest) + return + } + h.mu.Lock() + session := h.sessions[sessionID] + h.mu.Unlock() + if session == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + session.ServeHTTP(w, req) + return + } + + if req.Method != http.MethodGet { + http.Error(w, "invalid method", http.StatusMethodNotAllowed) + return + } + + // GET requests create a new session, and serve messages over SSE. + + // TODO: it's not entirely documented whether we should check Accept here. + // Let's again be lax and assume the client will accept SSE. + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + sessionID = randText() + endpoint, err := req.URL.Parse("?sessionid=" + sessionID) + if err != nil { + http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) + return + } + + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} + + // The session is terminated when the request exits. + h.mu.Lock() + h.sessions[sessionID] = transport + h.mu.Unlock() + defer func() { + h.mu.Lock() + delete(h.sessions, sessionID) + h.mu.Unlock() + }() + + server := h.getServer(req) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + ss, err := server.Connect(req.Context(), transport, nil) + if err != nil { + http.Error(w, "connection failed", http.StatusInternalServerError) + return + } + if h.onConnection != nil { + h.onConnection(ss) + } + defer ss.Close() // close the transport when the GET exits + + select { + case <-req.Context().Done(): + case <-transport.done: + } +} + +// sseServerConn implements the [Connection] interface for a single [SSEServerTransport]. +// It hides the Connection interface from the SSEServerTransport API. +type sseServerConn struct { + t *SSEServerTransport +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (s *sseServerConn) SessionID() string { return "" } + +// Read implements jsonrpc2.Reader. +func (s *sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg := <-s.t.incoming: + return msg, nil + case <-s.t.done: + return nil, io.EOF + } +} + +// Write implements jsonrpc2.Writer. +func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if ctx.Err() != nil { + return ctx.Err() + } + + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + s.t.mu.Lock() + defer s.t.mu.Unlock() + + // Note that it is invalid to write to a ResponseWriter after ServeHTTP has + // exited, and so we must lock around this write and check isDone, which is + // set before the hanging GET exits. + if s.t.closed { + return io.EOF + } + + _, err = writeEvent(s.t.Response, Event{Name: "message", Data: data}) + return err +} + +// Close implements io.Closer, and closes the session. +// +// It must be safe to call Close more than once, as the close may +// asynchronously be initiated by either the server closing its connection, or +// by the hanging GET exiting. +func (s *sseServerConn) Close() error { + s.t.mu.Lock() + defer s.t.mu.Unlock() + if !s.t.closed { + s.t.closed = true + close(s.t.done) + } + return nil +} + +// An SSEClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the SSE transport defined by the 2024-11-05 version of the +// spec. +// +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEClientTransport struct { + // Endpoint is the SSE endpoint to connect to. + Endpoint string + + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client +} + +// Connect connects through the client endpoint. +func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { + parsedURL, err := url.Parse(c.Endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint: %v", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", c.Endpoint, nil) + if err != nil { + return nil, err + } + httpClient := c.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + req.Header.Set("Accept", "text/event-stream") + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + msgEndpoint, err := func() (*url.URL, error) { + var evt Event + for evt, err = range scanEvents(resp.Body) { + break + } + if err != nil { + return nil, err + } + if evt.Name != "endpoint" { + return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") + } + raw := string(evt.Data) + return parsedURL.Parse(raw) + }() + if err != nil { + resp.Body.Close() + return nil, fmt.Errorf("missing endpoint: %v", err) + } + + // From here on, the stream takes ownership of resp.Body. + s := &sseClientConn{ + client: httpClient, + msgEndpoint: msgEndpoint, + incoming: make(chan []byte, 100), + body: resp.Body, + done: make(chan struct{}), + } + + go func() { + defer s.Close() // close the transport when the GET exits + + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return + } + select { + case s.incoming <- evt.Data: + case <-s.done: + return + } + } + }() + + return s, nil +} + +// An sseClientConn is a logical jsonrpc2 connection that implements the client +// half of the SSE protocol: +// - Writes are POSTS to the session endpoint. +// - Reads are SSE 'message' events, and pushes them onto a buffered channel. +// - Close terminates the GET request. +type sseClientConn struct { + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages + + mu sync.Mutex + body io.ReadCloser // body of the hanging GET + closed bool // set when the stream is closed + done chan struct{} // closed when the stream is closed +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (c *sseClientConn) SessionID() string { return "" } + +func (c *sseClientConn) isDone() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case <-c.done: + return nil, io.EOF + + case data := <-c.incoming: + // TODO(rfindley): do we really need to check this? We receive from c.done above. + if c.isDone() { + return nil, io.EOF + } + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + return nil, err + } + return msg, nil + } +} + +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + if c.isDone() { + return io.EOF + } + req, err := http.NewRequestWithContext(ctx, "POST", c.msgEndpoint.String(), bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("failed to write: %s", resp.Status) + } + return nil +} + +func (c *sseClientConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + c.closed = true + _ = c.body.Close() + close(c.done) + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go new file mode 100644 index 000000000..b4b2fa310 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go @@ -0,0 +1,2040 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// NOTE: see streamable_server.go and streamable_client.go for detailed +// documentation of the streamable server design. +// TODO: move the client and server logic into those files. + +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "math" + "math/rand/v2" + "net/http" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" + lastEventIDHeader = "Last-Event-ID" +) + +// A StreamableHTTPHandler is an http.Handler that serves streamable MCP +// sessions, as defined by the [MCP spec]. +// +// [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html +type StreamableHTTPHandler struct { + getServer func(*http.Request) *Server + opts StreamableHTTPOptions + + onTransportDeletion func(sessionID string) // for testing + + mu sync.Mutex + sessions map[string]*sessionInfo // keyed by session ID +} + +type sessionInfo struct { + session *ServerSession + transport *StreamableServerTransport + // userID is the user ID from the TokenInfo when the session was created. + // If non-empty, subsequent requests must have the same user ID to prevent + // session hijacking. + userID string + + // If timeout is set, automatically close the session after an idle period. + timeout time.Duration + timerMu sync.Mutex + refs int // reference count + timer *time.Timer +} + +// startPOST signals that a POST request for this session is starting (which +// carries a client->server message), pausing the session timeout if it was +// running. +// +// TODO: we may want to also pause the timer when resuming non-standalone SSE +// streams, but that is tricy to implement. Clients should generally make +// keepalive pings if they want to keep the session live. +func (i *sessionInfo) startPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + if i.refs == 0 { + i.timer.Stop() + } + i.refs++ +} + +// endPOST sigals that a request for this session is ending, starting the +// timeout if there are no other requests running. +func (i *sessionInfo) endPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + + i.refs-- + assert(i.refs >= 0, "negative ref count") + if i.refs == 0 { + i.timer.Reset(i.timeout) + } +} + +// stopTimer stops the inactivity timer permanently. +func (i *sessionInfo) stopTimer() { + i.timerMu.Lock() + defer i.timerMu.Unlock() + if i.timer != nil { + i.timer.Stop() + i.timer = nil + } +} + +// StreamableHTTPOptions configures the StreamableHTTPHandler. +type StreamableHTTPOptions struct { + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. + Stateless bool + + // TODO(#148): support session retention (?) + + // JSONResponse causes streamable responses to return application/json rather + // than text/event-stream ([§2.1.5] of the spec). + // + // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + JSONResponse bool + + // Logger specifies the logger to use. + // If nil, do not log. + Logger *slog.Logger + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // SessionTimeout configures a timeout for idle sessions. + // + // When sessions receive no new HTTP requests from the client for this + // duration, they are automatically closed. + // + // If SessionTimeout is the zero value, idle sessions are never closed. + SessionTimeout time.Duration +} + +// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. +// +// The getServer function is used to create or look up servers for new +// sessions. It is OK for getServer to return the same server multiple times. +// If getServer returns nil, a 400 Bad Request will be served. +func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { + h := &StreamableHTTPHandler{ + getServer: getServer, + sessions: make(map[string]*sessionInfo), + } + if opts != nil { + h.opts = *opts + } + + if h.opts.Logger == nil { // ensure we have a logger + h.opts.Logger = ensureLogger(nil) + } + + return h +} + +// closeAll closes all ongoing sessions, for tests. +// +// TODO(rfindley): investigate the best API for callers to configure their +// session lifecycle. (?) +// +// Should we allow passing in a session store? That would allow the handler to +// be stateless. +func (h *StreamableHTTPHandler) closeAll() { + // TODO: if we ever expose this outside of tests, we'll need to do better + // than simply collecting sessions while holding the lock: we need to prevent + // new sessions from being added. + // + // Currently, sessions remove themselves from h.sessions when closed, so we + // can't call Close while holding the lock. + h.mu.Lock() + sessionInfos := slices.Collect(maps.Values(h.sessions)) + h.sessions = nil + h.mu.Unlock() + for _, s := range sessionInfos { + s.session.Close() + } +} + +func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Allow multiple 'Accept' headers. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax + accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") + var jsonOK, streamOK bool + for _, c := range accept { + switch strings.TrimSpace(c) { + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": + jsonOK = true + streamOK = true + } + } + + if req.Method == http.MethodGet { + if !streamOK { + http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) + return + } + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. + http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) + return + } + + sessionID := req.Header.Get(sessionIDHeader) + var sessInfo *sessionInfo + if sessionID != "" { + h.mu.Lock() + sessInfo = h.sessions[sessionID] + h.mu.Unlock() + if sessInfo == nil && !h.opts.Stateless { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. + http.Error(w, "session not found", http.StatusNotFound) + return + } + // Prevent session hijacking: if the session was created with a user ID, + // verify that subsequent requests come from the same user. + if sessInfo != nil && sessInfo.userID != "" { + tokenInfo := auth.TokenInfoFromContext(req.Context()) + if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID { + http.Error(w, "session user mismatch", http.StatusForbidden) + return + } + } + } + + if req.Method == http.MethodDelete { + if sessionID == "" { + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + return + } + if sessInfo != nil { // sessInfo may be nil in stateless mode + // Closing the session also removes it from h.sessions, due to the + // onClose callback. + sessInfo.session.Close() + } + w.WriteHeader(http.StatusNoContent) + return + } + + switch req.Method { + case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } + default: + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) + return + } + + // [§2.7] of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + return + } + + if sessInfo == nil { + server := h.getServer(req) + if server == nil { + // The getServer argument to NewStreamableHTTPHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = server.opts.GetSessionID() + } + transport := &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + } + + // Sessions without a session ID are also stateless: there's no way to + // address them. + stateless := h.opts.Stateless || sessionID == "" + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } + state.LogLevel = "info" + connectOpts = &ServerSessionOptions{ + State: state, + } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + defer h.mu.Unlock() + if info, ok := h.sessions[transport.SessionID]; ok { + info.stopTimer() + delete(h.sessions, transport.SessionID) + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + } + }, + } + } + + // Pass req.Context() here, to allow middleware to add context values. + // The context is detached in the jsonrpc2 library when handling the + // long-running stream. + session, err := server.Connect(req.Context(), transport, connectOpts) + if err != nil { + http.Error(w, "failed connection", http.StatusInternalServerError) + return + } + // Capture the user ID from the token info to enable session hijacking + // prevention on subsequent requests. + var userID string + if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { + userID = tokenInfo.UserID + } + sessInfo = &sessionInfo{ + session: session, + transport: transport, + userID: userID, + } + + if stateless { + // Stateless mode: close the session when the request exits. + defer session.Close() // close the fake session after handling the request + } else { + // Otherwise, save the transport so that it can be reused + + // Clean up the session when it times out. + // + // Note that the timer here may fire multiple times, but + // sessInfo.session.Close is idempotent. + if h.opts.SessionTimeout > 0 { + sessInfo.timeout = h.opts.SessionTimeout + sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { + sessInfo.session.Close() + }) + } + h.mu.Lock() + h.sessions[transport.SessionID] = sessInfo + h.mu.Unlock() + defer func() { + // If initialization failed, clean up the session (#578). + if session.InitializeParams() == nil { + // Initialization failed. + session.Close() + } + }() + } + } + + if req.Method == http.MethodPost { + sessInfo.startPOST() + defer sessInfo.endPOST() + } + + sessInfo.transport.ServeHTTP(w, req) +} + +// A StreamableServerTransport implements the server side of the MCP streamable +// transport. +// +// Each StreamableServerTransport must be connected (via [Server.Connect]) at +// most once, since [StreamableServerTransport.ServeHTTP] serves messages to +// the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the related stream, or to the standalone SSE stream, +// according to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the standalone SSE stream. +// - Requests or notifications made with a detached context.Context value are +// routed to the standalone SSE stream. +type StreamableServerTransport struct { + // SessionID is the ID of this session. + // + // If SessionID is the empty string, this is a 'stateless' session, which has + // limited ability to communicate with the client. Otherwise, the session ID + // must be globally unique, that is, different from any other session ID + // anywhere, past and future. (We recommend using a crypto random number + // generator to produce one, as with [crypto/rand.Text].) + SessionID string + + // Stateless controls whether the eventstore is 'Stateless'. Server sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // jsonResponse, if set, tells the server to prefer to respond to requests + // using application/json responses rather than text/event-stream. + // + // Specifically, responses will be application/json whenever incoming POST + // request contain only a single message. In this case, notifications or + // requests made within the context of a server request will be sent to the + // standalone SSE stream, if any. + // + // TODO(rfindley): jsonResponse should be exported, since + // StreamableHTTPOptions.JSONResponse is exported, and we want to allow users + // to write their own streamable HTTP handler. + jsonResponse bool + + // optional logger provided through the [StreamableHTTPOptions.Logger]. + // + // TODO(rfindley): logger should be exported, since we want to allow users + // to write their own streamable HTTP handler. + logger *slog.Logger + + // connection is non-nil if and only if the transport has been connected. + connection *streamableServerConn +} + +// Connect implements the [Transport] interface. +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { + if t.connection != nil { + return nil, fmt.Errorf("transport already connected") + } + t.connection = &streamableServerConn{ + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + logger: ensureLogger(t.logger), // see #556: must be non-nil + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), + } + // Stream 0 corresponds to the standalone SSE stream. + // + // It is always text/event-stream, since it must carry arbitrarily many + // messages. + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, nil, "") + if err != nil { + return nil, err + } + return t.connection, nil +} + +type streamableServerConn struct { + sessionID string + stateless bool + jsonResponse bool + eventStore EventStore + + logger *slog.Logger + + incoming chan jsonrpc.Message // messages from the client to the server + + mu sync.Mutex // guards all fields below + + // Sessions are closed exactly once. + isDone bool + done chan struct{} + + // Sessions can have multiple logical connections (which we call streams), + // corresponding to HTTP requests. Additionally, streams may be resumed by + // subsequent HTTP requests, when the HTTP connection is terminated + // unexpectedly. + // + // Therefore, we use a logical stream ID to key the stream state, and + // perform the accounting described below when incoming HTTP requests are + // handled. + + // streams holds the logical streams for this session, keyed by their ID. + // + // Lifecycle: streams persist until all of their responses are received from + // the server. + streams map[string]*stream + + // requestStreams maps incoming requests to their logical stream ID. + // + // Lifecycle: requestStreams persist until their response is received. + requestStreams map[jsonrpc.ID]string +} + +func (c *streamableServerConn) SessionID() string { + return c.sessionID +} + +// A stream is a single logical stream of SSE events within a server session. +// A stream begins with a client request, or with a client GET that has +// no Last-Event-ID header. +// +// A stream ends only when its session ends; we cannot determine its end otherwise, +// since a client may send a GET with a Last-Event-ID that references the stream +// at any time. +type stream struct { + // id is the logical ID for the stream, unique within a session. + // + // The standalone SSE stream has id "". + id string + + // logger is used for logging errors during stream operations. + logger *slog.Logger + + // mu guards the fields below, as well as storage of new messages in the + // connection's event store (if any). + mu sync.Mutex + + // If pendingJSONMessages is non-nil, this is a JSON stream and messages are + // collected here until the stream is complete, at which point they are + // flushed as a single JSON response. Note that the non-nilness of this field + // is significant, as it signals the expected content type. + // + // Note: if we remove support for batching, this could just be a bool. + pendingJSONMessages []json.RawMessage + + // w is the HTTP response writer for this stream. A non-nil w indicates + // that the stream is claimed by an HTTP request (the hanging POST or GET); + // it is set to nil when the request completes. + w http.ResponseWriter + + // done is closed to release the hanging HTTP request. + // + // Invariant: a non-nil done implies w is also non-nil, though the converse + // is not necessarily true: done is set to nil when it is closed, to avoid + // duplicate closure. + done chan struct{} + + // lastIdx is the index of the last written SSE event, for event ID generation. + // It starts at -1 since indices start at 0. + lastIdx int + + // protocolVersion is the protocol version for this stream. + protocolVersion string + + // requests is the set of unanswered incoming requests for the stream. + // + // Requests are removed when their response has been received. + // In practice, there is only one request, but in the 2025-03-26 version of + // the spec and earlier there was a concept of batching, in which POST + // payloads could hold multiple requests or responses. + requests map[jsonrpc.ID]struct{} +} + +// close sends a 'close' event to the client (if protocolVersion >= 2025-11-25 +// and reconnectAfter > 0) and closes the done channel. +// +// The done channel is set to nil after closing, so that done != nil implies +// the stream is active and done is open. This simplifies checks elsewhere. +func (s *stream) close(reconnectAfter time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done == nil { + return // stream not connected or already closed + } + if s.protocolVersion >= protocolVersion20251125 && reconnectAfter > 0 { + reconnectStr := strconv.FormatInt(reconnectAfter.Milliseconds(), 10) + if _, err := writeEvent(s.w, Event{ + Name: "close", + Retry: reconnectStr, + }); err != nil { + s.logger.Warn(fmt.Sprintf("Writing close event: %v", err)) + } + } + close(s.done) + s.done = nil +} + +// release releases the stream from its HTTP request, allowing it to be +// claimed by another request (e.g., for resumption). +func (s *stream) release() { + s.mu.Lock() + defer s.mu.Unlock() + s.w = nil + s.done = nil // may already be nil, if the stream is done or closed +} + +// deliverLocked writes data to the stream (for SSE) or stores it in +// pendingJSONMessages (for JSON mode). The eventID is used for SSE event ID; +// pass "" to omit. +// +// If responseTo is valid, it is removed from the requests map. When all +// requests have been responded to, the done channel is closed and set to nil. +// +// Returns true if the stream is now done (all requests have been responded to). +// The done value is always accurate, even if an error is returned. +// +// s.mu must be held when calling this method. +func (s *stream) deliverLocked(data []byte, eventID string, responseTo jsonrpc.ID) (done bool, err error) { + // First, record the response. We must do this *before* returning an error + // below, as even if the stream is disconnected we want to update our + // accounting. + if responseTo.IsValid() { + delete(s.requests, responseTo) + } + // Now, try to deliver the message to the client. + done = len(s.requests) == 0 && s.id != "" + if s.done == nil { + return done, fmt.Errorf("stream not connected or already closed") + } + if done { + defer func() { close(s.done); s.done = nil }() + } + // Try to write to the response. + // + // If we get here, the request is still hanging (because s.done != nil + // implies s.w != nil), but may have been cancelled by the client/http layer: + // there's a brief race between request cancellation and releasing the + // stream. + if s.pendingJSONMessages != nil { + s.pendingJSONMessages = append(s.pendingJSONMessages, data) + if done { + // Flush all pending messages as JSON response. + var toWrite []byte + if len(s.pendingJSONMessages) == 1 { + toWrite = s.pendingJSONMessages[0] + } else { + toWrite, err = json.Marshal(s.pendingJSONMessages) + if err != nil { + return done, err + } + } + if _, err := s.w.Write(toWrite); err != nil { + return done, err + } + } + } else { + // SSE mode: write event to response writer. + s.lastIdx++ + if _, err := writeEvent(s.w, Event{Name: "message", Data: data, ID: eventID}); err != nil { + return done, err + } + } + return done, nil +} + +// doneLocked reports whether the stream is logically complete. +// +// s.requests was populated when reading the POST body, requests are deleted as +// they are responded to. Once all requests have been responded to, the stream +// is done. +// +// s.mu must be held while calling this function. +func (s *stream) doneLocked() bool { + return len(s.requests) == 0 && s.id != "" +} + +func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonrpc.ID]struct{}, id string) (*stream, error) { + if c.eventStore != nil { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } + } + return &stream{ + id: id, + requests: requests, + lastIdx: -1, // indices start at 0, incremented before each write + logger: c.logger, + }, nil +} + +// We track the incoming request ID inside the handler context using +// idContextValue, so that notifications and server->client calls that occur in +// the course of handling incoming requests are correlated with the incoming +// request that caused them, and can be dispatched as server-sent events to the +// correct HTTP request. +// +// Currently, this is implemented in [ServerSession.handle]. This is not ideal, +// because it means that a user of the MCP package couldn't implement the +// streamable transport, as they'd lack this privileged access. +// +// If we ever wanted to expose this mechanism, we have a few options: +// 1. Make ServerSession an interface, and provide an implementation of +// ServerSession to handlers that closes over the incoming request ID. +// 2. Expose a 'HandlerTransport' interface that allows transports to provide +// a handler middleware, so that we don't hard-code this behavior in +// ServerSession.handle. +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets +// any transport access the incoming request ID. +// +// For now, by giving only the StreamableServerTransport access to the request +// ID, we avoid having to make this API decision. +type idContextKey struct{} + +// ServeHTTP handles a single HTTP request for the session. +func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.connection == nil { + http.Error(w, "transport not connected", http.StatusInternalServerError) + return + } + switch req.Method { + case http.MethodGet: + t.connection.serveGET(w, req) + case http.MethodPost: + t.connection.servePOST(w, req) + default: + // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. + w.Header().Set("Allow", "GET, POST") + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return + } +} + +// serveGET streams messages to a hanging http GET, with stream ID and last +// message parsed from the Last-Event-ID header. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { + // streamID "" corresponds to the default GET request. + streamID := "" + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written. + lastIdx := -1 + if len(req.Header.Values(lastEventIDHeader)) > 0 { + eid := req.Header.Get(lastEventIDHeader) + var ok bool + streamID, lastIdx, ok = parseEventID(eid) + if !ok { + http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) + return + } + if c.eventStore == nil { + http.Error(w, "stream replay unsupported", http.StatusBadRequest) + return + } + } + + ctx := req.Context() + + // Read the protocol version from the header. For GET requests, this should + // always be present since GET only happens after initialization. + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + stream, done := c.acquireStream(ctx, w, streamID, lastIdx, protocolVersion) + if stream == nil { + return + } + defer stream.release() + c.hangResponse(ctx, done) +} + +// hangResponse blocks the HTTP response until one of three conditions is met: +// - ctx is cancelled (the client disconnected or the request timed out) +// - done is closed (all responses have been sent, or the stream was explicitly closed) +// - the session is closed +// +// This keeps the HTTP connection open so that server-sent events can be +// written to the response. +func (c *streamableServerConn) hangResponse(ctx context.Context, done <-chan struct{}) { + select { + case <-ctx.Done(): + case <-done: + case <-c.done: + } +} + +// acquireStream replays all events since lastIdx, and acquires the ongoing +// stream, if any. If non-nil, the resulting stream will be registered for +// receiving new messages, and the stream's done channel will be closed when +// all related messages have been delivered. +// +// If any errors occur, they will be written to w and the resulting stream will +// be nil. The resulting stream may also be nil if the stream is complete. +// +// Importantly, this function must hold the stream mutex until done replaying +// all messages, so that no delivery or storage of new messages occurs while +// the stream is still replaying. +// +// protocolVersion is the protocol version for this stream, used to determine +// feature support (e.g. prime and close events were added in 2025-11-25). +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx int, protocolVersion string) (*stream, chan struct{}) { + // if tempStream is set, the stream is done and we're just replaying messages. + // + // We record a temporary stream to claim exclusive replay rights. The spec + // (https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#resumability-and-redelivery) + // does not explicitly require exclusive replay, but we enforce it defensively. + tempStream := false + c.mu.Lock() + s, ok := c.streams[streamID] + if !ok { + // The stream is logically done, but claim exclusive rights to replay it by + // adding a temporary entry in the streams map. + // + // We create this entry with a non-nil w, to ensure it isn't claimed by + // another request before we lock it below. + tempStream = true + s = &stream{ + id: streamID, + w: w, + } + c.streams[streamID] = s + + // Since this stream is transient, we must clean up after replaying. + defer func() { + c.mu.Lock() + delete(c.streams, streamID) + c.mu.Unlock() + }() + } + c.mu.Unlock() + + s.mu.Lock() + defer s.mu.Unlock() + + // Check that this stream wasn't claimed by another request. + if !tempStream && s.w != nil { + http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) + return nil, nil + } + + // Collect events to replay. Collect them all before writing, so that we + // have an opportunity to set the HTTP status code on an error. + // + // As indicated above, we must do that while holding stream.mu, so that no + // new messages are added to the eventstore until we've replayed all previous + // messages, and registered our delivery function. + var toReplay [][]byte + if c.eventStore != nil { + for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, lastIdx) { + if err != nil { + // We can't replay events, perhaps because the underlying event store + // has garbage collected its storage. + // + // We must be careful here: any 404 will signal to the client that the + // *session* is not found, rather than the stream. + // + // 400 is not really accurate, but should at least have no side effects. + // Other SDKs (typescript) do not have a mechanism for events to be purged. + http.Error(w, "failed to replay events", http.StatusBadRequest) + return nil, nil + } + if len(data) > 0 { + toReplay = append(toReplay, data) + } + } + } + + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + + if s.id == "" { + // Issue #410: the standalone SSE stream is likely not to receive messages + // for a long time. Ensure that headers are flushed. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + for _, data := range toReplay { + lastIdx++ + e := Event{Name: "message", Data: data} + if c.eventStore != nil { + e.ID = formatEventID(s.id, lastIdx) + } + if _, err := writeEvent(w, e); err != nil { + return nil, nil + } + } + + if tempStream || s.doneLocked() { + // Nothing more to do. + return nil, nil + } + + // The stream is not done: set up delivery state before the stream is + // unlocked, allowing the connection to write new events. + s.w = w + s.done = make(chan struct{}) + s.lastIdx = lastIdx + s.protocolVersion = protocolVersion + return s, s.done +} + +// servePOST handles an incoming message, and replies with either an outgoing +// message stream or single response object, depending on whether the +// jsonResponse option is set. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { + if len(req.Header.Values(lastEventIDHeader)) > 0 { + http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) + return + } + + // Read incoming messages. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + if len(body) == 0 { + http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) + return + } + // TODO(#674): once we've documented the support matrix for 2025-03-26 and + // earlier, drop support for matching entirely; that will simplify this + // logic. + incoming, isBatch, err := readBatch(body) + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } + + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + if isBatch && protocolVersion >= protocolVersion20250618 { + http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest) + return + } + + // TODO(rfindley): no tests fail if we reject batch JSON requests entirely. + // We need to test this with older protocol versions. + // if isBatch && c.jsonResponse { + // http.Error(w, "server does not support batch requests", http.StatusBadRequest) + // return + // } + + calls := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false + var initializeProtocolVersion string + for _, msg := range incoming { + if jreq, ok := msg.(*jsonrpc.Request); ok { + // Preemptively check that this is a valid request, so that we can fail + // the HTTP request. If we didn't do this, a request with a bad method or + // missing ID could be silently swallowed. + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if jreq.Method == methodInitialize { + isInitialize = true + // Extract the protocol version from InitializeParams. + var params InitializeParams + if err := json.Unmarshal(jreq.Params, ¶ms); err == nil { + initializeProtocolVersion = params.ProtocolVersion + } + } + // Include metadata for all requests (including notifications). + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + calls[jreq.ID] = struct{}{} + // See the doc for CloseSSEStream: allow the request handler to + // explicitly close the ongoing stream. + jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) { + c.mu.Lock() + streamID, ok := c.requestStreams[jreq.ID] + var stream *stream + if ok { + stream = c.streams[streamID] + } + c.mu.Unlock() + + if stream != nil { + stream.close(args.RetryAfter) + } + } + } + } + } + + // The prime and close events were added in protocol version 2025-11-25 (SEP-1699). + // Use the version from InitializeParams if this is an initialize request, + // otherwise use the protocol version header. + effectiveVersion := protocolVersion + if isInitialize && initializeProtocolVersion != "" { + effectiveVersion = initializeProtocolVersion + } + + // If we don't have any calls, we can just publish the incoming messages and return. + // No need to track a logical stream. + // + // See section [§2.1.4] of the spec: "If the server accepts the input, the + // server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + if len(calls) == 0 { + for _, msg := range incoming { + select { + case c.incoming <- msg: + case <-c.done: + // The session is closing. Since we haven't yet written any data to the + // response, we can signal to the client that the session is gone. + http.Error(w, "session is closing", http.StatusNotFound) + return + } + } + w.WriteHeader(http.StatusAccepted) + return + } + + // Invariant: we have at least one call. + // + // Create a logical stream to track its responses. + // Important: don't publish the incoming messages until the stream is + // registered, as the server may attempt to respond to imcoming messages as + // soon as they're published. + stream, err := c.newStream(req.Context(), calls, randText()) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } + + // Set response headers. Accept was checked in [StreamableHTTPHandler]. + w.Header().Set("Cache-Control", "no-cache, no-transform") + if c.jsonResponse { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + } + if c.sessionID != "" && isInitialize { + w.Header().Set(sessionIDHeader, c.sessionID) + } + + // Set up stream delivery state. + stream.w = w + done := make(chan struct{}) + stream.done = done + stream.protocolVersion = effectiveVersion + if c.jsonResponse { + // JSON mode: collect messages in pendingJSONMessages until done. + // Set pendingJSONMessages to a non-nil value to signal that this is an + // application/json stream. + stream.pendingJSONMessages = []json.RawMessage{} + } else { + // SSE mode: write a priming event if supported. + if c.eventStore != nil && effectiveVersion >= protocolVersion20251125 { + // Write a priming event, as defined by [§2.1.6] of the spec. + // + // [§2.1.6]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + // + // We must also write it to the event store in order for indexes to + // align. + if err := c.eventStore.Append(req.Context(), c.sessionID, stream.id, nil); err != nil { + c.logger.Warn(fmt.Sprintf("Storing priming event: %v", err)) + } + stream.lastIdx++ + e := Event{Name: "prime", ID: formatEventID(stream.id, stream.lastIdx)} + if _, err := writeEvent(w, e); err != nil { + c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) + } + } + } + + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() + + // The stream is now set up to deliver messages. + // + // Register it before publishing incoming messages. + c.mu.Lock() + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // Publish incoming messages. + for _, msg := range incoming { + select { + case c.incoming <- msg: + // Note: don't select on req.Context().Done() here, since we've already + // received the requests and may have already published a response message + // or notification. The client could resume the stream. + // + // In fact, this send could be in a separate goroutine. + case <-c.done: + // Session closed: we don't know if any data has been written, so it's + // too late to write a status code here. + return + } + } + + c.hangResponse(req.Context(), done) +} + +// Event IDs: encode both the logical connection ID and the index, as +// _, to be consistent with the typescript implementation. + +// formatEventID returns the event ID to use for the logical connection ID +// streamID and message index idx. +// +// See also [parseEventID]. +func formatEventID(sid string, idx int) string { + return fmt.Sprintf("%s_%d", sid, idx) +} + +// parseEventID parses a Last-Event-ID value into a logical stream id and +// index. +// +// See also [formatEventID]. +func parseEventID(eventID string) (streamID string, idx int, ok bool) { + parts := strings.Split(eventID, "_") + if len(parts) != 2 { + return "", 0, false + } + streamID = parts[0] + idx, err := strconv.Atoi(parts[1]) + if err != nil || idx < 0 { + return "", 0, false + } + return streamID, idx, true +} + +// Read implements the [Connection] interface. +func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.incoming: + if !ok { + return nil, io.EOF + } + return msg, nil + case <-c.done: + return nil, io.EOF + } +} + +// Write implements the [Connection] interface. +func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // Throughout this function, note that any error that wraps ErrRejected + // indicates a does not cause the connection to break. + // + // Most errors don't break the connection: unlike a true bidirectional + // stream, a failure to deliver to a stream is not an indication that the + // logical session is broken. + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } + + // Find the incoming request that this write relates to, if any. + var ( + relatedRequest jsonrpc.ID + responseTo jsonrpc.ID // if valid, the message is a response to this request + ) + if resp, ok := msg.(*jsonrpc.Response); ok { + // If the message is a response, it relates to its request (of course). + relatedRequest = resp.ID + responseTo = resp.ID + } else { + // Otherwise, we check to see if it request was made in the context of an + // ongoing request. This may not be the case if the request was made with + // an unrelated context. + if v := ctx.Value(idContextKey{}); v != nil { + relatedRequest = v.(jsonrpc.ID) + } + } + + // If the stream is application/json, but the message is not a response, we + // must send it out of band to the standalone SSE stream. + if c.jsonResponse && !responseTo.IsValid() { + relatedRequest = jsonrpc.ID{} + } + + // Write the message to the stream. + var s *stream + c.mu.Lock() + if relatedRequest.IsValid() { + if streamID, ok := c.requestStreams[relatedRequest]; ok { + s = c.streams[streamID] + } + } else { + s = c.streams[""] // standalone SSE stream + } + if responseTo.IsValid() { + // Once we've responded to a request, disallow related messages by removing + // the stream association. This also releases memory. + delete(c.requestStreams, responseTo) + } + sessionClosed := c.isDone + c.mu.Unlock() + + if s == nil { + // The request was made in the context of an ongoing request, but that + // request is complete. + // + // In the future, we could be less strict and allow the request to land on + // the standalone SSE stream. + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) + } + if sessionClosed { + return errors.New("session is closed") + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Store in eventStore before delivering. + // TODO(rfindley): we should only append if the response is SSE, not JSON, by + // pushing down into the delivery layer. + delivered := false + var errs []error + if c.eventStore != nil { + if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { + errs = append(errs, err) + } else { + delivered = true + } + } + + // Compute eventID for SSE streams with event store. + // Use s.lastIdx + 1 because deliverLocked increments before writing. + var eventID string + if c.eventStore != nil { + eventID = formatEventID(s.id, s.lastIdx+1) + } + + done, err := s.deliverLocked(data, eventID, responseTo) + if err != nil { + errs = append(errs, err) + } else { + delivered = true + } + + if done { + c.mu.Lock() + delete(c.streams, s.id) + c.mu.Unlock() + } + + if !delivered { + return fmt.Errorf("%w: undelivered message: %v", jsonrpc2.ErrRejected, errors.Join(errs...)) + } + return nil +} + +// Close implements the [Connection] interface. +func (c *streamableServerConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.isDone { + c.isDone = true + close(c.done) + if c.eventStore != nil { + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + } + } + return nil +} + +// A StreamableClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the streamable HTTP transport defined by the 2025-03-26 +// version of the spec. +type StreamableClientTransport struct { + Endpoint string + HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int + + // TODO(rfindley): propose exporting these. + // If strict is set, the transport is in 'strict mode', where any violation + // of the MCP spec causes a failure. + strict bool + // If logger is set, it is used to log aspects of the transport, such as spec + // violations that were ignored. + logger *slog.Logger +} + +// These settings are not (yet) exposed to the user in +// StreamableClientTransport. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) + +var ( + // reconnectInitialDelay is the base delay for the first reconnect attempt. + // + // Mutable for testing. + reconnectInitialDelay = 1 * time.Second +) + +// Connect implements the [Transport] interface. +// +// The resulting [Connection] writes messages via POST requests to the +// transport URL with the Mcp-Session-Id header set, and reads messages from +// hanging requests. +// +// When closed, the connection issues a DELETE request to terminate the logical +// session. +func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { + client := t.HTTPClient + if client == nil { + client = http.DefaultClient + } + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + // + // This context should be detached from the incoming context: the standalone + // SSE request should not break when the connection context is done. + // + // For example, consider that the user may want to wait at most 5s to connect + // to the server, and therefore uses a context with a 5s timeout when calling + // client.Connect. Let's suppose that Connect returns after 1s, and the user + // starts using the resulting session. If we didn't detach here, the session + // would break after 4s, when the background SSE stream is terminated. + // + // Instead, creating a cancellable context detached from the incoming context + // allows us to preserve context values (which may be necessary for auth + // middleware), yet only cancel the standalone stream when the connection is closed. + connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) + conn := &streamableClientConn{ + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: ensureLogger(t.logger), // must be non-nil for safe logging + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + } + return conn, nil +} + +type streamableClientConn struct { + url string + client *http.Client + ctx context.Context // connection context, detached from Connect + cancel context.CancelFunc // cancels ctx + incoming chan jsonrpc.Message + maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] + + // Guard calls to Close, as it may be called multiple times. + closeOnce sync.Once + closeErr error + done chan struct{} // signal graceful termination + + // Logical reads are distributed across multiple http requests. Whenever any + // of them fails to process their response, we must break the connection, by + // failing the pending Read. + // + // Achieve this by storing the failure message, and signalling when reads are + // broken. See also [streamableClientConn.fail] and + // [streamableClientConn.failure]. + failOnce sync.Once + _failure error + failed chan struct{} // signal failure + + // Guard the initialization state. + mu sync.Mutex + initializedResult *InitializeResult + sessionID string +} + +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { + c.mu.Lock() + c.initializedResult = state.InitializeResult + c.mu.Unlock() + + // Start the standalone SSE stream as soon as we have the initialized + // result. + // + // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be + // used to open an SSE stream, allowing the server to communicate to the + // client, without the client first sending data via HTTP POST. + // + // We have to wait for initialized, because until we've received + // initialized, we don't know whether the server requires a sessionID. + // + // § 2.5: A server using the Streamable HTTP transport MAY assign a session + // ID at initialization time, by including it in an Mcp-Session-Id header + // on the HTTP response containing the InitializeResult. + c.connectStandaloneSSE() +} + +func (c *streamableClientConn) connectStandaloneSSE() { + resp, err := c.connectSSE(c.ctx, "", 0, true) + if err != nil { + // If the client didn't cancel the request, and failure breaks the logical + // session. + if c.ctx.Err() == nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + } + return + } + + // [§2.2.3]: "The server MUST either return Content-Type: + // text/event-stream in response to this HTTP GET, or else return HTTP + // 405 Method Not Allowed, indicating that the server does not offer an + // SSE stream at this endpoint." + // + // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode == http.StatusMethodNotAllowed { + // The server doesn't support the standalone SSE stream. + resp.Body.Close() + return + } + if resp.StatusCode >= 400 && resp.StatusCode < 500 && !c.strict { + // modelcontextprotocol/go-sdk#393,#610: some servers return NotFound or + // other status codes instead of MethodNotAllowed for the standalone SSE + // stream. + // + // Treat this like MethodNotAllowed in non-strict mode. + c.logger.Warn(fmt.Sprintf("got %d instead of 405 for standalone SSE stream", resp.StatusCode)) + resp.Body.Close() + return + } + summary := "standalone SSE stream" + if err := c.checkResponse(summary, resp); err != nil { + c.fail(err) + return + } + go c.handleSSE(c.ctx, summary, resp, nil) +} + +// fail handles an asynchronous error while reading. +// +// If err is non-nil, it is terminal, and subsequent (or pending) Reads will +// fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. +func (c *streamableClientConn) fail(err error) { + if err != nil { + c.failOnce.Do(func() { + c._failure = err + close(c.failed) + }) + } +} + +func (c *streamableClientConn) failure() error { + select { + case <-c.failed: + return c._failure + default: + return nil + } +} + +func (c *streamableClientConn) SessionID() string { + c.mu.Lock() + defer c.mu.Unlock() + return c.sessionID +} + +// Read implements the [Connection] interface. +func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + if err := c.failure(); err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.failed: + return nil, c.failure() + case <-c.done: + return nil, io.EOF + case msg := <-c.incoming: + return msg, nil + } +} + +// Write implements the [Connection] interface. +func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if err := c.failure(); err != nil { + return err + } + + var requestSummary string + var forCall *jsonrpc.Request + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + if msg.IsCall() { + forCall = msg + } + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + + data, err := jsonrpc.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + c.setMCPHeaders(req) + + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + } + + if err := c.checkResponse(requestSummary, resp); err != nil { + // Only fail the connection for non-transient errors. + // Transient errors (wrapped with ErrRejected) should not break the connection. + if !errors.Is(err, jsonrpc2.ErrRejected) { + c.fail(err) + } + return err + } + + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + c.mu.Lock() + hadSessionID := c.sessionID + if hadSessionID == "" { + c.sessionID = sessionID + } + c.mu.Unlock() + if hadSessionID != "" && hadSessionID != sessionID { + resp.Body.Close() + return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) + } + } + + if forCall == nil { + resp.Body.Close() + + // [§2.1.4]: "If the input is a JSON-RPC response or notification: + // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusAccepted { + errMsg := fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode) + // Some servers return 200, even with an empty json body. + // + // In strict mode, return an error to the caller. + c.logger.Warn(errMsg) + if c.strict { + return errors.New(errMsg) + } + } + return nil + } + + contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0]) + switch contentType { + case "application/json": + go c.handleJSON(requestSummary, resp) + + case "text/event-stream": + var forCall *jsonrpc.Request + if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() { + forCall = jsonReq + } + // Handle the resulting stream. Note that ctx comes from the call, and + // therefore is already cancelled when the JSON-RPC request is cancelled + // (or rather, context cancellation is what *triggers* JSON-RPC + // cancellation) + go c.handleSSE(ctx, requestSummary, resp, forCall) + + default: + resp.Body.Close() + return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType) + } + return nil +} + +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth atomic.Bool + +func (c *streamableClientConn) setMCPHeaders(req *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initializedResult != nil { + req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } + if c.sessionID != "" { + req.Header.Set(sessionIDHeader, c.sessionID) + } + if testAuth.Load() { + req.Header.Set("Authorization", "Bearer foo") + } +} + +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) + return + } + select { + case c.incoming <- msg: + case <-c.done: + // The connection was closed by the client; exit gracefully. + } +} + +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +// +// If forCall is set, it is the call that initiated the stream, and the +// stream is complete when we receive its response. Otherwise, this is the +// standalone stream. +func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc2.Request) { + for { + // Connection was successful. Continue the loop with the new response. + // + // TODO(#679): we should set a reasonable limit on the number of times + // we'll try getting a response for a given request, or enforce that we + // actually make progress. + // + // Eventually, if we don't get the response, we should stop trying and + // fail the request. + lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If we don't have a last event ID, we can never get the call response, so + // there's nothing to resume. For the standalone stream, we can reconnect, + // but we may just miss messages. + if lastEventID == "" && forCall != nil { + return + } + + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) + if err != nil { + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } + return + } + + resp = newResp + if err := c.checkResponse(requestSummary, resp); err != nil { + c.fail(err) + return + } + } +} + +// checkResponse checks the status code of the provided response, and +// translates it into an error if the request was unsuccessful. +// +// The response body is close if a non-nil error is returned. +func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) { + defer func() { + if err != nil { + resp.Body.Close() + } + }() + // §2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Return an errSessionMissing to avoid sending a redundant DELETE when the + // session is already gone. + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + } + // Transient server errors (502, 503, 504, 429) should not break the connection. + // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. + if isTransientHTTPStatus(resp.StatusCode) { + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) + } + return nil +} + +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event and a flag +// indicating if the connection was closed by the client. If resp is nil, it +// returns "", false. +func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { + defer func() { + // Drain any remaining unprocessed body. This allows the connection to be re-used after closing. + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + for evt, err := range scanEvents(resp.Body) { + if err != nil { + if ctx.Err() != nil { + return "", 0, true // don't reconnect: client cancelled + } + break + } + + if evt.ID != "" { + lastEventID = evt.ID + } + + if evt.Retry != "" { + if n, err := strconv.ParseInt(evt.Retry, 10, 64); err == nil { + reconnectDelay = time.Duration(n) * time.Millisecond + } + } + // According to SSE spec, events with no name default to "message" + if evt.Name != "" && evt.Name != "message" { + continue + } + + msg, err := jsonrpc.DecodeMessage(evt.Data) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) + return "", 0, true + } + + select { + case c.incoming <- msg: + // Check if this is the response to our call, which terminates the request. + // (it could also be a server->client request or notification). + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { + // TODO: we should never get a response when forReq is nil (the standalone SSE request). + // We should detect this case. + if jsonResp.ID == forCall.ID { + return "", 0, true + } + } + + case <-c.done: + // The connection was closed by the client; exit gracefully. + return "", 0, true + } + } + // The loop finished without an error, indicating the server closed the stream. + // + // If the lastEventID is "", the stream is not retryable and we should + // report a synthetic error for the call. + // + // Note that this is different from the cancellation case above, since the + // caller is still waiting for a response that will never come. + if lastEventID == "" && forCall != nil { + errmsg := &jsonrpc2.Response{ + ID: forCall.ID, + Error: fmt.Errorf("request terminated without response"), + } + select { + case c.incoming <- errmsg: + case <-c.done: + } + } + return lastEventID, reconnectDelay, false +} + +// connectSSE handles the logic of connecting a text/event-stream connection. +// +// If lastEventID is set, it is the last-event ID of a stream being resumed. +// +// If connection fails, connectSSE retries with an exponential backoff +// strategy. It returns a new, valid HTTP response if successful, or an error +// if all retries are exhausted. +// +// reconnectDelay is the delay set by the server using the SSE retry field, or +// 0. +// +// If initial is set, this is the initial attempt. +// +// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()). +func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) { + var finalErr error + attempt := 0 + if !initial { + // We've already connected successfully once, so delay subsequent + // reconnections. Otherwise, if the server returns 200 but terminates the + // connection, we'll reconnect as fast as we can, ad infinitum. + // + // TODO: we should consider also setting a limit on total attempts for one + // logical request. + attempt = 1 + } + delay := calculateReconnectDelay(attempt) + if reconnectDelay > 0 { + delay = reconnectDelay // honor the server's requested initial delay + } + for ; attempt <= c.maxRetries; attempt++ { + select { + case <-c.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + + case <-ctx.Done(): + // If the connection context is canceled, the request below will not + // succeed anyway. + return nil, ctx.Err() + + case <-time.After(delay): + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, err + } + c.setMCPHeaders(req) + if lastEventID != "" { + req.Header.Set(lastEventIDHeader, lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + resp, err := c.client.Do(req) + if err != nil { + finalErr = err // Store the error and try again. + delay = calculateReconnectDelay(attempt + 1) + continue + } + return resp, nil + } + } + // If the loop completes, all retries have failed, or the client is closing. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) + } + return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) +} + +// Close implements the [Connection] interface. +func (c *streamableClientConn) Close() error { + c.closeOnce.Do(func() { + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. + } else { + req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) + if err != nil { + c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } + } + } + + // Cancel any hanging network requests after cleanup. + c.cancel() + close(c.done) + }) + return c.closeErr +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, reconnectMaxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server error that should not permanently break the connection. +func isTransientHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusTooManyRequests: // 429 + return true + } + return false +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go new file mode 100644 index 000000000..41a100461 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go @@ -0,0 +1,226 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move client-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Client Design + +This document describes the client-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The client-side streamable transport allows an MCP client to communicate with a +server over HTTP, sending messages via POST and receiving responses via either +JSON or server-sent events (SSE). The implementation consists of two main +components: + + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableClientTransport] │ + │ Transport configuration; creates connections via Connect() │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [streamableClientConn] │ + │ Connection implementation; handles HTTP request/response │ + └─────────────────────────────────────────────────────────────────┘ + │ + ├──────────────────────────────────────┐ + ▼ ▼ + ┌─────────────────────────────────────────┐ ┌────────────────────────────────────┐ + │ POST request handlers │ │ Standalone SSE stream │ + │ (one per outgoing message/call) │ │ (server-initiated messages) │ + └─────────────────────────────────────────┘ └────────────────────────────────────┘ + +# Sessions + +The client maintains a session with the server, identified by a session ID +(Mcp-Session-Id header): + + - Session ID is received from the server after initialization + - Client includes the session ID in all subsequent requests + - Session ends when the client calls Close() (sends DELETE) or server returns 404 + +[streamableClientConn] stores the session state: + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Protocol version and server capabilities + +# Connection Lifecycle + +1. Connect: [StreamableClientTransport.Connect] creates a [streamableClientConn] + with a detached context for the connection's lifetime. The context is detached + to prevent the standalone SSE stream from being cancelled when the original + Connect context times out. + +2. Initialize: The MCP client sends initialize/initialized messages. Upon + receiving [InitializeResult], the connection: + - Stores the negotiated protocol version for the Mcp-Protocol-Version header + - Captures the session ID from the Mcp-Session-Id response header + - Starts the standalone SSE stream via [streamableClientConn.connectStandaloneSSE] + +3. Operation: Messages are sent via POST, responses received via JSON or SSE. + +4. Close: [streamableClientConn.Close] sends a DELETE request to terminate + the session (unless the session is already gone), then cancels the connection + context to clean up the standalone SSE stream. + +# Sending Messages (Write) + +[streamableClientConn.Write] sends all outgoing messages via HTTP POST: + + POST /endpoint + Content-Type: application/json + Accept: application/json, text/event-stream + Mcp-Protocol-Version: + Mcp-Session-Id: + + + +The server may respond with: + - 202 Accepted: Message received, no response body (notifications/responses) + - 200 OK with application/json: Single JSON-RPC response + - 200 OK with text/event-stream: SSE stream of responses + +# Receiving Messages (Read) + +[streamableClientConn.Read] returns messages from the [streamableClientConn.incoming] +channel, which is populated by multiple concurrent goroutines: + +1. POST response handlers ([streamableClientConn.handleJSON] and + [streamableClientConn.handleSSE]): Process responses from POST requests + +2. Standalone SSE stream: Receives server-initiated requests and notifications + +The client handles both response formats: + - JSON: [streamableClientConn.handleJSON] reads body, decodes message + - SSE: [streamableClientConn.handleSSE] scans events, decodes each message + +# Standalone SSE Stream + +After initialization, [streamableClientConn.sessionUpdated] triggers +[streamableClientConn.connectStandaloneSSE] to open a GET request for +server-initiated messages: + + GET /endpoint + Accept: text/event-stream + Mcp-Session-Id: + +Stream behavior: + - Optional: Server may return 405 Method Not Allowed (spec-compliant) or + other 4xx errors (tolerated in non-strict mode for compatibility) + - Persistent: Runs for the connection lifetime in a background goroutine + - Resumable: Uses Last-Event-ID header on reconnection if server provides event IDs + - Reconnects: Automatic reconnection with exponential backoff on interruption + +# Stream Resumption + +When an SSE stream (standalone or POST response) is interrupted, the client +attempts to reconnect using [streamableClientConn.connectSSE]: + +Event ID tracking: + - [streamableClientConn.processStream] tracks the last received event ID + - On reconnection, the Last-Event-ID header is set to resume from that point + - Server replays missed events if it has an [EventStore] configured + +See [calculateReconnectDelay] for the reconnect delay details. + +Server-initiated reconnection (SEP-1699) + - SSE retry field: Sets the delay for the next reconnect attempt + - If server doesn't provide event IDs, non-standalone streams don't reconnect + +# Response Formats + +The client must handle two response formats from POST requests: + +1. application/json: Single JSON-RPC response + - Body contains one JSON-RPC message + - Handled by [streamableClientConn.handleJSON] + - Simpler but doesn't support streaming or server-initiated messages + +2. text/event-stream: SSE stream of messages + - Body contains SSE events with JSON-RPC messages + - Handled by [streamableClientConn.handleSSE] + - Supports multiple messages and server-initiated communication + - Stream completes when the response to the originating call is received + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - Used by [streamableClientConn.Write] + - Response may be JSON or SSE + + - GET: Open or resume SSE stream for server-initiated messages + - Used by [streamableClientConn.connectSSE] + - Always expects text/event-stream response (or 405) + + - DELETE: Terminate the session + - Used by [streamableClientConn.Close] + - Skipped if session is already known to be gone ([errSessionMissing]) + +# Error Handling + +Errors are categorized and handled differently: + +1. Transient (recoverable via reconnection): + - Network interruption during SSE streaming + - Connection reset or timeout + - Triggers reconnection in [streamableClientConn.handleSSE] + +2. Terminal (breaks the connection): + - 404 Not Found: Session terminated by server ([errSessionMissing]) + - Message decode errors: Protocol violation + - Context cancellation: Client closed connection + - Mismatched session IDs: Protocol error + - See issue #683: our terminal errors are too strict. + +Terminal errors are stored via [streamableClientConn.fail] and returned by +subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed] +channel signals that the connection is broken. + +Special case: [errSessionMissing] indicates the server has terminated the session, +so [streamableClientConn.Close] skips the DELETE request. + +# Protocol Version Header + +After initialization, all requests include: + + Mcp-Protocol-Version: + +This header (set by [streamableClientConn.setMCPHeaders]): + - Allows the server to handle requests per the negotiated protocol + - Is omitted before initialization completes + - Uses the version from [streamableClientConn.initializedResult] + +# Key Implementation Details + +[StreamableClientTransport] configuration: + - [StreamableClientTransport.Endpoint]: URL of the MCP server + - [StreamableClientTransport.HTTPClient]: Custom HTTP client (optional) + - [StreamableClientTransport.MaxRetries]: Reconnection attempts (default 5) + +[streamableClientConn] handles the [Connection] interface: + - [streamableClientConn.Read]: Returns messages from incoming channel + - [streamableClientConn.Write]: Sends messages via POST, starts response handlers + - [streamableClientConn.Close]: Sends DELETE, cancels context, closes done channel + +State management: + - [streamableClientConn.incoming]: Buffered channel for received messages + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Cached for protocol version header + - [streamableClientConn.failed]: Channel closed on terminal error + - [streamableClientConn.done]: Channel closed on graceful shutdown + - [streamableClientConn.ctx]: Detached context for connection lifetime + - [streamableClientConn.cancel]: Cancels ctx to terminate SSE streams + +Context handling: + - Connection context is detached from [StreamableClientTransport.Connect] context + using [xcontext.Detach] to preserve context values (for auth middleware) while + preventing premature cancellation of the standalone SSE stream + - Individual POST requests use caller-provided contexts for cancellation +*/ diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go new file mode 100644 index 000000000..8a573e56a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go @@ -0,0 +1,160 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move server-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Server Design + +This document describes the server-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The streamable HTTP transport enables MCP communication over HTTP, with +server-sent events (SSE) for server-to-client messages. The implementation +consists of several layered components: + + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableHTTPHandler] │ + │ http.Handler that manages sessions and routes HTTP requests │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableServerTransport] │ + │ transport implementation, one per session; exposes ServeHTTP │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [streamableServerConn] │ + │ Connection implementation, handles message routing │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [stream] │ + │ Logical message channel within a session, may be resumed │ + └─────────────────────────────────────────────────────────────────┘ + +# Sessions + +As with other transports, a session represents a logical MCP connection between +a client and server. In the streamable transport, sessions are identified by a +unique session ID (Mcp-Session-Id header) and persist across multiple HTTP +requests. + +[StreamableHTTPHandler] maintains a map of active sessions ([sessionInfo]), +each containing: + - The [ServerSession] (MCP-level session state) + - The [StreamableServerTransport] (for message I/O) + - Optional timeout management for idle session cleanup + +Sessions are created on the first POST request (typically containing the +initialize request) and destroyed either by: + - Client sending a DELETE request + - Session timeout due to inactivity + - Server explicitly closing the session + +# Streams + +Within a session, there can be multiple concurrent "streams" - logical channels +for message delivery. This is distinct from HTTP streams; a single [stream] may +span multiple HTTP request/response cycles (via resumption). + +There are two types of streams: + +1. Optional standalone SSE stream (id = ""): + - Created when client sends a GET request to the endpoint + - Used for server-initiated messages (requests/notifications to client) + - Persists for the lifetime of the session + - Only one standalone stream per session + +2. Request streams (id = random string): + - Created for each POST request containing JSON-RPC calls + - Used to route responses back to the originating HTTP request + - Completed when all responses have been sent + - Can be resumed via GET with Last-Event-ID if interrupted + +# Message Routing + +When the server writes a message, it must be routed to the correct [stream]: + + - Responses: Routed to the stream that originated the request + - Requests/Notifications made during request handling: Routed to the same + stream as the triggering request (via context) + - Requests/Notifications made outside request handling: Routed to the + standalone SSE stream + +This routing is implemented using: + - [streamableServerConn.requestStreams] maps request IDs to stream IDs + - [idContextKey] is used to store the originating request ID in Context + - [streamableServerConn.streams] maps stream IDs to [stream] objects + +# Stream Resumption + +If an HTTP connection is interrupted (network issues, etc.), clients can +resume a stream by sending a GET request with the Last-Event-ID header. +This requires an [EventStore] to be configured on the server. + + - [EventStore.Open] is called when a new stream is created + - [EventStore.Append] is called for each message written to the stream + - [EventStore.After] is called to replay messages after a given index + - [EventStore.SessionClosed] is called when the session ends + +Event IDs are formatted as "_" to identify both the +stream and position within that stream (see [formatEventID] and [parseEventID]). + +# Stateless Mode + +For simpler deployments, the handler supports "stateless" mode +([StreamableHTTPOptions.Stateless]) where: + - No session ID validation is performed + - Each request creates a temporary session that's closed after the request + - Server-to-client requests are not supported (no way to receive response) + +This mode is useful for simple tool servers that don't need bidirectional +communication. + +# Response Formats + +The server can respond to POST requests in two formats: + +1. text/event-stream (default): Messages sent as SSE events, supports + streaming multiple messages and server-initiated communication during + request handling. + +2. application/json ([StreamableHTTPOptions.JSONResponse]): Single JSON + response, simpler but doesn't support streaming. Server-initiated messages + during request handling go to the standalone SSE stream instead. + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - GET: Open standalone SSE stream or resume an interrupted stream + - DELETE: Terminate the session + +# Key Implementation Details + +The [stream] struct manages delivery of messages to HTTP responses. + +Fields: + - [stream.w] is the ResponseWriter for the current HTTP response (non-nil indicates claimed) + - [stream.done] is closed to release the hanging HTTP request + - [stream.requests] tracks pending request IDs (stream completes when empty) + +Methods: + - [stream.deliverLocked] delivers a message to the stream + - [stream.close] sends a close event and releases the stream + - [stream.release] releases the stream from the HTTP request, allowing resumption + +[streamableServerConn] handles the [Connection] interface: + - [streamableServerConn.Read] receives messages from the incoming channel (fed by POST handlers) + - [streamableServerConn.Write] routes messages to appropriate streams + - [streamableServerConn.Close] terminates the session and notifies the [EventStore] +*/ diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go new file mode 100644 index 000000000..8aa7c3c0d --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go @@ -0,0 +1,139 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/google/jsonschema-go/jsonschema" +) + +// A ToolHandler handles a call to tools/call. +// +// This is a low-level API, for use with [Server.AddTool]. It does not do any +// pre- or post-processing of the request or result: the params contain raw +// arguments, no input validation is performed, and the result is returned to +// the user as-is, without any validation of the output. +// +// Most users will write a [ToolHandlerFor] and install it with the generic +// [AddTool] function. +// +// If ToolHandler returns an error, it is treated as a protocol error. By +// contrast, [ToolHandlerFor] automatically populates [CallToolResult.IsError] +// and [CallToolResult.Content] accordingly. +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) + +// A ToolHandlerFor handles a call to tools/call with typed arguments and results. +// +// Use [AddTool] to add a ToolHandlerFor to a server. +// +// Unlike [ToolHandler], [ToolHandlerFor] provides significant functionality +// out of the box, and enforces that the tool conforms to the MCP spec: +// - The In type provides a default input schema for the tool, though it may +// be overridden in [AddTool]. +// - The input value is automatically unmarshaled from req.Params.Arguments. +// - The input value is automatically validated against its input schema. +// Invalid input is rejected before getting to the handler. +// - If the Out type is not the empty interface [any], it provides the +// default output schema for the tool (which again may be overridden in +// [AddTool]). +// - The Out value is used to populate result.StructuredOutput. +// - If [CallToolResult.Content] is unset, it is populated with the JSON +// content of the output. +// - An error result is treated as a tool error, rather than a protocol +// error, and is therefore packed into CallToolResult.Content, with +// [IsError] set. +// +// For these reasons, most users can ignore the [CallToolRequest] argument and +// [CallToolResult] return values entirely. In fact, it is permissible to +// return a nil CallToolResult, if you only care about returning a output value +// or error. The effective result will be populated as described above. +type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) + +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler ToolHandler +} + +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { + // TODO: use reflection to create the struct type to unmarshal into. + // Separate validation from assignment. + + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? + if resolved != nil { + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } + } + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) + } + if err := resolved.Validate(&v); err != nil { + return nil, err + } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) + } + } + return data, nil +} + +// validateToolName checks whether name is a valid tool name, reporting a +// non-nil error if not. +func validateToolName(name string) error { + if name == "" { + return fmt.Errorf("tool name cannot be empty") + } + if len(name) > 128 { + return fmt.Errorf("tool name exceeds maximum length of 128 characters (current: %d)", len(name)) + } + // For consistency with other SDKs, report characters in the order the appear + // in the name. + var invalidChars []string + seen := make(map[rune]bool) + for _, r := range name { + if !validToolNameRune(r) { + if !seen[r] { + invalidChars = append(invalidChars, fmt.Sprintf("%q", string(r))) + seen[r] = true + } + } + } + if len(invalidChars) > 0 { + return fmt.Errorf("tool name contains invalid characters: %s", strings.Join(invalidChars, ", ")) + } + return nil +} + +// validToolNameRune reports whether r is valid within tool names. +func validToolNameRune(r rune) bool { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' || r == '.' +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go new file mode 100644 index 000000000..25f1d5d05 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go @@ -0,0 +1,655 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// ErrConnectionClosed is returned when sending a message to a connection that +// is closed or in the process of closing. +var ErrConnectionClosed = errors.New("connection closed") + +// A Transport is used to create a bidirectional connection between MCP client +// and server. +// +// Transports should be used for at most one call to [Server.Connect] or +// [Client.Connect]. +type Transport interface { + // Connect returns the logical JSON-RPC connection.. + // + // It is called exactly once by [Server.Connect] or [Client.Connect]. + Connect(ctx context.Context) (Connection, error) +} + +// A Connection is a logical bidirectional JSON-RPC connection. +type Connection interface { + // Read reads the next message to process off the connection. + // + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. + Read(context.Context) (jsonrpc.Message, error) + + // Write writes a new message to the connection. + // + // Write may be called concurrently, as calls or responses may occur + // concurrently in user code. + Write(context.Context, jsonrpc.Message) error + + // Close closes the connection. It is implicitly called whenever a Read or + // Write fails. + // + // Close may be called multiple times, potentially concurrently. + Close() error + + // TODO(#148): remove SessionID from this interface. + SessionID() string +} + +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? +type clientConnection interface { + Connection + + // sessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) +} + +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) +} + +// A StdioTransport is a [Transport] that communicates over stdin/stdout using +// newline-delimited JSON. +type StdioTransport struct{} + +// Connect implements the [Transport] interface. +func (*StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil +} + +// nopCloserWriter is an io.WriteCloser with a trivial Close method. +type nopCloserWriter struct { + io.Writer +} + +func (nopCloserWriter) Close() error { return nil } + +// An IOTransport is a [Transport] that communicates over separate +// io.ReadCloser and io.WriteCloser using newline-delimited JSON. +type IOTransport struct { + Reader io.ReadCloser + Writer io.WriteCloser +} + +// Connect implements the [Transport] interface. +func (t *IOTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{t.Reader, t.Writer}), nil +} + +// An InMemoryTransport is a [Transport] that communicates over an in-memory +// network connection, using newline-delimited JSON. +// +// InMemoryTransports should be constructed using [NewInMemoryTransports], +// which returns two transports connected to each other. +type InMemoryTransport struct { + rwc io.ReadWriteCloser +} + +// Connect implements the [Transport] interface. +func (t *InMemoryTransport) Connect(context.Context) (Connection, error) { + return newIOConn(t.rwc), nil +} + +// NewInMemoryTransports returns two [InMemoryTransport] objects that connect +// to each other. +// +// The resulting transports are symmetrical: use either to connect to a server, +// and then the other to connect to a client. Servers must be connected before +// clients, as the client initializes the MCP session during connection. +func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { + c1, c2 := net.Pipe() + return &InMemoryTransport{c1}, &InMemoryTransport{c2} +} + +type binder[T handler, State any] interface { + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T + disconnect(T) +} + +type handler interface { + handle(ctx context.Context, req *jsonrpc.Request) (any, error) +} + +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { + var zero H + mcpConn, err := t.Connect(ctx) + if err != nil { + return zero, err + } + // If logging is configured, write message logs. + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) + var ( + h H + preempter canceller + ) + bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { + h = b.bind(mcpConn, conn, s, onClose) + preempter.conn = conn + return jsonrpc2.HandlerFunc(h.handle) + } + _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ + Reader: reader, + Writer: writer, + Closer: mcpConn, + Bind: bind, + Preempter: &preempter, + OnDone: func() { + b.disconnect(h) + }, + OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, + }) + assert(preempter.conn != nil, "unbound preempter") + return h, nil +} + +// A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP +// cancelled notifications. +type canceller struct { + conn *jsonrpc2.Connection +} + +// Preempt implements [jsonrpc2.Preempter]. +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { + if req.Method == notificationCancelled { + var params CancelledParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, err + } + id, err := jsonrpc2.MakeID(params.RequestID) + if err != nil { + return nil, err + } + go c.conn.Cancel(id) + } + return nil, jsonrpc2.ErrNotHandled +} + +// call executes and awaits a jsonrpc2 call on the given connection, +// translating errors into the mcp domain. +func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { + // The "%w"s in this function expose jsonrpc.Error as part of the API. + call := conn.Call(ctx, method, params) + err := call.Await(ctx, result) + switch { + case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) + case ctx.Err() != nil: + // Notify the peer of cancellation. + err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) + return errors.Join(ctx.Err(), err) + case err != nil: + return fmt.Errorf("calling %q: %w", method, err) + } + return nil +} + +// A LoggingTransport is a [Transport] that delegates to another transport, +// writing RPC logs to an io.Writer. +type LoggingTransport struct { + Transport Transport + Writer io.Writer +} + +// Connect connects the underlying transport, returning a [Connection] that writes +// logs to the configured destination. +func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { + delegate, err := t.Transport.Connect(ctx) + if err != nil { + return nil, err + } + return &loggingConn{delegate: delegate, w: t.Writer}, nil +} + +type loggingConn struct { + delegate Connection + + mu sync.Mutex + w io.Writer +} + +func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } + +// Read is a stream middleware that logs incoming messages. +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { + msg, err := s.delegate.Read(ctx) + + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "read error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "read: %s\n", string(data)) + s.mu.Unlock() + } + + return msg, err +} + +// Write is a stream middleware that logs outgoing messages. +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { + err := s.delegate.Write(ctx, msg) + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "write error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "write: %s\n", string(data)) + s.mu.Unlock() + } + return err +} + +func (s *loggingConn) Close() error { + return s.delegate.Close() +} + +// A rwc binds an io.ReadCloser and io.WriteCloser together to create an +// io.ReadWriteCloser. +type rwc struct { + rc io.ReadCloser + wc io.WriteCloser +} + +func (r rwc) Read(p []byte) (n int, err error) { + return r.rc.Read(p) +} + +func (r rwc) Write(p []byte) (n int, err error) { + return r.wc.Write(p) +} + +func (r rwc) Close() error { + rcErr := r.rc.Close() + + var wcErr error + if r.wc != nil { // we only allow a nil writer in unit tests + wcErr = r.wc.Close() + } + + return errors.Join(rcErr, wcErr) +} + +// An ioConn is a transport that delimits messages with newlines across +// a bidirectional stream, and supports jsonrpc.2 message batching. +// +// See https://github.com/ndjson/ndjson-spec for discussion of newline +// delimited JSON. +// +// See [msgBatch] for more discussion of message batching. +type ioConn struct { + protocolVersion string // negotiated version, set during session initialization. + + writeMu sync.Mutex // guards Write, which must be concurrency safe. + rwc io.ReadWriteCloser // the underlying stream + + // incoming receives messages from the read loop started in [newIOConn]. + incoming <-chan msgOrErr + + // If outgoiBatch has a positive capacity, it will be used to batch requests + // and notifications before sending. + outgoingBatch []jsonrpc.Message + + // Unread messages in the last batch. Since reads are serialized, there is no + // need to guard here. + queue []jsonrpc.Message + + // batches correlate incoming requests to the batch in which they arrived. + // Since writes may be concurrent to reads, we need to guard this with a mutex. + batchMu sync.Mutex + batches map[jsonrpc2.ID]*msgBatch // lazily allocated + + closeOnce sync.Once + closed chan struct{} + closeErr error +} + +type msgOrErr struct { + msg json.RawMessage + err error +} + +func newIOConn(rwc io.ReadWriteCloser) *ioConn { + var ( + incoming = make(chan msgOrErr) + closed = make(chan struct{}) + ) + // Start a goroutine for reads, so that we can select on the incoming channel + // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). + // + // This leaks a goroutine if rwc.Read does not unblock after it is closed, + // but that is unavoidable since AFAIK there is no (easy and portable) way to + // guarantee that reads of stdin are unblocked when closed. + go func() { + dec := json.NewDecoder(rwc) + for { + var raw json.RawMessage + err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + // Support both Unix (\n) and Windows (\r\n) line endings. + if tr[0] != '\n' && tr[0] != '\r' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } + select { + case incoming <- msgOrErr{msg: raw, err: err}: + case <-closed: + return + } + if err != nil { + return + } + } + }() + return &ioConn{ + rwc: rwc, + incoming: incoming, + closed: closed, + } +} + +func (c *ioConn) SessionID() string { return "" } + +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + +// addBatch records a msgBatch for an incoming batch payload. +// It returns an error if batch is malformed, containing previously seen IDs. +// +// See [msgBatch] for more. +func (t *ioConn) addBatch(batch *msgBatch) error { + t.batchMu.Lock() + defer t.batchMu.Unlock() + for id := range batch.unresolved { + if _, ok := t.batches[id]; ok { + return fmt.Errorf("%w: batch contains previously seen request %v", jsonrpc2.ErrInvalidRequest, id.Raw()) + } + } + for id := range batch.unresolved { + if t.batches == nil { + t.batches = make(map[jsonrpc2.ID]*msgBatch) + } + t.batches[id] = batch + } + return nil +} + +// updateBatch records a response in the message batch tracking the +// corresponding incoming call, if any. +// +// The second result reports whether resp was part of a batch. If this is true, +// the first result is nil if the batch is still incomplete, or the full set of +// batch responses if resp completed the batch. +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { + t.batchMu.Lock() + defer t.batchMu.Unlock() + + if batch, ok := t.batches[resp.ID]; ok { + idx, ok := batch.unresolved[resp.ID] + if !ok { + panic("internal error: inconsistent batches") + } + batch.responses[idx] = resp + delete(batch.unresolved, resp.ID) + delete(t.batches, resp.ID) + if len(batch.unresolved) == 0 { + return batch.responses, true + } + return nil, true + } + return nil, false +} + +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. +// +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: +// +// "The Server should respond with an Array containing the corresponding +// Response objects, after all of the batch Request objects have been +// processed. A Response object SHOULD exist for each Request object, except +// that there SHOULD NOT be any Response objects for notifications. The Server +// MAY process a batch rpc call as a set of concurrent tasks, processing them +// in any order and with any width of parallelism." +// +// Therefore, a msgBatch keeps track of outstanding calls and their responses. +// When there are no unresolved calls, the response payload is sent. +type msgBatch struct { + unresolved map[jsonrpc2.ID]int + responses []*jsonrpc.Response +} + +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { + // As a matter of principle, enforce that reads on a closed context return an + // error. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if len(t.queue) > 0 { + next := t.queue[0] + t.queue = t.queue[1:] + return next, nil + } + + var raw json.RawMessage + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case v := <-t.incoming: + if v.err != nil { + return nil, v.err + } + raw = v.msg + + case <-t.closed: + return nil, io.EOF + } + + msgs, batch, err := readBatch(raw) + if err != nil { + return nil, err + } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + + t.queue = msgs[1:] + + if batch { + var respBatch *msgBatch // track incoming requests in the batch + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + if respBatch == nil { + respBatch = &msgBatch{ + unresolved: make(map[jsonrpc2.ID]int), + } + } + if _, ok := respBatch.unresolved[req.ID]; ok { + return nil, fmt.Errorf("duplicate message ID %q", req.ID) + } + respBatch.unresolved[req.ID] = len(respBatch.responses) + respBatch.responses = append(respBatch.responses, nil) + } + } + if respBatch != nil { + // The batch contains one or more incoming requests to track. + if err := t.addBatch(respBatch); err != nil { + return nil, err + } + } + } + return msgs[0], err +} + +// readBatch reads batch data, which may be either a single JSON-RPC message, +// or an array of JSON-RPC messages. +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { + // Try to read an array of messages first. + var rawBatch []json.RawMessage + if err := json.Unmarshal(data, &rawBatch); err == nil { + if len(rawBatch) == 0 { + return nil, true, fmt.Errorf("empty batch") + } + for _, raw := range rawBatch { + msg, err := jsonrpc2.DecodeMessage(raw) + if err != nil { + return nil, true, err + } + msgs = append(msgs, msg) + } + return msgs, true, nil + } + // Try again with a single message. + msg, err := jsonrpc2.DecodeMessage(data) + return []jsonrpc.Message{msg}, false, err +} + +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // As in [ioConn.Read], enforce that Writes on a closed context are an error. + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + t.writeMu.Lock() + defer t.writeMu.Unlock() + + // Batching support: if msg is a Response, it may have completed a batch, so + // check that first. Otherwise, it is a request or notification, and we may + // want to collect it into a batch before sending, if we're configured to use + // outgoing batches. + if resp, ok := msg.(*jsonrpc.Response); ok { + if batch, ok := t.updateBatch(resp); ok { + if len(batch) > 0 { + data, err := marshalMessages(batch) + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + } else if len(t.outgoingBatch) < cap(t.outgoingBatch) { + t.outgoingBatch = append(t.outgoingBatch, msg) + if len(t.outgoingBatch) == cap(t.outgoingBatch) { + data, err := marshalMessages(t.outgoingBatch) + t.outgoingBatch = t.outgoingBatch[:0] + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + data = append(data, '\n') // newline delimited + _, err = t.rwc.Write(data) + return err +} + +func (t *ioConn) Close() error { + t.closeOnce.Do(func() { + t.closeErr = t.rwc.Close() + close(t.closed) + }) + return t.closeErr +} + +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { + var rawMsgs []json.RawMessage + for _, msg := range msgs { + raw, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return nil, fmt.Errorf("encoding batch message: %w", err) + } + rawMsgs = append(rawMsgs, raw) + } + return json.Marshal(rawMsgs) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go new file mode 100644 index 000000000..5ada466e5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go @@ -0,0 +1,43 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "crypto/rand" + "encoding/json" +) + +func assert(cond bool, msg string) { + if !cond { + panic(msg) + } +} + +// Copied from crypto/rand. +// TODO: once 1.24 is assured, just use crypto/rand. +const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + +func randText() string { + // ⌈log₃₂ 2¹²⁸⌉ = 26 chars + src := make([]byte, 26) + rand.Read(src) + for i := range src { + src[i] = base32alphabet[src[i]%32] + } + return string(src) +} + +// remarshal marshals from to JSON, and then unmarshals into to, which must be +// a pointer type. +func remarshal(from, to any) error { + data, err := json.Marshal(from) + if err != nil { + return err + } + if err := json.Unmarshal(data, to); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go new file mode 100644 index 000000000..9aa0c8d7d --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go @@ -0,0 +1,187 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, +// as defined in [RFC 8414]. +// +// Not supported: +// - signed metadata +// +// Note: URL fields in this struct are validated by validateAuthServerMetaURLs to +// prevent XSS attacks. If you add a new URL field, you must also add it to that +// function. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414) +type AuthServerMeta struct { + // GENERATED BY GEMINI 2.5. + + // Issuer is the REQUIRED URL identifying the authorization server. + Issuer string `json:"issuer"` + + // AuthorizationEndpoint is the REQUIRED URL of the server's OAuth 2.0 authorization endpoint. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the REQUIRED URL of the server's OAuth 2.0 token endpoint. + TokenEndpoint string `json:"token_endpoint"` + + // JWKSURI is the REQUIRED URL of the server's JSON Web Key Set [JWK] document. + JWKSURI string `json:"jwks_uri"` + + // RegistrationEndpoint is the RECOMMENDED URL of the server's OAuth 2.0 Dynamic Client Registration endpoint. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // ScopesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "scope" values that this server supports. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ResponseTypesSupported is a REQUIRED JSON array of strings containing a list of the OAuth 2.0 + // "response_type" values that this server supports. + ResponseTypesSupported []string `json:"response_types_supported"` + + // ResponseModesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "response_mode" values that this server supports. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // GrantTypesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // grant type values that this server supports. + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // TokenEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // client authentication methods supported by this token endpoint. + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // TokenEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings containing + // a list of the JWS signing algorithms ("alg" values) supported by the token endpoint for + // the signature on the JWT used to authenticate the client. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // ServiceDocumentation is a RECOMMENDED URL of a page containing human-readable documentation + // for the service. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + // UILocalesSupported is a RECOMMENDED JSON array of strings representing supported + // BCP47 [RFC5646] language tag values for display in the user interface. + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + + // OpPolicyURI is a RECOMMENDED URL that the server provides to the person registering + // the client to read about the server's operator policies. + OpPolicyURI string `json:"op_policy_uri,omitempty"` + + // OpTOSURI is a RECOMMENDED URL that the server provides to the person registering the + // client to read about the server's terms of service. + OpTOSURI string `json:"op_tos_uri,omitempty"` + + // RevocationEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 revocation endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // RevocationEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this revocation endpoint. + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // RevocationEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the revocation + // endpoint for the signature on the JWT used to authenticate the client. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + // IntrospectionEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 introspection endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // IntrospectionEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this introspection endpoint. + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // IntrospectionEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the introspection + // endpoint for the signature on the JWT used to authenticate the client. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // PKCE code challenge methods supported by this authorization server. + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` +} + +var wellKnownPaths = []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", +} + +// GetAuthServerMeta issues a GET request to retrieve authorization server metadata +// from an OAuth authorization server with the given issuerURL. +// +// It follows [RFC 8414]: +// - The well-known paths specified there are inserted into the URL's path, one at time. +// The first to succeed is used. +// - The Issuer field is checked against issuerURL. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414 +func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { + var errs []error + for _, p := range wellKnownPaths { + u, err := prependToPath(issuerURL, p) + if err != nil { + // issuerURL is bad; no point in continuing. + return nil, err + } + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) + if err == nil { + if asm.Issuer != issuerURL { // section 3.3 + // Security violation; don't keep trying. + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } + + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) + } + + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err + } + + return asm, nil + } + errs = append(errs, err) + } + return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) +} + +// validateAuthServerMetaURLs validates all URL fields in AuthServerMeta +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateAuthServerMetaURLs(asm *AuthServerMeta) error { + urls := []struct { + name string + value string + }{ + {"authorization_endpoint", asm.AuthorizationEndpoint}, + {"token_endpoint", asm.TokenEndpoint}, + {"jwks_uri", asm.JWKSURI}, + {"registration_endpoint", asm.RegistrationEndpoint}, + {"service_documentation", asm.ServiceDocumentation}, + {"op_policy_uri", asm.OpPolicyURI}, + {"op_tos_uri", asm.OpTOSURI}, + {"revocation_endpoint", asm.RevocationEndpoint}, + {"introspection_endpoint", asm.IntrospectionEndpoint}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go new file mode 100644 index 000000000..c64cb8cd4 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go @@ -0,0 +1,261 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// ClientRegistrationMetadata represents the client metadata fields for the DCR POST request (RFC 7591). +// +// Note: URL fields in this struct are validated by validateClientRegistrationURLs +// to prevent XSS attacks. If you add a new URL field, you must also add it to +// that function. +type ClientRegistrationMetadata struct { + // RedirectURIs is a REQUIRED JSON array of redirection URI strings for use in + // redirect-based flows (such as the authorization code grant). + RedirectURIs []string `json:"redirect_uris"` + + // TokenEndpointAuthMethod is an OPTIONAL string indicator of the requested + // authentication method for the token endpoint. + // If omitted, the default is "client_secret_basic". + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // GrantTypes is an OPTIONAL JSON array of OAuth 2.0 grant type strings + // that the client will restrict itself to using. + // If omitted, the default is ["authorization_code"]. + GrantTypes []string `json:"grant_types,omitempty"` + + // ResponseTypes is an OPTIONAL JSON array of OAuth 2.0 response type strings + // that the client will restrict itself to using. + // If omitted, the default is ["code"]. + ResponseTypes []string `json:"response_types,omitempty"` + + // ClientName is a RECOMMENDED human-readable name of the client to be presented + // to the end-user. + ClientName string `json:"client_name,omitempty"` + + // ClientURI is a RECOMMENDED URL of a web page providing information about the client. + ClientURI string `json:"client_uri,omitempty"` + + // LogoURI is an OPTIONAL URL of a logo for the client, which may be displayed + // to the end-user. + LogoURI string `json:"logo_uri,omitempty"` + + // Scope is an OPTIONAL string containing a space-separated list of scope values + // that the client will restrict itself to using. + Scope string `json:"scope,omitempty"` + + // Contacts is an OPTIONAL JSON array of strings representing ways to contact + // people responsible for this client (e.g., email addresses). + Contacts []string `json:"contacts,omitempty"` + + // TOSURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's terms of service. + TOSURI string `json:"tos_uri,omitempty"` + + // PolicyURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's privacy policy. + PolicyURI string `json:"policy_uri,omitempty"` + + // JWKSURI is an OPTIONAL URL for the client's JSON Web Key Set [JWK] document. + // This is preferred over the 'jwks' parameter. + JWKSURI string `json:"jwks_uri,omitempty"` + + // JWKS is an OPTIONAL client's JSON Web Key Set [JWK] document, passed by value. + // This is an alternative to providing a JWKSURI. + JWKS string `json:"jwks,omitempty"` + + // SoftwareID is an OPTIONAL unique identifier string for the client software, + // constant across all instances and versions. + SoftwareID string `json:"software_id,omitempty"` + + // SoftwareVersion is an OPTIONAL version identifier string for the client software. + SoftwareVersion string `json:"software_version,omitempty"` + + // SoftwareStatement is an OPTIONAL JWT that asserts client metadata values. + // Values in the software statement take precedence over other metadata values. + SoftwareStatement string `json:"software_statement,omitempty"` +} + +// ClientRegistrationResponse represents the fields returned by the Authorization Server +// (RFC 7591, Section 3.2.1 and 3.2.2). +type ClientRegistrationResponse struct { + // ClientRegistrationMetadata contains all registered client metadata, returned by the + // server on success, potentially with modified or defaulted values. + ClientRegistrationMetadata + + // ClientID is the REQUIRED newly issued OAuth 2.0 client identifier. + ClientID string `json:"client_id"` + + // ClientSecret is an OPTIONAL client secret string. + ClientSecret string `json:"client_secret,omitempty"` + + // ClientIDIssuedAt is an OPTIONAL Unix timestamp when the ClientID was issued. + ClientIDIssuedAt time.Time `json:"client_id_issued_at,omitempty"` + + // ClientSecretExpiresAt is the REQUIRED (if client_secret is issued) Unix + // timestamp when the secret expires, or 0 if it never expires. + ClientSecretExpiresAt time.Time `json:"client_secret_expires_at,omitempty"` +} + +func (r *ClientRegistrationResponse) MarshalJSON() ([]byte, error) { + type alias ClientRegistrationResponse + var clientIDIssuedAt int64 + var clientSecretExpiresAt int64 + + if !r.ClientIDIssuedAt.IsZero() { + clientIDIssuedAt = r.ClientIDIssuedAt.Unix() + } + if !r.ClientSecretExpiresAt.IsZero() { + clientSecretExpiresAt = r.ClientSecretExpiresAt.Unix() + } + + return json.Marshal(&struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + ClientIDIssuedAt: clientIDIssuedAt, + ClientSecretExpiresAt: clientSecretExpiresAt, + alias: (*alias)(r), + }) +} + +func (r *ClientRegistrationResponse) UnmarshalJSON(data []byte) error { + type alias ClientRegistrationResponse + aux := &struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + alias: (*alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if aux.ClientIDIssuedAt != 0 { + r.ClientIDIssuedAt = time.Unix(aux.ClientIDIssuedAt, 0) + } + if aux.ClientSecretExpiresAt != 0 { + r.ClientSecretExpiresAt = time.Unix(aux.ClientSecretExpiresAt, 0) + } + return nil +} + +// ClientRegistrationError is the error response from the Authorization Server +// for a failed registration attempt (RFC 7591, Section 3.2.2). +type ClientRegistrationError struct { + // ErrorCode is the REQUIRED error code if registration failed (RFC 7591, 3.2.2). + ErrorCode string `json:"error"` + + // ErrorDescription is an OPTIONAL human-readable error message. + ErrorDescription string `json:"error_description,omitempty"` +} + +func (e *ClientRegistrationError) Error() string { + return fmt.Sprintf("registration failed: %s (%s)", e.ErrorCode, e.ErrorDescription) +} + +// RegisterClient performs Dynamic Client Registration according to RFC 7591. +func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta *ClientRegistrationMetadata, c *http.Client) (*ClientRegistrationResponse, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("registration_endpoint is required") + } + + if c == nil { + c = http.DefaultClient + } + + payload, err := json.Marshal(clientMeta) + if err != nil { + return nil, fmt.Errorf("failed to marshal client metadata: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read registration response body: %w", err) + } + + if resp.StatusCode == http.StatusCreated { + var regResponse ClientRegistrationResponse + if err := json.Unmarshal(body, ®Response); err != nil { + return nil, fmt.Errorf("failed to decode successful registration response: %w (%s)", err, string(body)) + } + if regResponse.ClientID == "" { + return nil, fmt.Errorf("registration response is missing required 'client_id' field") + } + // Validate URL fields to prevent XSS attacks (see #526). + if err := validateClientRegistrationURLs(®Response.ClientRegistrationMetadata); err != nil { + return nil, err + } + return ®Response, nil + } + + if resp.StatusCode == http.StatusBadRequest { + var regError ClientRegistrationError + if err := json.Unmarshal(body, ®Error); err != nil { + return nil, fmt.Errorf("failed to decode registration error response: %w (%s)", err, string(body)) + } + return nil, ®Error + } + + return nil, fmt.Errorf("registration failed with status %s: %s", resp.Status, string(body)) +} + +// validateClientRegistrationURLs validates all URL fields in ClientRegistrationMetadata +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { + // Validate redirect URIs + for i, uri := range meta.RedirectURIs { + if err := checkURLScheme(uri); err != nil { + return fmt.Errorf("redirect_uris[%d]: %w", i, err) + } + } + + // Validate other URL fields + urls := []struct { + name string + value string + }{ + {"client_uri", meta.ClientURI}, + {"logo_uri", meta.LogoURI}, + {"tos_uri", meta.TOSURI}, + {"policy_uri", meta.PolicyURI}, + {"jwks_uri", meta.JWKSURI}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go new file mode 100644 index 000000000..cdda695b7 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go @@ -0,0 +1,91 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" +) + +// prependToPath prepends pre to the path of urlStr. +// When pre is the well-known path, this is the algorithm specified in both RFC 9728 +// section 3.1 and RFC 8414 section 3.1. +func prependToPath(urlStr, pre string) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", err + } + p := "/" + strings.Trim(pre, "/") + if u.Path != "" { + p += "/" + } + + u.Path = p + strings.TrimLeft(u.Path, "/") + return u.String(), nil +} + +// getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both +// RFC 9728 and RFC 8414. +// It will not read more than limit bytes from the body. +func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Specs require a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Specs require application/json. + ct := res.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(ct) + if err != nil || mediaType != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var t T + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + if err := dec.Decode(&t); err != nil { + return nil, err + } + return &t, nil +} + +// checkURLScheme ensures that its argument is a valid URL with a scheme +// that prevents XSS attacks. +// See #526. +func checkURLScheme(u string) error { + if u == "" { + return nil + } + uu, err := url.Parse(u) + if err != nil { + return err + } + scheme := strings.ToLower(uu.Scheme) + if scheme == "javascript" || scheme == "data" || scheme == "vbscript" { + return fmt.Errorf("URL has disallowed scheme %q", scheme) + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go new file mode 100644 index 000000000..34ed55b59 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go @@ -0,0 +1,92 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go new file mode 100644 index 000000000..bb61f7974 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go @@ -0,0 +1,281 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "unicode" + + "github.com/modelcontextprotocol/go-sdk/internal/util" +) + +const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource" + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server by its ID. +// The resource ID is an HTTPS URL, typically with a host:port and possibly a path. +// For example: +// +// https://example.com/server +// +// This function, following the spec (§3), inserts the default well-known path into the +// URL. In our example, the result would be +// +// https://example.com/.well-known/oauth-protected-resource/server +// +// It then retrieves the metadata at that location using the given client (or the +// default client if nil) and validates its resource field against resourceID. +func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) + + u, err := url.Parse(resourceID) + if err != nil { + return nil, err + } + // Insert well-known URI into URL. + u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) + return getPRM(ctx, u.String(), c, resourceID) +} + +// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata +// using information in the given header, using the given client (or the default +// client if nil). +// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request. +// Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata +// matches the serverURL (the URL that the client used to make the original request to the resource server). +// If there is no metadata URL in the header, it returns nil, nil. +func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") + headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] + if len(headers) == 0 { + return nil, nil + } + cs, err := ParseWWWAuthenticate(headers) + if err != nil { + return nil, err + } + metadataURL := ResourceMetadataURL(cs) + if metadataURL == "" { + return nil, nil + } + return getPRM(ctx, metadataURL, c, serverURL) +} + +// getPRM makes a GET request to the given URL, and validates the response. +// As part of the validation, it compares the returned resource field to wantResource. +func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { + if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { + return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) + } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) + if err != nil { + return nil, err + } + // Validate the Resource field (see RFC 9728, section 3.3). + if prm.Resource != wantResource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + } + // Validate the authorization server URLs to prevent XSS attacks (see #526). + for _, u := range prm.AuthorizationServers { + if err := checkURLScheme(u); err != nil { + return nil, err + } + } + return prm, nil +} + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +// ResourceMetadataURL returns a resource metadata URL from the given challenges, +// or the empty string if there is none. +func ResourceMetadataURL(cs []challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/LICENSE b/vendor/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 000000000..79e8f8757 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/yosida95/uritemplate/v3/README.rst b/vendor/github.com/yosida95/uritemplate/v3/README.rst new file mode 100644 index 000000000..6815d0a46 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/README.rst @@ -0,0 +1,46 @@ +uritemplate +=========== + +`uritemplate`_ is a Go implementation of `URI Template`_ [RFC6570] with +full functionality of URI Template Level 4. + +uritemplate can also generate a regexp that matches expansion of the +URI Template from a URI Template. + +Getting Started +--------------- + +Installation +~~~~~~~~~~~~ + +.. code-block:: sh + + $ go get -u github.com/yosida95/uritemplate/v3 + +Documentation +~~~~~~~~~~~~~ + +The documentation is available on GoDoc_. + +Examples +-------- + +See `examples on GoDoc`_. + +License +------- + +`uritemplate`_ is distributed under the BSD 3-Clause license. +PLEASE READ ./LICENSE carefully and follow its clauses to use this software. + +Author +------ + +yosida95_ + + +.. _`URI Template`: https://tools.ietf.org/html/rfc6570 +.. _Godoc: https://godoc.org/github.com/yosida95/uritemplate +.. _`examples on GoDoc`: https://godoc.org/github.com/yosida95/uritemplate#pkg-examples +.. _yosida95: https://yosida95.com/ +.. _uritemplate: https://github.com/yosida95/uritemplate diff --git a/vendor/github.com/yosida95/uritemplate/v3/compile.go b/vendor/github.com/yosida95/uritemplate/v3/compile.go new file mode 100644 index 000000000..bd774d15d --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/compile.go @@ -0,0 +1,224 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode/utf8" +) + +type compiler struct { + prog *prog +} + +func (c *compiler) init() { + c.prog = &prog{} +} + +func (c *compiler) op(opcode progOpcode) uint32 { + i := len(c.prog.op) + c.prog.op = append(c.prog.op, progOp{code: opcode}) + return uint32(i) +} + +func (c *compiler) opWithRune(opcode progOpcode, r rune) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).r = r + return addr +} + +func (c *compiler) opWithRuneClass(opcode progOpcode, rc runeClass) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).rc = rc + return addr +} + +func (c *compiler) opWithAddr(opcode progOpcode, absaddr uint32) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).i = absaddr + return addr +} + +func (c *compiler) opWithAddrDelta(opcode progOpcode, delta uint32) uint32 { + return c.opWithAddr(opcode, uint32(len(c.prog.op))+delta) +} + +func (c *compiler) opWithName(opcode progOpcode, name string) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).name = name + return addr +} + +func (c *compiler) compileString(str string) { + for i := 0; i < len(str); { + // NOTE(yosida95): It is confirmed at parse time that literals + // consist of only valid-UTF8 runes. + r, size := utf8.DecodeRuneInString(str[i:]) + c.opWithRune(opRune, r) + i += size + } +} + +func (c *compiler) compileRuneClass(rc runeClass, maxlen int) { + for i := 0; i < maxlen; i++ { + if i > 0 { + c.opWithAddrDelta(opSplit, 7) + } + c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + } +} + +func (c *compiler) compileRuneClassInfinite(rc runeClass) { + start := c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithAddrDelta(opSplit, 2) // loop + c.opWithAddr(opJmp, start) // +} + +func (c *compiler) compileVarspecValue(spec varspec, expr *expression) { + var specname string + if spec.maxlen > 0 { + specname = fmt.Sprintf("%s:%d", spec.name, spec.maxlen) + } else { + specname = spec.name + } + + c.prog.numCap++ + + c.opWithName(opCapStart, specname) + + split := c.op(opSplit) + if spec.maxlen > 0 { + c.compileRuneClass(expr.allow, spec.maxlen) + } else { + c.compileRuneClassInfinite(expr.allow) + } + + capEnd := c.opWithName(opCapEnd, specname) + c.prog.op[split].i = capEnd +} + +func (c *compiler) compileVarspec(spec varspec, expr *expression) { + switch { + case expr.named && spec.explode: + split1 := c.op(opSplit) + noop := c.op(opNoop) + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + c.compileVarspecValue(spec, expr) + + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.opWithAddr(opJmp, noop) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + c.opWithAddr(opJmp, split3) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + c.prog.op[split3].i = uint32(len(c.prog.op)) + + case expr.named && !spec.explode: + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + + split3 := c.op(opSplit) + + split4 := c.op(opSplit) + c.compileVarspecValue(spec, expr) + + split5 := c.op(opSplit) + c.prog.op[split4].i = split5 + c.compileString(",") + c.opWithAddr(opJmp, split4) + + c.prog.op[split3].i = uint32(len(c.prog.op)) + c.compileString(",") + jmp1 := c.op(opJmp) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + + c.prog.op[split5].i = uint32(len(c.prog.op)) + c.prog.op[jmp1].i = uint32(len(c.prog.op)) + + case !expr.named: + start := uint32(len(c.prog.op)) + c.compileVarspecValue(spec, expr) + + split1 := c.op(opSplit) + jmp := c.op(opJmp) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + if spec.explode { + c.compileString(expr.sep) + } else { + c.opWithRune(opRune, ',') + } + c.opWithAddr(opJmp, start) + + c.prog.op[jmp].i = uint32(len(c.prog.op)) + } +} + +func (c *compiler) compileExpression(expr *expression) { + if len(expr.vars) < 1 { + return + } + + split1 := c.op(opSplit) + c.compileString(expr.first) + + for i, size := 0, len(expr.vars); i < size; i++ { + spec := expr.vars[i] + + split2 := c.op(opSplit) + if i > 0 { + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.prog.op[split3].i = uint32(len(c.prog.op)) + } + c.compileVarspec(spec, expr) + c.prog.op[split2].i = uint32(len(c.prog.op)) + } + + c.prog.op[split1].i = uint32(len(c.prog.op)) +} + +func (c *compiler) compileLiterals(lt literals) { + c.compileString(string(lt)) +} + +func (c *compiler) compile(tmpl *Template) { + c.op(opLineBegin) + for i := range tmpl.exprs { + expr := tmpl.exprs[i] + switch expr := expr.(type) { + default: + panic("unhandled expression") + case *expression: + c.compileExpression(expr) + case literals: + c.compileLiterals(expr) + } + } + c.op(opLineEnd) + c.op(opEnd) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/equals.go b/vendor/github.com/yosida95/uritemplate/v3/equals.go new file mode 100644 index 000000000..aa59a5c03 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/equals.go @@ -0,0 +1,53 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +type CompareFlags uint8 + +const ( + CompareVarname CompareFlags = 1 << iota +) + +// Equals reports whether or not two URI Templates t1 and t2 are equivalent. +func Equals(t1 *Template, t2 *Template, flags CompareFlags) bool { + if len(t1.exprs) != len(t2.exprs) { + return false + } + for i := 0; i < len(t1.exprs); i++ { + switch t1 := t1.exprs[i].(type) { + case literals: + t2, ok := t2.exprs[i].(literals) + if !ok { + return false + } + if t1 != t2 { + return false + } + case *expression: + t2, ok := t2.exprs[i].(*expression) + if !ok { + return false + } + if t1.op != t2.op || len(t1.vars) != len(t2.vars) { + return false + } + for n := 0; n < len(t1.vars); n++ { + v1 := t1.vars[n] + v2 := t2.vars[n] + if flags&CompareVarname == CompareVarname && v1.name != v2.name { + return false + } + if v1.maxlen != v2.maxlen || v1.explode != v2.explode { + return false + } + } + default: + panic("unhandled case") + } + } + return true +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/error.go b/vendor/github.com/yosida95/uritemplate/v3/error.go new file mode 100644 index 000000000..2fd34a808 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/error.go @@ -0,0 +1,16 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" +) + +func errorf(pos int, format string, a ...interface{}) error { + msg := fmt.Sprintf(format, a...) + return fmt.Errorf("uritemplate:%d:%s", pos, msg) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/escape.go b/vendor/github.com/yosida95/uritemplate/v3/escape.go new file mode 100644 index 000000000..6d27e693a --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/escape.go @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +var ( + hex = []byte("0123456789ABCDEF") + // reserved = gen-delims / sub-delims + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + // sub-delims = "!" / "$" / "&" / "’" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + rangeReserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x21, Hi: 0x21, Stride: 1}, // '!' + {Lo: 0x23, Hi: 0x24, Stride: 1}, // '#' - '$' + {Lo: 0x26, Hi: 0x2C, Stride: 1}, // '&' - ',' + {Lo: 0x2F, Hi: 0x2F, Stride: 1}, // '/' + {Lo: 0x3A, Hi: 0x3B, Stride: 1}, // ':' - ';' + {Lo: 0x3D, Hi: 0x3D, Stride: 1}, // '=' + {Lo: 0x3F, Hi: 0x40, Stride: 1}, // '?' - '@' + {Lo: 0x5B, Hi: 0x5B, Stride: 1}, // '[' + {Lo: 0x5D, Hi: 0x5D, Stride: 1}, // ']' + }, + LatinOffset: 9, + } + reReserved = `\x21\x23\x24\x26-\x2c\x2f\x3a\x3b\x3d\x3f\x40\x5b\x5d` + // ALPHA = %x41-5A / %x61-7A + // DIGIT = %x30-39 + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + rangeUnreserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2D, Hi: 0x2E, Stride: 1}, // '-' - '.' + {Lo: 0x30, Hi: 0x39, Stride: 1}, // '0' - '9' + {Lo: 0x41, Hi: 0x5A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x5F, Hi: 0x5F, Stride: 1}, // '_' + {Lo: 0x61, Hi: 0x7A, Stride: 1}, // 'a' - 'z' + {Lo: 0x7E, Hi: 0x7E, Stride: 1}, // '~' + }, + } + reUnreserved = `\x2d\x2e\x30-\x39\x41-\x5a\x5f\x61-\x7a\x7e` +) + +type runeClass uint8 + +const ( + runeClassU runeClass = 1 << iota + runeClassR + runeClassPctE + runeClassLast + + runeClassUR = runeClassU | runeClassR +) + +var runeClassNames = []string{ + "U", + "R", + "pct-encoded", +} + +func (rc runeClass) String() string { + ret := make([]string, 0, len(runeClassNames)) + for i, j := 0, runeClass(1); j < runeClassLast; j <<= 1 { + if rc&j == j { + ret = append(ret, runeClassNames[i]) + } + i++ + } + return strings.Join(ret, "+") +} + +func pctEncode(w *strings.Builder, r rune) { + if s := r >> 24 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 16 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 8 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +func pctDecode(s string) string { + size := len(s) + for i := 0; i < len(s); { + switch s[i] { + case '%': + size -= 2 + i += 3 + default: + i++ + } + } + if size == len(s) { + return s + } + + buf := make([]byte, size) + j := 0 + for i := 0; i < len(s); { + switch c := s[i]; c { + case '%': + buf[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + i += 3 + j++ + default: + buf[j] = c + i++ + j++ + } + } + return string(buf) +} + +type escapeFunc func(*strings.Builder, string) error + +func escapeLiteral(w *strings.Builder, v string) error { + w.WriteString(v) + return nil +} + +func escapeExceptU(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + if unicode.Is(rangeUnreserved, r) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} + +func escapeExceptUR(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + // TODO(yosida95): is pct-encoded triplets allowed here? + if unicode.In(r, rangeUnreserved, rangeReserved) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/expression.go b/vendor/github.com/yosida95/uritemplate/v3/expression.go new file mode 100644 index 000000000..4858c2dde --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/expression.go @@ -0,0 +1,173 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "regexp" + "strconv" + "strings" +) + +type template interface { + expand(*strings.Builder, Values) error + regexp(*strings.Builder) +} + +type literals string + +func (l literals) expand(b *strings.Builder, _ Values) error { + b.WriteString(string(l)) + return nil +} + +func (l literals) regexp(b *strings.Builder) { + b.WriteString("(?:") + b.WriteString(regexp.QuoteMeta(string(l))) + b.WriteByte(')') +} + +type varspec struct { + name string + maxlen int + explode bool +} + +type expression struct { + vars []varspec + op parseOp + first string + sep string + named bool + ifemp string + escape escapeFunc + allow runeClass +} + +func (e *expression) init() { + switch e.op { + case parseOpSimple: + e.sep = "," + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpPlus: + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpCrosshatch: + e.first = "#" + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpDot: + e.first = "." + e.sep = "." + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSlash: + e.first = "/" + e.sep = "/" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSemicolon: + e.first = ";" + e.sep = ";" + e.named = true + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpQuestion: + e.first = "?" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpAmpersand: + e.first = "&" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + } +} + +func (e *expression) expand(w *strings.Builder, values Values) error { + first := true + for _, varspec := range e.vars { + value := values.Get(varspec.name) + if !value.Valid() { + continue + } + + if first { + w.WriteString(e.first) + first = false + } else { + w.WriteString(e.sep) + } + + if err := value.expand(w, varspec, e); err != nil { + return err + } + + } + return nil +} + +func (e *expression) regexp(b *strings.Builder) { + if e.first != "" { + b.WriteString("(?:") // $1 + b.WriteString(regexp.QuoteMeta(e.first)) + } + b.WriteByte('(') // $2 + runeClassToRegexp(b, e.allow, e.named || e.vars[0].explode) + if len(e.vars) > 1 || e.vars[0].explode { + max := len(e.vars) - 1 + for i := 0; i < len(e.vars); i++ { + if e.vars[i].explode { + max = -1 + break + } + } + + b.WriteString("(?:") // $3 + b.WriteString(regexp.QuoteMeta(e.sep)) + runeClassToRegexp(b, e.allow, e.named || max < 0) + b.WriteByte(')') // $3 + if max > 0 { + b.WriteString("{0,") + b.WriteString(strconv.Itoa(max)) + b.WriteByte('}') + } else { + b.WriteByte('*') + } + } + b.WriteByte(')') // $2 + if e.first != "" { + b.WriteByte(')') // $1 + } + b.WriteByte('?') +} + +func runeClassToRegexp(b *strings.Builder, class runeClass, named bool) { + b.WriteString("(?:(?:[") + if class&runeClassR == 0 { + b.WriteString(`\x2c`) + if named { + b.WriteString(`\x3d`) + } + } + if class&runeClassU == runeClassU { + b.WriteString(reUnreserved) + } + if class&runeClassR == runeClassR { + b.WriteString(reReserved) + } + b.WriteString("]") + b.WriteString("|%[[:xdigit:]][[:xdigit:]]") + b.WriteString(")*)") +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/machine.go b/vendor/github.com/yosida95/uritemplate/v3/machine.go new file mode 100644 index 000000000..7b1d0b518 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/machine.go @@ -0,0 +1,23 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +// threadList implements https://research.swtch.com/sparse. +type threadList struct { + dense []threadEntry + sparse []uint32 +} + +type threadEntry struct { + pc uint32 + t *thread +} + +type thread struct { + op *progOp + cap map[string][]int +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/match.go b/vendor/github.com/yosida95/uritemplate/v3/match.go new file mode 100644 index 000000000..02fe6385a --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/match.go @@ -0,0 +1,213 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "unicode" + "unicode/utf8" +) + +type matcher struct { + prog *prog + + list1 threadList + list2 threadList + matched bool + cap map[string][]int + + input string +} + +func (m *matcher) at(pos int) (rune, int, bool) { + if l := len(m.input); pos < l { + c := m.input[pos] + if c < utf8.RuneSelf { + return rune(c), 1, pos+1 < l + } + r, size := utf8.DecodeRuneInString(m.input[pos:]) + return r, size, pos+size < l + } + return -1, 0, false +} + +func (m *matcher) add(list *threadList, pc uint32, pos int, next bool, cap map[string][]int) { + if i := list.sparse[pc]; i < uint32(len(list.dense)) && list.dense[i].pc == pc { + return + } + + n := len(list.dense) + list.dense = list.dense[:n+1] + list.sparse[pc] = uint32(n) + + e := &list.dense[n] + e.pc = pc + e.t = nil + + op := &m.prog.op[pc] + switch op.code { + default: + panic("unhandled opcode") + case opRune, opRuneClass, opEnd: + e.t = &thread{ + op: &m.prog.op[pc], + cap: make(map[string][]int, len(m.cap)), + } + for k, v := range cap { + e.t.cap[k] = make([]int, len(v)) + copy(e.t.cap[k], v) + } + case opLineBegin: + if pos == 0 { + m.add(list, pc+1, pos, next, cap) + } + case opLineEnd: + if !next { + m.add(list, pc+1, pos, next, cap) + } + case opCapStart, opCapEnd: + ocap := make(map[string][]int, len(m.cap)) + for k, v := range cap { + ocap[k] = make([]int, len(v)) + copy(ocap[k], v) + } + ocap[op.name] = append(ocap[op.name], pos) + m.add(list, pc+1, pos, next, ocap) + case opSplit: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmp: + m.add(list, op.i, pos, next, cap) + case opJmpIfNotDefined: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotFirst: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotEmpty: + m.add(list, op.i, pos, next, cap) + m.add(list, pc+1, pos, next, cap) + case opNoop: + m.add(list, pc+1, pos, next, cap) + } +} + +func (m *matcher) step(clist *threadList, nlist *threadList, r rune, pos int, nextPos int, next bool) { + debug.Printf("===== %q =====", string(r)) + for i := 0; i < len(clist.dense); i++ { + e := clist.dense[i] + if debug { + var buf bytes.Buffer + dumpProg(&buf, m.prog, e.pc) + debug.Printf("\n%s", buf.String()) + } + if e.t == nil { + continue + } + + t := e.t + op := t.op + switch op.code { + default: + panic("unhandled opcode") + case opRune: + if op.r == r { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opRuneClass: + ret := false + if !ret && op.rc&runeClassU == runeClassU { + ret = ret || unicode.Is(rangeUnreserved, r) + } + if !ret && op.rc&runeClassR == runeClassR { + ret = ret || unicode.Is(rangeReserved, r) + } + if !ret && op.rc&runeClassPctE == runeClassPctE { + ret = ret || unicode.Is(unicode.ASCII_Hex_Digit, r) + } + if ret { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opEnd: + m.matched = true + for k, v := range t.cap { + m.cap[k] = make([]int, len(v)) + copy(m.cap[k], v) + } + clist.dense = clist.dense[:0] + } + } + clist.dense = clist.dense[:0] +} + +func (m *matcher) match() bool { + pos := 0 + clist, nlist := &m.list1, &m.list2 + for { + if len(clist.dense) == 0 && m.matched { + break + } + r, width, next := m.at(pos) + if !m.matched { + m.add(clist, 0, pos, next, m.cap) + } + m.step(clist, nlist, r, pos, pos+width, next) + + if width < 1 { + break + } + pos += width + + clist, nlist = nlist, clist + } + return m.matched +} + +func (tmpl *Template) Match(expansion string) Values { + tmpl.mu.Lock() + if tmpl.prog == nil { + c := compiler{} + c.init() + c.compile(tmpl) + tmpl.prog = c.prog + } + prog := tmpl.prog + tmpl.mu.Unlock() + + n := len(prog.op) + m := matcher{ + prog: prog, + list1: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + list2: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + cap: make(map[string][]int, prog.numCap), + input: expansion, + } + if !m.match() { + return nil + } + + match := make(Values, len(m.cap)) + for name, indices := range m.cap { + v := Value{V: make([]string, len(indices)/2)} + for i := range v.V { + v.V[i] = pctDecode(expansion[indices[2*i]:indices[2*i+1]]) + } + if len(v.V) == 1 { + v.T = ValueTypeString + } else { + v.T = ValueTypeList + } + match[name] = v + } + return match +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/parse.go b/vendor/github.com/yosida95/uritemplate/v3/parse.go new file mode 100644 index 000000000..fd38a682f --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/parse.go @@ -0,0 +1,277 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode" + "unicode/utf8" +) + +type parseOp int + +const ( + parseOpSimple parseOp = iota + parseOpPlus + parseOpCrosshatch + parseOpDot + parseOpSlash + parseOpSemicolon + parseOpQuestion + parseOpAmpersand +) + +var ( + rangeVarchar = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0030, Hi: 0x0039, Stride: 1}, // '0' - '9' + {Lo: 0x0041, Hi: 0x005A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + }, + LatinOffset: 4, + } + rangeLiterals = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0021, Hi: 0x0021, Stride: 1}, // '!' + {Lo: 0x0023, Hi: 0x0024, Stride: 1}, // '#' - '$' + {Lo: 0x0026, Hi: 0x003B, Stride: 1}, // '&' ''' '(' - ';'. '''/27 used to be excluded but an errata is in the review process https://www.rfc-editor.org/errata/eid6937 + {Lo: 0x003D, Hi: 0x003D, Stride: 1}, // '=' + {Lo: 0x003F, Hi: 0x005B, Stride: 1}, // '?' - '[' + {Lo: 0x005D, Hi: 0x005D, Stride: 1}, // ']' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + {Lo: 0x007E, Hi: 0x007E, Stride: 1}, // '~' + {Lo: 0x00A0, Hi: 0xD7FF, Stride: 1}, // ucschar + {Lo: 0xE000, Hi: 0xF8FF, Stride: 1}, // iprivate + {Lo: 0xF900, Hi: 0xFDCF, Stride: 1}, // ucschar + {Lo: 0xFDF0, Hi: 0xFFEF, Stride: 1}, // ucschar + }, + R32: []unicode.Range32{ + {Lo: 0x00010000, Hi: 0x0001FFFD, Stride: 1}, // ucschar + {Lo: 0x00020000, Hi: 0x0002FFFD, Stride: 1}, // ucschar + {Lo: 0x00030000, Hi: 0x0003FFFD, Stride: 1}, // ucschar + {Lo: 0x00040000, Hi: 0x0004FFFD, Stride: 1}, // ucschar + {Lo: 0x00050000, Hi: 0x0005FFFD, Stride: 1}, // ucschar + {Lo: 0x00060000, Hi: 0x0006FFFD, Stride: 1}, // ucschar + {Lo: 0x00070000, Hi: 0x0007FFFD, Stride: 1}, // ucschar + {Lo: 0x00080000, Hi: 0x0008FFFD, Stride: 1}, // ucschar + {Lo: 0x00090000, Hi: 0x0009FFFD, Stride: 1}, // ucschar + {Lo: 0x000A0000, Hi: 0x000AFFFD, Stride: 1}, // ucschar + {Lo: 0x000B0000, Hi: 0x000BFFFD, Stride: 1}, // ucschar + {Lo: 0x000C0000, Hi: 0x000CFFFD, Stride: 1}, // ucschar + {Lo: 0x000D0000, Hi: 0x000DFFFD, Stride: 1}, // ucschar + {Lo: 0x000E1000, Hi: 0x000EFFFD, Stride: 1}, // ucschar + {Lo: 0x000F0000, Hi: 0x000FFFFD, Stride: 1}, // iprivate + {Lo: 0x00100000, Hi: 0x0010FFFD, Stride: 1}, // iprivate + }, + LatinOffset: 10, + } +) + +type parser struct { + r string + start int + stop int + state parseState +} + +func (p *parser) errorf(i rune, format string, a ...interface{}) error { + return fmt.Errorf("%s: %s%s", fmt.Sprintf(format, a...), p.r[0:p.stop], string(i)) +} + +func (p *parser) rune() (rune, int) { + r, size := utf8.DecodeRuneInString(p.r[p.stop:]) + if r != utf8.RuneError { + p.stop += size + } + return r, size +} + +func (p *parser) unread(r rune) { + p.stop -= utf8.RuneLen(r) +} + +type parseState int + +const ( + parseStateDefault = parseState(iota) + parseStateOperator + parseStateVarList + parseStateVarName + parseStatePrefix +) + +func (p *parser) setState(state parseState) { + p.state = state + p.start = p.stop +} + +func (p *parser) parseURITemplate() (*Template, error) { + tmpl := Template{ + raw: p.r, + exprs: []template{}, + } + + var exp *expression + for { + r, size := p.rune() + if r == utf8.RuneError { + if size == 0 { + if p.state != parseStateDefault { + return nil, p.errorf('_', "incomplete expression") + } + if p.start < p.stop { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:p.stop])) + } + return &tmpl, nil + } + return nil, p.errorf('_', "invalid UTF-8 sequence") + } + + switch p.state { + case parseStateDefault: + switch r { + case '{': + if stop := p.stop - size; stop > p.start { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:stop])) + } + exp = &expression{} + tmpl.exprs = append(tmpl.exprs, exp) + p.setState(parseStateOperator) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + default: + if !unicode.Is(rangeLiterals, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable character (hint: use %%XX encoding)") + } + } + case parseStateOperator: + switch r { + default: + p.unread(r) + exp.op = parseOpSimple + case '+': + exp.op = parseOpPlus + case '#': + exp.op = parseOpCrosshatch + case '.': + exp.op = parseOpDot + case '/': + exp.op = parseOpSlash + case ';': + exp.op = parseOpSemicolon + case '?': + exp.op = parseOpQuestion + case '&': + exp.op = parseOpAmpersand + case '=', ',', '!', '@', '|': // op-reserved + return nil, p.errorf('|', "unimplemented operator (op-reserved)") + } + p.setState(parseStateVarName) + case parseStateVarList: + switch r { + case ',': + p.setState(parseStateVarName) + case '}': + exp.init() + p.setState(parseStateDefault) + default: + p.unread(r) + return nil, p.errorf('_', "unrecognized value modifier") + } + case parseStateVarName: + switch r { + case ':', '*': + name := p.r[p.start : p.stop-size] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + explode := r == '*' + exp.vars = append(exp.vars, varspec{ + name: name, + explode: explode, + }) + if explode { + p.setState(parseStateVarList) + } else { + p.setState(parseStatePrefix) + } + case ',', '}': + p.unread(r) + name := p.r[p.start:p.stop] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + exp.vars = append(exp.vars, varspec{ + name: name, + }) + p.setState(parseStateVarList) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + case '.': + if dot := p.stop - size; dot == p.start || p.r[dot-1] == '.' { + return nil, p.errorf('|', "unacceptable variable name") + } + default: + if !unicode.Is(rangeVarchar, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable variable name") + } + } + case parseStatePrefix: + spec := &(exp.vars[len(exp.vars)-1]) + switch { + case '0' <= r && r <= '9': + spec.maxlen *= 10 + spec.maxlen += int(r - '0') + if spec.maxlen == 0 || spec.maxlen > 9999 { + return nil, p.errorf('|', "max-length must be (0, 9999]") + } + default: + p.unread(r) + if spec.maxlen == 0 { + return nil, p.errorf('_', "max-length must be (0, 9999]") + } + p.setState(parseStateVarList) + } + default: + p.unread(r) + panic(p.errorf('_', "unhandled parseState(%d)", p.state)) + } + } +} + +func isValidVarname(name string) bool { + if l := len(name); l == 0 || name[0] == '.' || name[l-1] == '.' { + return false + } + for i := 1; i < len(name)-1; i++ { + switch c := name[i]; c { + case '.': + if name[i-1] == '.' { + return false + } + } + } + return true +} + +func (p *parser) consumeTriplet() error { + if len(p.r)-p.stop < 3 || p.r[p.stop] != '%' || !ishex(p.r[p.stop+1]) || !ishex(p.r[p.stop+2]) { + return p.errorf('_', "incomplete pct-encodeed") + } + p.stop += 3 + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/prog.go b/vendor/github.com/yosida95/uritemplate/v3/prog.go new file mode 100644 index 000000000..97af4f0ea --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/prog.go @@ -0,0 +1,130 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "strconv" +) + +type progOpcode uint16 + +const ( + // match + opRune progOpcode = iota + opRuneClass + opLineBegin + opLineEnd + // capture + opCapStart + opCapEnd + // stack + opSplit + opJmp + opJmpIfNotDefined + opJmpIfNotEmpty + opJmpIfNotFirst + // result + opEnd + // fake + opNoop + opcodeMax +) + +var opcodeNames = []string{ + // match + "opRune", + "opRuneClass", + "opLineBegin", + "opLineEnd", + // capture + "opCapStart", + "opCapEnd", + // stack + "opSplit", + "opJmp", + "opJmpIfNotDefined", + "opJmpIfNotEmpty", + "opJmpIfNotFirst", + // result + "opEnd", +} + +func (code progOpcode) String() string { + if code >= opcodeMax { + return "" + } + return opcodeNames[code] +} + +type progOp struct { + code progOpcode + r rune + rc runeClass + i uint32 + + name string +} + +func dumpProgOp(b *bytes.Buffer, op *progOp) { + b.WriteString(op.code.String()) + switch op.code { + case opRune: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(string(op.r))) + b.WriteString(")") + case opRuneClass: + b.WriteString("(") + b.WriteString(op.rc.String()) + b.WriteString(")") + case opCapStart, opCapEnd: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + case opSplit: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmp, opJmpIfNotFirst: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmpIfNotDefined, opJmpIfNotEmpty: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + } +} + +type prog struct { + op []progOp + numCap int +} + +func dumpProg(b *bytes.Buffer, prog *prog, pc uint32) { + for i := range prog.op { + op := prog.op[i] + + pos := strconv.Itoa(i) + if uint32(i) == pc { + pos = "*" + pos + } + b.WriteString(" "[len(pos):]) + b.WriteString(pos) + + b.WriteByte('\t') + dumpProgOp(b, &op) + + b.WriteByte('\n') + } +} + +func (p *prog) String() string { + b := bytes.Buffer{} + dumpProg(&b, p, 0) + return b.String() +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go new file mode 100644 index 000000000..dbd267375 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go @@ -0,0 +1,116 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "log" + "regexp" + "strings" + "sync" +) + +var ( + debug = debugT(false) +) + +type debugT bool + +func (t debugT) Printf(format string, v ...interface{}) { + if t { + log.Printf(format, v...) + } +} + +// Template represents a URI Template. +type Template struct { + raw string + exprs []template + + // protects the rest of fields + mu sync.Mutex + varnames []string + re *regexp.Regexp + prog *prog +} + +// New parses and constructs a new Template instance based on the template. +// New returns an error if the template cannot be recognized. +func New(template string) (*Template, error) { + return (&parser{r: template}).parseURITemplate() +} + +// MustNew panics if the template cannot be recognized. +func MustNew(template string) *Template { + ret, err := New(template) + if err != nil { + panic(err) + } + return ret +} + +// Raw returns a raw URI template passed to New in string. +func (t *Template) Raw() string { + return t.raw +} + +// Varnames returns variable names used in the template. +func (t *Template) Varnames() []string { + t.mu.Lock() + defer t.mu.Unlock() + if t.varnames != nil { + return t.varnames + } + + reg := map[string]struct{}{} + t.varnames = []string{} + for i := range t.exprs { + expr, ok := t.exprs[i].(*expression) + if !ok { + continue + } + for _, spec := range expr.vars { + if _, ok := reg[spec.name]; ok { + continue + } + reg[spec.name] = struct{}{} + t.varnames = append(t.varnames, spec.name) + } + } + + return t.varnames +} + +// Expand returns a URI reference corresponding to the template expanded using the passed variables. +func (t *Template) Expand(vars Values) (string, error) { + var w strings.Builder + for i := range t.exprs { + expr := t.exprs[i] + if err := expr.expand(&w, vars); err != nil { + return w.String(), err + } + } + return w.String(), nil +} + +// Regexp converts the template to regexp and returns compiled *regexp.Regexp. +func (t *Template) Regexp() *regexp.Regexp { + t.mu.Lock() + defer t.mu.Unlock() + if t.re != nil { + return t.re + } + + var b strings.Builder + b.WriteByte('^') + for _, expr := range t.exprs { + expr.regexp(&b) + } + b.WriteByte('$') + t.re = regexp.MustCompile(b.String()) + + return t.re +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/value.go b/vendor/github.com/yosida95/uritemplate/v3/value.go new file mode 100644 index 000000000..0550eabdb --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/value.go @@ -0,0 +1,216 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import "strings" + +// A varname containing pct-encoded characters is not the same variable as +// a varname with those same characters decoded. +// +// -- https://tools.ietf.org/html/rfc6570#section-2.3 +type Values map[string]Value + +func (v Values) Set(name string, value Value) { + v[name] = value +} + +func (v Values) Get(name string) Value { + if v == nil { + return Value{} + } + return v[name] +} + +type ValueType uint8 + +const ( + ValueTypeString = iota + ValueTypeList + ValueTypeKV + valueTypeLast +) + +var valueTypeNames = []string{ + "String", + "List", + "KV", +} + +func (vt ValueType) String() string { + if vt < valueTypeLast { + return valueTypeNames[vt] + } + return "" +} + +type Value struct { + T ValueType + V []string +} + +func (v Value) String() string { + if v.Valid() && v.T == ValueTypeString { + return v.V[0] + } + return "" +} + +func (v Value) List() []string { + if v.Valid() && v.T == ValueTypeList { + return v.V + } + return nil +} + +func (v Value) KV() []string { + if v.Valid() && v.T == ValueTypeKV { + return v.V + } + return nil +} + +func (v Value) Valid() bool { + switch v.T { + default: + return false + case ValueTypeString: + return len(v.V) > 0 + case ValueTypeList: + return len(v.V) > 0 + case ValueTypeKV: + return len(v.V) > 0 && len(v.V)%2 == 0 + } +} + +func (v Value) expand(w *strings.Builder, spec varspec, exp *expression) error { + switch v.T { + case ValueTypeString: + val := v.V[0] + var maxlen int + if max := len(val); spec.maxlen < 1 || spec.maxlen > max { + maxlen = max + } else { + maxlen = spec.maxlen + } + + if exp.named { + w.WriteString(spec.name) + if val == "" { + w.WriteString(exp.ifemp) + return nil + } + w.WriteByte('=') + } + return exp.escape(w, val[:maxlen]) + case ValueTypeList: + var sep string + if spec.explode { + sep = exp.sep + } else { + sep = "," + } + + var pre string + var preifemp string + if spec.explode && exp.named { + pre = spec.name + "=" + preifemp = spec.name + exp.ifemp + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + for i := range v.V { + val := v.V[i] + if i > 0 { + w.WriteString(sep) + } + if val == "" { + w.WriteString(preifemp) + continue + } + w.WriteString(pre) + + if err := exp.escape(w, val); err != nil { + return err + } + } + case ValueTypeKV: + var sep string + var kvsep string + if spec.explode { + sep = exp.sep + kvsep = "=" + } else { + sep = "," + kvsep = "," + } + + var ifemp string + var kescape escapeFunc + if spec.explode && exp.named { + ifemp = exp.ifemp + kescape = escapeLiteral + } else { + ifemp = "," + kescape = exp.escape + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + + for i := 0; i < len(v.V); i += 2 { + if i > 0 { + w.WriteString(sep) + } + if err := kescape(w, v.V[i]); err != nil { + return err + } + if v.V[i+1] == "" { + w.WriteString(ifemp) + continue + } + w.WriteString(kvsep) + + if err := exp.escape(w, v.V[i+1]); err != nil { + return err + } + } + } + return nil +} + +// String returns Value that represents string. +func String(v string) Value { + return Value{ + T: ValueTypeString, + V: []string{v}, + } +} + +// List returns Value that represents list. +func List(v ...string) Value { + return Value{ + T: ValueTypeList, + V: v, + } +} + +// KV returns Value that represents associative list. +// KV panics if len(kv) is not even. +func KV(kv ...string) Value { + if len(kv)%2 != 0 { + panic("uritemplate.go: count of the kv must be even number") + } + return Value{ + T: ValueTypeKV, + V: kv, + } +} diff --git a/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go index 51121a3d5..e86346e8b 100644 --- a/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go +++ b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go @@ -55,7 +55,7 @@ type Config struct { // Token uses client credentials to retrieve a token. // -// The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable. func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { return c.TokenSource(ctx).Token() } @@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { // The token will auto-refresh as necessary. // // The provided context optionally controls which HTTP client -// is returned. See the oauth2.HTTPClient variable. +// is returned. See the [oauth2.HTTPClient] variable. // -// The returned Client and its Transport should not be modified. +// The returned [http.Client] and its Transport should not be modified. func (c *Config) Client(ctx context.Context) *http.Client { return oauth2.NewClient(ctx, c.TokenSource(ctx)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [oauth2.TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context and the // client ID and client secret. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { source := &tokenSource{ ctx: ctx, diff --git a/vendor/golang.org/x/oauth2/internal/doc.go b/vendor/golang.org/x/oauth2/internal/doc.go index 03265e888..8c7c475f2 100644 --- a/vendor/golang.org/x/oauth2/internal/doc.go +++ b/vendor/golang.org/x/oauth2/internal/doc.go @@ -2,5 +2,5 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package internal contains support packages for oauth2 package. +// Package internal contains support packages for [golang.org/x/oauth2]. package internal diff --git a/vendor/golang.org/x/oauth2/internal/oauth2.go b/vendor/golang.org/x/oauth2/internal/oauth2.go index 14989beaf..71ea6ad1f 100644 --- a/vendor/golang.org/x/oauth2/internal/oauth2.go +++ b/vendor/golang.org/x/oauth2/internal/oauth2.go @@ -13,7 +13,7 @@ import ( ) // ParseKey converts the binary contents of a private key file -// to an *rsa.PrivateKey. It detects whether the private key is in a +// to an [*rsa.PrivateKey]. It detects whether the private key is in a // PEM container or not. If so, it extracts the private key // from PEM container before conversion. It only supports PEM // containers with no passphrase. diff --git a/vendor/golang.org/x/oauth2/internal/token.go b/vendor/golang.org/x/oauth2/internal/token.go index e83ddeef0..8389f2462 100644 --- a/vendor/golang.org/x/oauth2/internal/token.go +++ b/vendor/golang.org/x/oauth2/internal/token.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "mime" "net/http" @@ -26,9 +25,9 @@ import ( // the requests to access protected resources on the OAuth 2.0 // provider's backend. // -// This type is a mirror of oauth2.Token and exists to break +// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break // an otherwise-circular dependency. Other internal packages -// should convert this Token into an oauth2.Token before use. +// should convert this Token into an [golang.org/x/oauth2.Token] before use. type Token struct { // AccessToken is the token that authorizes and authenticates // the requests. @@ -50,9 +49,16 @@ type Token struct { // mechanisms for that TokenSource will not be used. Expiry time.Time + // ExpiresIn is the OAuth2 wire format "expires_in" field, + // which specifies how many seconds later the token expires, + // relative to an unknown time base approximately around "now". + // It is the application's responsibility to populate + // `Expiry` from `ExpiresIn` when required. + ExpiresIn int64 `json:"expires_in,omitempty"` + // Raw optionally contains extra metadata from the server // when updating a token. - Raw interface{} + Raw any } // tokenJSON is the struct representing the HTTP response from OAuth2 @@ -99,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. -// -// Deprecated: this function no longer does anything. Caller code that -// wants to avoid potential extra HTTP requests made during -// auto-probing of the provider's auth style should set -// Endpoint.AuthStyle. -func RegisterBrokenAuthHeaderProvider(tokenURL string) {} - // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. type AuthStyle int @@ -143,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { return c } +type authStyleCacheKey struct { + url string + clientID string +} + // AuthStyleCache is the set of tokenURLs we've successfully used via // RetrieveToken and which style auth we ended up using. // It's called a cache, but it doesn't (yet?) shrink. It's expected that @@ -150,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { // small. type AuthStyleCache struct { mu sync.Mutex - m map[string]AuthStyle // keyed by tokenURL + m map[authStyleCacheKey]AuthStyle } // lookupAuthStyle reports which auth style we last used with tokenURL // when calling RetrieveToken and whether we have ever done so. -func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { +func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) { c.mu.Lock() defer c.mu.Unlock() - style, ok = c.m[tokenURL] + style, ok = c.m[authStyleCacheKey{tokenURL, clientID}] return } // setAuthStyle adds an entry to authStyleCache, documented above. -func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { +func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) { c.mu.Lock() defer c.mu.Unlock() if c.m == nil { - c.m = make(map[string]AuthStyle) + c.m = make(map[authStyleCacheKey]AuthStyle) } - c.m[tokenURL] = v + c.m[authStyleCacheKey{tokenURL, clientID}] = v } // newTokenRequest returns a new *http.Request to retrieve a new token @@ -210,9 +213,9 @@ func cloneURLValues(v url.Values) url.Values { } func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { - needsAuthStyleProbe := authStyle == 0 + needsAuthStyleProbe := authStyle == AuthStyleUnknown if needsAuthStyleProbe { - if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { + if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok { authStyle = style needsAuthStyleProbe = false } else { @@ -242,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, token, err = doTokenRoundTrip(ctx, req) } if needsAuthStyleProbe && err == nil { - styleCache.setAuthStyle(tokenURL, authStyle) + styleCache.setAuthStyle(tokenURL, clientID, authStyle) } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. @@ -257,7 +260,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { if err != nil { return nil, err } - body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) r.Body.Close() if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -312,7 +315,8 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { TokenType: tj.TokenType, RefreshToken: tj.RefreshToken, Expiry: tj.expiry(), - Raw: make(map[string]interface{}), + ExpiresIn: int64(tj.ExpiresIn), + Raw: make(map[string]any), } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } diff --git a/vendor/golang.org/x/oauth2/internal/transport.go b/vendor/golang.org/x/oauth2/internal/transport.go index b9db01ddf..afc0aeb27 100644 --- a/vendor/golang.org/x/oauth2/internal/transport.go +++ b/vendor/golang.org/x/oauth2/internal/transport.go @@ -9,8 +9,8 @@ import ( "net/http" ) -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate an [*http.Client] value with a context. var HTTPClient ContextKey // ContextKey is just an empty struct. It exists so HTTPClient can be diff --git a/vendor/golang.org/x/oauth2/jws/jws.go b/vendor/golang.org/x/oauth2/jws/jws.go index 95015648b..9bc484406 100644 --- a/vendor/golang.org/x/oauth2/jws/jws.go +++ b/vendor/golang.org/x/oauth2/jws/jws.go @@ -4,7 +4,7 @@ // Package jws provides a partial implementation // of JSON Web Signature encoding and decoding. -// It exists to support the golang.org/x/oauth2 package. +// It exists to support the [golang.org/x/oauth2] package. // // See RFC 7515. // @@ -48,7 +48,7 @@ type ClaimSet struct { // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // This array is marshalled using custom code (see (c *ClaimSet) encode()). - PrivateClaims map[string]interface{} `json:"-"` + PrivateClaims map[string]any `json:"-"` } func (c *ClaimSet) encode() (string, error) { @@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) { // Decode decodes a claim set from a JWS payload. func Decode(payload string) (*ClaimSet, error) { // decode returned id token to get expiry - s := strings.Split(payload, ".") - if len(s) < 2 { + _, claims, _, ok := parseToken(payload) + if !ok { // TODO(jbd): Provide more context about the error. return nil, errors.New("jws: invalid token received") } - decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + decoded, err := base64.RawURLEncoding.DecodeString(claims) if err != nil { return nil, err } @@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) { } // Encode encodes a signed JWS with provided header and claim set. -// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key. +// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key. func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { sg := func(data []byte) (sig []byte, err error) { h := sha256.New() @@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { // Verify tests whether the provided JWT token's signature was produced by the private key // associated with the supplied public key. func Verify(token string, key *rsa.PublicKey) error { - parts := strings.Split(token, ".") - if len(parts) != 3 { + header, claims, sig, ok := parseToken(token) + if !ok { return errors.New("jws: invalid token received, token must have 3 parts") } - - signedContent := parts[0] + "." + parts[1] - signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) + signatureString, err := base64.RawURLEncoding.DecodeString(sig) if err != nil { return err } h := sha256.New() - h.Write([]byte(signedContent)) + h.Write([]byte(header + tokenDelim + claims)) return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) } + +func parseToken(s string) (header, claims, sig string, ok bool) { + header, s, ok = strings.Cut(s, tokenDelim) + if !ok { // no period found + return "", "", "", false + } + claims, s, ok = strings.Cut(s, tokenDelim) + if !ok { // only one period found + return "", "", "", false + } + sig, _, ok = strings.Cut(s, tokenDelim) + if ok { // three periods found + return "", "", "", false + } + return header, claims, sig, true +} + +const tokenDelim = "." diff --git a/vendor/golang.org/x/oauth2/jwt/jwt.go b/vendor/golang.org/x/oauth2/jwt/jwt.go index b2bf18298..38a92daca 100644 --- a/vendor/golang.org/x/oauth2/jwt/jwt.go +++ b/vendor/golang.org/x/oauth2/jwt/jwt.go @@ -13,7 +13,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -69,7 +68,7 @@ type Config struct { // PrivateClaims optionally specifies custom private claims in the JWT. // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 - PrivateClaims map[string]interface{} + PrivateClaims map[string]any // UseIDToken optionally specifies whether ID token should be used instead // of access token when the server returns both. @@ -136,7 +135,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -148,10 +147,8 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } // tokenRes is the JSON response body. var tokenRes struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - IDToken string `json:"id_token"` - ExpiresIn int64 `json:"expires_in"` // relative seconds from now + oauth2.Token + IDToken string `json:"id_token"` } if err := json.Unmarshal(body, &tokenRes); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -160,7 +157,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, } - raw := make(map[string]interface{}) + raw := make(map[string]any) json.Unmarshal(body, &raw) // no error checks for optional fields token = token.WithExtra(raw) diff --git a/vendor/golang.org/x/oauth2/oauth2.go b/vendor/golang.org/x/oauth2/oauth2.go index 74f052aa9..de34feb84 100644 --- a/vendor/golang.org/x/oauth2/oauth2.go +++ b/vendor/golang.org/x/oauth2/oauth2.go @@ -22,9 +22,9 @@ import ( ) // NoContext is the default context you should supply if not using -// your own context.Context (see https://golang.org/x/net/context). +// your own [context.Context]. // -// Deprecated: Use context.Background() or context.TODO() instead. +// Deprecated: Use [context.Background] or [context.TODO] instead. var NoContext = context.TODO() // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. @@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. -// For the client credentials 2-legged OAuth2 flow, see the clientcredentials -// package (https://golang.org/x/oauth2/clientcredentials). +// For the client credentials 2-legged OAuth2 flow, see the +// [golang.org/x/oauth2/clientcredentials] package. type Config struct { // ClientID is the application's ID. ClientID string @@ -46,7 +46,7 @@ type Config struct { // ClientSecret is the application's secret. ClientSecret string - // Endpoint contains the resource server's token endpoint + // Endpoint contains the authorization server's token endpoint // URLs. These are constants specific to each server and are // often available via site-specific packages, such as // google.Endpoint or github.Endpoint. @@ -135,7 +135,7 @@ type setParam struct{ k, v string } func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } -// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters +// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters // to a provider's authorization endpoint. func SetAuthURLParam(key, value string) AuthCodeOption { return setParam{key, value} @@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption { // request and callback. The authorization server includes this value when // redirecting the user agent back to the client. // -// Opts may include AccessTypeOnline or AccessTypeOffline, as well -// as ApprovalForce. +// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well +// as [ApprovalForce]. // // To protect against CSRF attacks, opts should include a PKCE challenge // (S256ChallengeOption). Not all servers support PKCE. An alternative is to @@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { // and when other authorization grant types are not available." // See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { v := url.Values{ "grant_type": {"password"}, @@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor // It is used after a resource provider redirects the user back // to the Redirect URI (the URL obtained from AuthCodeURL). // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. // -// The code will be in the *http.Request.FormValue("code"). Before -// calling Exchange, be sure to validate FormValue("state") if you are +// The code will be in the [http.Request.FormValue]("code"). Before +// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are // using it to protect against CSRF attacks. // // If using PKCE to protect against CSRF attacks, opts should include a @@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client { return NewClient(ctx, c.TokenSource(ctx, t)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { tkr := &tokenRefresher{ ctx: ctx, @@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { } } -// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" +// tokenRefresher is a TokenSource that makes "grant_type=refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { ctx context.Context // used to get HTTP requests @@ -288,7 +288,7 @@ func (tf *tokenRefresher) Token() (*Token, error) { if tf.refreshToken != tk.RefreshToken { tf.refreshToken = tk.RefreshToken } - return tk, err + return tk, nil } // reuseTokenSource is a TokenSource that holds a single token in memory @@ -305,8 +305,7 @@ type reuseTokenSource struct { } // Token returns the current token if it's still valid, else will -// refresh the current token (using r.Context for HTTP client -// information) and return the new one. +// refresh the current token and return the new one. func (s *reuseTokenSource) Token() (*Token, error) { s.mu.Lock() defer s.mu.Unlock() @@ -322,7 +321,7 @@ func (s *reuseTokenSource) Token() (*Token, error) { return t, nil } -// StaticTokenSource returns a TokenSource that always returns the same token. +// StaticTokenSource returns a [TokenSource] that always returns the same token. // Because the provided token t is never refreshed, StaticTokenSource is only // useful for tokens that never expire. func StaticTokenSource(t *Token) TokenSource { @@ -338,16 +337,16 @@ func (s staticTokenSource) Token() (*Token, error) { return s.t, nil } -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate a [*http.Client] value with a context. var HTTPClient internal.ContextKey -// NewClient creates an *http.Client from a Context and TokenSource. +// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource]. // The returned client is not valid beyond the lifetime of the context. // -// Note that if a custom *http.Client is provided via the Context it +// Note that if a custom [*http.Client] is provided via the [context.Context] it // is used only for token acquisition and is not used to configure the -// *http.Client returned from NewClient. +// [*http.Client] returned from NewClient. // // As a special case, if src is nil, a non-OAuth2 client is returned // using the provided context. This exists to support related OAuth2 @@ -356,15 +355,19 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { if src == nil { return internal.ContextClient(ctx) } + cc := internal.ContextClient(ctx) return &http.Client{ Transport: &Transport{ - Base: internal.ContextClient(ctx).Transport, + Base: cc.Transport, Source: ReuseTokenSource(nil, src), }, + CheckRedirect: cc.CheckRedirect, + Jar: cc.Jar, + Timeout: cc.Timeout, } } -// ReuseTokenSource returns a TokenSource which repeatedly returns the +// ReuseTokenSource returns a [TokenSource] which repeatedly returns the // same token as long as it's valid, starting with t. // When its cached token is invalid, a new token is obtained from src. // @@ -372,10 +375,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { // (such as a file on disk) between runs of a program, rather than // obtaining new tokens unnecessarily. // -// The initial token t may be nil, in which case the TokenSource is +// The initial token t may be nil, in which case the [TokenSource] is // wrapped in a caching version if it isn't one already. This also // means it's always safe to wrap ReuseTokenSource around any other -// TokenSource without adverse effects. +// [TokenSource] without adverse effects. func ReuseTokenSource(t *Token, src TokenSource) TokenSource { // Don't wrap a reuseTokenSource in itself. That would work, // but cause an unnecessary number of mutex operations. @@ -393,8 +396,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource { } } -// ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the -// TokenSource returned by ReuseTokenSource, except the expiry buffer is +// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the +// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is // configurable. The expiration time of a token is calculated as // t.Expiry.Add(-earlyExpiry). func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { diff --git a/vendor/golang.org/x/oauth2/pkce.go b/vendor/golang.org/x/oauth2/pkce.go index 50593b6df..cea8374d5 100644 --- a/vendor/golang.org/x/oauth2/pkce.go +++ b/vendor/golang.org/x/oauth2/pkce.go @@ -1,6 +1,7 @@ // Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + package oauth2 import ( @@ -20,9 +21,9 @@ const ( // This follows recommendations in RFC 7636. // // A fresh verifier should be generated for each authorization. -// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL -// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange -// (or Config.DeviceAccessToken). +// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] +// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken] +// with [VerifierOption]. func GenerateVerifier() string { // "RECOMMENDED that the output of a suitable random number generator be // used to create a 32-octet sequence. The octet sequence is then @@ -36,22 +37,22 @@ func GenerateVerifier() string { return base64.RawURLEncoding.EncodeToString(data) } -// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be -// passed to Config.Exchange or Config.DeviceAccessToken only. +// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be +// passed to [Config.Exchange] or [Config.DeviceAccessToken]. func VerifierOption(verifier string) AuthCodeOption { return setParam{k: codeVerifierKey, v: verifier} } // S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256. // -// Prefer to use S256ChallengeOption where possible. +// Prefer to use [S256ChallengeOption] where possible. func S256ChallengeFromVerifier(verifier string) string { sha := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sha[:]) } // S256ChallengeOption derives a PKCE code challenge derived from verifier with -// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess +// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // only. func S256ChallengeOption(verifier string) AuthCodeOption { return challengeOption{ diff --git a/vendor/golang.org/x/oauth2/token.go b/vendor/golang.org/x/oauth2/token.go index 109997d77..239ec3296 100644 --- a/vendor/golang.org/x/oauth2/token.go +++ b/vendor/golang.org/x/oauth2/token.go @@ -44,7 +44,7 @@ type Token struct { // Expiry is the optional expiration time of the access token. // - // If zero, TokenSource implementations will reuse the same + // If zero, [TokenSource] implementations will reuse the same // token forever and RefreshToken or equivalent // mechanisms for that TokenSource will not be used. Expiry time.Time `json:"expiry,omitempty"` @@ -58,7 +58,7 @@ type Token struct { // raw optionally contains extra metadata from the server // when updating a token. - raw interface{} + raw any // expiryDelta is used to calculate when a token is considered // expired, by subtracting from Expiry. If zero, defaultExpiryDelta @@ -86,16 +86,16 @@ func (t *Token) Type() string { // SetAuthHeader sets the Authorization header to r using the access // token in t. // -// This method is unnecessary when using Transport or an HTTP Client +// This method is unnecessary when using [Transport] or an HTTP Client // returned by this package. func (t *Token) SetAuthHeader(r *http.Request) { r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) } -// WithExtra returns a new Token that's a clone of t, but using the +// WithExtra returns a new [Token] that's a clone of t, but using the // provided raw extra map. This is only intended for use by packages // implementing derivative OAuth2 flows. -func (t *Token) WithExtra(extra interface{}) *Token { +func (t *Token) WithExtra(extra any) *Token { t2 := new(Token) *t2 = *t t2.raw = extra @@ -105,8 +105,8 @@ func (t *Token) WithExtra(extra interface{}) *Token { // Extra returns an extra field. // Extra fields are key-value pairs returned by the server as a // part of the token retrieval response. -func (t *Token) Extra(key string) interface{} { - if raw, ok := t.raw.(map[string]interface{}); ok { +func (t *Token) Extra(key string) any { + if raw, ok := t.raw.(map[string]any); ok { return raw[key] } @@ -163,13 +163,14 @@ func tokenFromInternal(t *internal.Token) *Token { TokenType: t.TokenType, RefreshToken: t.RefreshToken, Expiry: t.Expiry, + ExpiresIn: t.ExpiresIn, raw: t.Raw, } } // retrieveToken takes a *Config and uses that to retrieve an *internal.Token. // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along -// with an error.. +// with an error. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { diff --git a/vendor/golang.org/x/oauth2/transport.go b/vendor/golang.org/x/oauth2/transport.go index 90657915f..8bbebbac9 100644 --- a/vendor/golang.org/x/oauth2/transport.go +++ b/vendor/golang.org/x/oauth2/transport.go @@ -11,12 +11,12 @@ import ( "sync" ) -// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests, -// wrapping a base RoundTripper and adding an Authorization header -// with a token from the supplied Sources. +// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests, +// wrapping a base [http.RoundTripper] and adding an Authorization header +// with a token from the supplied [TokenSource]. // // Transport is a low-level mechanism. Most code will use the -// higher-level Config.Client method instead. +// higher-level [Config.Client] method instead. type Transport struct { // Source supplies the token to add to outgoing requests' // Authorization headers. @@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - req2 := cloneRequest(req) // per RoundTripper contract + req2 := req.Clone(req.Context()) token.SetAuthHeader(req2) // req.Body is assumed to be closed by the base RoundTripper. @@ -73,17 +73,3 @@ func (t *Transport) base() http.RoundTripper { } return http.DefaultTransport } - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - return r2 -} diff --git a/vendor/modules.txt b/vendor/modules.txt index cb35bbe31..0068ffeec 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -233,6 +233,9 @@ github.com/google/go-cmp/cmp/internal/diff github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value +# github.com/google/jsonschema-go v0.3.0 +## explicit; go 1.23.0 +github.com/google/jsonschema-go/jsonschema # github.com/google/uuid v1.6.0 ## explicit github.com/google/uuid @@ -307,6 +310,15 @@ github.com/maypok86/otter/v2/stats # github.com/mitchellh/mapstructure v1.5.0 ## explicit; go 1.14 github.com/mitchellh/mapstructure +# github.com/modelcontextprotocol/go-sdk v1.2.0 +## explicit; go 1.23.0 +github.com/modelcontextprotocol/go-sdk/auth +github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2 +github.com/modelcontextprotocol/go-sdk/internal/util +github.com/modelcontextprotocol/go-sdk/internal/xcontext +github.com/modelcontextprotocol/go-sdk/jsonrpc +github.com/modelcontextprotocol/go-sdk/mcp +github.com/modelcontextprotocol/go-sdk/oauthex # github.com/ncruces/go-strftime v0.1.9 ## explicit; go 1.17 github.com/ncruces/go-strftime @@ -401,6 +413,9 @@ github.com/tklauser/go-sysconf # github.com/tklauser/numcpus v0.11.0 ## explicit; go 1.24.0 github.com/tklauser/numcpus +# github.com/yosida95/uritemplate/v3 v3.0.2 +## explicit; go 1.14 +github.com/yosida95/uritemplate/v3 # github.com/yusufpapurcu/wmi v1.2.4 ## explicit; go 1.16 github.com/yusufpapurcu/wmi @@ -542,8 +557,8 @@ golang.org/x/net/idna golang.org/x/net/internal/httpcommon golang.org/x/net/internal/timeseries golang.org/x/net/trace -# golang.org/x/oauth2 v0.26.0 -## explicit; go 1.18 +# golang.org/x/oauth2 v0.30.0 +## explicit; go 1.23.0 golang.org/x/oauth2 golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/internal