From 2c91ced0ff834a0160b44f5ce783bb25ffd4ac13 Mon Sep 17 00:00:00 2001 From: Russell Haering Date: Fri, 9 Jan 2026 16:04:08 -0800 Subject: [PATCH 1/2] Add MCP server support for connectors Add Model Context Protocol (MCP) server support to the baton-sdk, allowing connectors to expose their functionality via MCP as an alternative to gRPC. New pkg/mcp package provides: - MCPServer that wraps ConnectorServer and exposes tools via stdio transport - Dynamic tool registration based on connector capabilities - Handlers for read operations (list_resource_types, list_resources, etc.) - Handlers for provisioning (grant, revoke, create_resource, delete_resource) - Handlers for ticketing (list_ticket_schemas, create_ticket, get_ticket) Usage: connectors now support an `mcp` subcommand that starts the MCP server on stdio, enabling AI assistants to interact with the connector. --- go.mod | 7 + go.sum | 15 + pkg/cli/commands.go | 69 + pkg/config/config.go | 11 + pkg/mcp/convert.go | 74 + pkg/mcp/handlers.go | 605 +++++++ pkg/mcp/server.go | 338 ++++ .../github.com/bahlo/generic-list-go/LICENSE | 27 + .../bahlo/generic-list-go/README.md | 5 + .../github.com/bahlo/generic-list-go/list.go | 235 +++ vendor/github.com/buger/jsonparser/.gitignore | 12 + .../github.com/buger/jsonparser/.travis.yml | 11 + vendor/github.com/buger/jsonparser/Dockerfile | 12 + vendor/github.com/buger/jsonparser/LICENSE | 21 + vendor/github.com/buger/jsonparser/Makefile | 36 + vendor/github.com/buger/jsonparser/README.md | 365 +++++ vendor/github.com/buger/jsonparser/bytes.go | 47 + .../github.com/buger/jsonparser/bytes_safe.go | 25 + .../buger/jsonparser/bytes_unsafe.go | 44 + vendor/github.com/buger/jsonparser/escape.go | 173 ++ vendor/github.com/buger/jsonparser/fuzz.go | 117 ++ .../buger/jsonparser/oss-fuzz-build.sh | 47 + vendor/github.com/buger/jsonparser/parser.go | 1283 +++++++++++++++ .../github.com/invopop/jsonschema/.gitignore | 2 + .../invopop/jsonschema/.golangci.yml | 69 + vendor/github.com/invopop/jsonschema/COPYING | 19 + .../github.com/invopop/jsonschema/README.md | 374 +++++ vendor/github.com/invopop/jsonschema/id.go | 76 + .../github.com/invopop/jsonschema/reflect.go | 1148 +++++++++++++ .../invopop/jsonschema/reflect_comments.go | 146 ++ .../github.com/invopop/jsonschema/schema.go | 94 ++ vendor/github.com/invopop/jsonschema/utils.go | 26 + vendor/github.com/mailru/easyjson/LICENSE | 7 + .../github.com/mailru/easyjson/buffer/pool.go | 278 ++++ .../mailru/easyjson/jwriter/writer.go | 405 +++++ vendor/github.com/mark3labs/mcp-go/LICENSE | 21 + .../github.com/mark3labs/mcp-go/mcp/consts.go | 9 + .../github.com/mark3labs/mcp-go/mcp/errors.go | 85 + .../mark3labs/mcp-go/mcp/prompts.go | 176 ++ .../mark3labs/mcp-go/mcp/resources.go | 99 ++ .../github.com/mark3labs/mcp-go/mcp/tools.go | 1331 +++++++++++++++ .../mark3labs/mcp-go/mcp/typed_tools.go | 42 + .../github.com/mark3labs/mcp-go/mcp/types.go | 1252 ++++++++++++++ .../github.com/mark3labs/mcp-go/mcp/utils.go | 979 +++++++++++ .../mark3labs/mcp-go/server/constants.go | 7 + .../github.com/mark3labs/mcp-go/server/ctx.go | 8 + .../mark3labs/mcp-go/server/elicitation.go | 32 + .../mark3labs/mcp-go/server/errors.go | 36 + .../mark3labs/mcp-go/server/hooks.go | 532 ++++++ .../mcp-go/server/http_transport_options.go | 11 + .../mcp-go/server/inprocess_session.go | 165 ++ .../mcp-go/server/request_handler.go | 339 ++++ .../mark3labs/mcp-go/server/roots.go | 32 + .../mark3labs/mcp-go/server/sampling.go | 61 + .../mark3labs/mcp-go/server/server.go | 1337 +++++++++++++++ .../mark3labs/mcp-go/server/session.go | 770 +++++++++ .../github.com/mark3labs/mcp-go/server/sse.go | 797 +++++++++ .../mark3labs/mcp-go/server/stdio.go | 877 ++++++++++ .../mcp-go/server/streamable_http.go | 1434 +++++++++++++++++ .../mark3labs/mcp-go/util/logger.go | 33 + .../wk8/go-ordered-map/v2/.gitignore | 1 + .../wk8/go-ordered-map/v2/.golangci.yml | 80 + .../wk8/go-ordered-map/v2/CHANGELOG.md | 38 + .../github.com/wk8/go-ordered-map/v2/LICENSE | 201 +++ .../github.com/wk8/go-ordered-map/v2/Makefile | 32 + .../wk8/go-ordered-map/v2/README.md | 154 ++ .../github.com/wk8/go-ordered-map/v2/json.go | 182 +++ .../wk8/go-ordered-map/v2/orderedmap.go | 296 ++++ .../github.com/wk8/go-ordered-map/v2/yaml.go | 71 + .../yosida95/uritemplate/v3/LICENSE | 25 + .../yosida95/uritemplate/v3/README.rst | 46 + .../yosida95/uritemplate/v3/compile.go | 224 +++ .../yosida95/uritemplate/v3/equals.go | 53 + .../yosida95/uritemplate/v3/error.go | 16 + .../yosida95/uritemplate/v3/escape.go | 190 +++ .../yosida95/uritemplate/v3/expression.go | 173 ++ .../yosida95/uritemplate/v3/machine.go | 23 + .../yosida95/uritemplate/v3/match.go | 213 +++ .../yosida95/uritemplate/v3/parse.go | 277 ++++ .../yosida95/uritemplate/v3/prog.go | 130 ++ .../yosida95/uritemplate/v3/uritemplate.go | 116 ++ .../yosida95/uritemplate/v3/value.go | 216 +++ vendor/modules.txt | 24 + 83 files changed, 19469 insertions(+) create mode 100644 pkg/mcp/convert.go create mode 100644 pkg/mcp/handlers.go create mode 100644 pkg/mcp/server.go create mode 100644 vendor/github.com/bahlo/generic-list-go/LICENSE create mode 100644 vendor/github.com/bahlo/generic-list-go/README.md create mode 100644 vendor/github.com/bahlo/generic-list-go/list.go create mode 100644 vendor/github.com/buger/jsonparser/.gitignore create mode 100644 vendor/github.com/buger/jsonparser/.travis.yml create mode 100644 vendor/github.com/buger/jsonparser/Dockerfile create mode 100644 vendor/github.com/buger/jsonparser/LICENSE create mode 100644 vendor/github.com/buger/jsonparser/Makefile create mode 100644 vendor/github.com/buger/jsonparser/README.md create mode 100644 vendor/github.com/buger/jsonparser/bytes.go create mode 100644 vendor/github.com/buger/jsonparser/bytes_safe.go create mode 100644 vendor/github.com/buger/jsonparser/bytes_unsafe.go create mode 100644 vendor/github.com/buger/jsonparser/escape.go create mode 100644 vendor/github.com/buger/jsonparser/fuzz.go create mode 100644 vendor/github.com/buger/jsonparser/oss-fuzz-build.sh create mode 100644 vendor/github.com/buger/jsonparser/parser.go create mode 100644 vendor/github.com/invopop/jsonschema/.gitignore create mode 100644 vendor/github.com/invopop/jsonschema/.golangci.yml create mode 100644 vendor/github.com/invopop/jsonschema/COPYING create mode 100644 vendor/github.com/invopop/jsonschema/README.md create mode 100644 vendor/github.com/invopop/jsonschema/id.go create mode 100644 vendor/github.com/invopop/jsonschema/reflect.go create mode 100644 vendor/github.com/invopop/jsonschema/reflect_comments.go create mode 100644 vendor/github.com/invopop/jsonschema/schema.go create mode 100644 vendor/github.com/invopop/jsonschema/utils.go create mode 100644 vendor/github.com/mailru/easyjson/LICENSE create mode 100644 vendor/github.com/mailru/easyjson/buffer/pool.go create mode 100644 vendor/github.com/mailru/easyjson/jwriter/writer.go create mode 100644 vendor/github.com/mark3labs/mcp-go/LICENSE create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/consts.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/errors.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/prompts.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/resources.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/tools.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/types.go create mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/utils.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/constants.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/ctx.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/elicitation.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/errors.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/hooks.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/request_handler.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/roots.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sampling.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/server.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/session.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/sse.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/stdio.go create mode 100644 vendor/github.com/mark3labs/mcp-go/server/streamable_http.go create mode 100644 vendor/github.com/mark3labs/mcp-go/util/logger.go create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/.gitignore create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/LICENSE create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/Makefile create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/README.md create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/json.go create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go create mode 100644 vendor/github.com/wk8/go-ordered-map/v2/yaml.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/LICENSE create mode 100644 vendor/github.com/yosida95/uritemplate/v3/README.rst create mode 100644 vendor/github.com/yosida95/uritemplate/v3/compile.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/equals.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/error.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/escape.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/expression.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/machine.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/match.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/parse.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/prog.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/uritemplate.go create mode 100644 vendor/github.com/yosida95/uritemplate/v3/value.go diff --git a/go.mod b/go.mod index 0a4ca6fbf..2e200c647 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/klauspost/compress v1.18.0 + github.com/mark3labs/mcp-go v0.43.2 github.com/maypok86/otter/v2 v2.2.1 github.com/mitchellh/mapstructure v1.5.0 github.com/pquerna/xjwt v0.3.0 @@ -76,7 +77,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.10 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.12 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.11 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -89,9 +92,11 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jellydator/ttlcache/v3 v3.3.0 // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect github.com/magiconair/properties v1.8.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -108,6 +113,8 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect diff --git a/go.sum b/go.sum index 4a1c84dad..1160f9dd6 100644 --- a/go.sum +++ b/go.sum @@ -52,9 +52,13 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.10 h1:g9d+TOsu3ac7SgmY2dUf1qMgu/u github.com/aws/aws-sdk-go-v2/service/sts v1.33.10/go.mod h1:WZfNmntu92HO44MVZAubQaz3qCuIdeOdog2sADfU6hU= github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -129,8 +133,11 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc= github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -149,6 +156,10 @@ github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 h1:7UMa6KCCMjZEMD github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM= github.com/magiconair/properties v1.8.9/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= @@ -219,6 +230,10 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index ae3f324e2..aceb553d2 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -26,6 +26,7 @@ import ( v1 "github.com/conductorone/baton-sdk/pb/c1/connector_wrapper/v1" baton_v1 "github.com/conductorone/baton-sdk/pb/c1/connectorapi/baton/v1" "github.com/conductorone/baton-sdk/pkg/connectorrunner" + mcpPkg "github.com/conductorone/baton-sdk/pkg/mcp" "github.com/conductorone/baton-sdk/pkg/crypto" "github.com/conductorone/baton-sdk/pkg/field" "github.com/conductorone/baton-sdk/pkg/logging" @@ -599,6 +600,74 @@ func MakeGRPCServerCommand[T field.Configurable]( } } +func MakeMCPServerCommand[T field.Configurable]( + ctx context.Context, + name string, + v *viper.Viper, + confschema field.Configuration, + getconnector GetConnectorFunc2[T], +) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + err := v.BindPFlags(cmd.Flags()) + if err != nil { + return err + } + + runCtx, err := initLogger( + ctx, + name, + logging.WithLogFormat(v.GetString("log-format")), + logging.WithLogLevel(v.GetString("log-level")), + ) + if err != nil { + return err + } + + runCtx, otelShutdown, err := initOtel(runCtx, name, v, nil) + if err != nil { + return err + } + defer func() { + if otelShutdown == nil { + return + } + shutdownCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(otelShutdownTimeout)) + defer cancel() + err := otelShutdown(shutdownCtx) + if err != nil { + zap.L().Error("error shutting down otel", zap.Error(err)) + } + }() + + l := ctxzap.Extract(runCtx) + l.Debug("starting MCP server") + + readFromPath := true + decodeOpts := field.WithAdditionalDecodeHooks(field.FileUploadDecodeHook(readFromPath)) + t, err := MakeGenericConfiguration[T](v, decodeOpts) + if err != nil { + return fmt.Errorf("failed to make configuration: %w", err) + } + + if err := field.Validate(confschema, t, field.WithAuthMethod(v.GetString("auth-method"))); err != nil { + return err + } + + c, err := getconnector(runCtx, t, RunTimeOpts{}) + if err != nil { + return err + } + + mcpServer, err := mcpPkg.NewMCPServer(runCtx, name, c) + if err != nil { + return fmt.Errorf("failed to create MCP server: %w", err) + } + + l.Info("MCP server starting on stdio") + return mcpServer.Serve() + } +} + func MakeCapabilitiesCommand[T field.Configurable]( ctx context.Context, name string, diff --git a/pkg/config/config.go b/pkg/config/config.go index d93045243..c4343f05a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -207,6 +207,17 @@ func DefineConfigurationV2[T field.Configurable]( return nil, nil, err } + _, err = cli.AddCommand(mainCMD, v, &schema, &cobra.Command{ + Use: "mcp", + Short: "Run as MCP server (stdio transport)", + Long: "Run the connector as an MCP (Model Context Protocol) server using stdio transport. This allows AI assistants to interact with the connector.", + RunE: cli.MakeMCPServerCommand(ctx, connectorName, v, confschema, connector), + }) + + if err != nil { + return nil, nil, err + } + _, err = cli.AddCommand(mainCMD, v, &schema, &cobra.Command{ Use: "capabilities", Short: "Get connector capabilities", diff --git a/pkg/mcp/convert.go b/pkg/mcp/convert.go new file mode 100644 index 000000000..b8c496d90 --- /dev/null +++ b/pkg/mcp/convert.go @@ -0,0 +1,74 @@ +package mcp + +import ( + "encoding/json" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +var ( + marshaler = protojson.MarshalOptions{ + EmitUnpopulated: false, + UseProtoNames: true, + } + + unmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, + } +) + +// protoToMap converts a proto message to a map[string]any. +func protoToMap(msg proto.Message) (map[string]any, error) { + if msg == nil { + return nil, nil + } + jsonBytes, err := marshaler.Marshal(msg) + if err != nil { + return nil, err + } + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, err + } + return result, nil +} + +// protoToJSON converts a proto message to a JSON string. +func protoToJSON(msg proto.Message) (string, error) { + if msg == nil { + return "{}", nil + } + jsonBytes, err := marshaler.Marshal(msg) + if err != nil { + return "", err + } + return string(jsonBytes), nil +} + +// protoListToMaps converts a slice of proto messages to a slice of maps. +func protoListToMaps[T proto.Message](list []T) ([]map[string]any, error) { + result := make([]map[string]any, 0, len(list)) + for _, item := range list { + m, err := protoToMap(item) + if err != nil { + return nil, err + } + result = append(result, m) + } + return result, nil +} + +// jsonToProto unmarshals a JSON string into a proto message. +func jsonToProto(jsonStr string, msg proto.Message) error { + return unmarshaler.Unmarshal([]byte(jsonStr), msg) +} + +// mapToProto converts a map to a proto message by going through JSON. +func mapToProto(m map[string]any, msg proto.Message) error { + jsonBytes, err := json.Marshal(m) + if err != nil { + return err + } + return unmarshaler.Unmarshal(jsonBytes, msg) +} diff --git a/pkg/mcp/handlers.go b/pkg/mcp/handlers.go new file mode 100644 index 000000000..37c2c6525 --- /dev/null +++ b/pkg/mcp/handlers.go @@ -0,0 +1,605 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" +) + +const defaultPageSize = 50 + +// handleGetMetadata handles the get_metadata tool. +func (m *MCPServer) handleGetMetadata(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resp, err := m.connector.GetMetadata(ctx, &v2.ConnectorServiceGetMetadataRequest{}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get metadata: %v", err)), nil + } + + result, err := protoToMap(resp.GetMetadata()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize metadata: %v", err)), nil + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleValidate handles the validate tool. +func (m *MCPServer) handleValidate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resp, err := m.connector.Validate(ctx, &v2.ConnectorServiceValidateRequest{}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("validation failed: %v", err)), nil + } + + result := map[string]any{ + "valid": true, + "annotations": resp.GetAnnotations(), + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleListResourceTypes handles the list_resource_types tool. +func (m *MCPServer) handleListResourceTypes(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + pageSize := getPageSize(req) + pageToken := getStringArg(req, "page_token") + + resp, err := m.connector.ListResourceTypes(ctx, &v2.ResourceTypesServiceListResourceTypesRequest{ + PageSize: pageSize, + PageToken: pageToken, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list resource types: %v", err)), nil + } + + resourceTypes, err := protoListToMaps(resp.GetList()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource types: %v", err)), nil + } + + result := map[string]any{ + "resource_types": resourceTypes, + "next_page_token": resp.GetNextPageToken(), + "has_more": resp.GetNextPageToken() != "", + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleListResources handles the list_resources tool. +func (m *MCPServer) handleListResources(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceTypeID, err := req.RequireString("resource_type_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type_id is required: %v", err)), nil + } + + pageSize := getPageSize(req) + pageToken := getStringArg(req, "page_token") + + // Build parent resource ID if specified. + var parentResourceID *v2.ResourceId + parentType := getStringArg(req, "parent_resource_type") + parentID := getStringArg(req, "parent_resource_id") + if parentType != "" && parentID != "" { + parentResourceID = &v2.ResourceId{ + ResourceType: parentType, + Resource: parentID, + } + } + + resp, err := m.connector.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: resourceTypeID, + ParentResourceId: parentResourceID, + PageSize: pageSize, + PageToken: pageToken, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list resources: %v", err)), nil + } + + resources, err := protoListToMaps(resp.GetList()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resources: %v", err)), nil + } + + result := map[string]any{ + "resources": resources, + "next_page_token": resp.GetNextPageToken(), + "has_more": resp.GetNextPageToken() != "", + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleGetResource handles the get_resource tool. +func (m *MCPServer) handleGetResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceType, err := req.RequireString("resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil + } + + resourceID, err := req.RequireString("resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil + } + + resp, err := m.connector.GetResource(ctx, &v2.ResourceGetterServiceGetResourceRequest{ + ResourceId: &v2.ResourceId{ + ResourceType: resourceType, + Resource: resourceID, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get resource: %v", err)), nil + } + + resource, err := protoToMap(resp.GetResource()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource: %v", err)), nil + } + + jsonBytes, err := json.Marshal(resource) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleListEntitlements handles the list_entitlements tool. +func (m *MCPServer) handleListEntitlements(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceType, err := req.RequireString("resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil + } + + resourceID, err := req.RequireString("resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil + } + + pageSize := getPageSize(req) + pageToken := getStringArg(req, "page_token") + + resp, err := m.connector.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: resourceType, + Resource: resourceID, + }, + }, + PageSize: pageSize, + PageToken: pageToken, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list entitlements: %v", err)), nil + } + + entitlements, err := protoListToMaps(resp.GetList()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize entitlements: %v", err)), nil + } + + result := map[string]any{ + "entitlements": entitlements, + "next_page_token": resp.GetNextPageToken(), + "has_more": resp.GetNextPageToken() != "", + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleListGrants handles the list_grants tool. +func (m *MCPServer) handleListGrants(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceType, err := req.RequireString("resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil + } + + resourceID, err := req.RequireString("resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil + } + + pageSize := getPageSize(req) + pageToken := getStringArg(req, "page_token") + + resp, err := m.connector.ListGrants(ctx, &v2.GrantsServiceListGrantsRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: resourceType, + Resource: resourceID, + }, + }, + PageSize: pageSize, + PageToken: pageToken, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list grants: %v", err)), nil + } + + grants, err := protoListToMaps(resp.GetList()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize grants: %v", err)), nil + } + + result := map[string]any{ + "grants": grants, + "next_page_token": resp.GetNextPageToken(), + "has_more": resp.GetNextPageToken() != "", + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleGrant handles the grant tool. +func (m *MCPServer) handleGrant(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + entResourceType, err := req.RequireString("entitlement_resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_type is required: %v", err)), nil + } + + entResourceID, err := req.RequireString("entitlement_resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_id is required: %v", err)), nil + } + + entID, err := req.RequireString("entitlement_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_id is required: %v", err)), nil + } + + principalType, err := req.RequireString("principal_resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("principal_resource_type is required: %v", err)), nil + } + + principalID, err := req.RequireString("principal_resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("principal_resource_id is required: %v", err)), nil + } + + resp, err := m.connector.Grant(ctx, &v2.GrantManagerServiceGrantRequest{ + Entitlement: &v2.Entitlement{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: entResourceType, + Resource: entResourceID, + }, + }, + Id: entID, + }, + Principal: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: principalType, + Resource: principalID, + }, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("grant failed: %v", err)), nil + } + + grants, err := protoListToMaps(resp.GetGrants()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize grants: %v", err)), nil + } + + result := map[string]any{ + "grants": grants, + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleRevoke handles the revoke tool. +func (m *MCPServer) handleRevoke(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + grantID, err := req.RequireString("grant_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("grant_id is required: %v", err)), nil + } + + entResourceType, err := req.RequireString("entitlement_resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_type is required: %v", err)), nil + } + + entResourceID, err := req.RequireString("entitlement_resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_id is required: %v", err)), nil + } + + entID, err := req.RequireString("entitlement_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("entitlement_id is required: %v", err)), nil + } + + principalType, err := req.RequireString("principal_resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("principal_resource_type is required: %v", err)), nil + } + + principalID, err := req.RequireString("principal_resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("principal_resource_id is required: %v", err)), nil + } + + _, err = m.connector.Revoke(ctx, &v2.GrantManagerServiceRevokeRequest{ + Grant: &v2.Grant{ + Id: grantID, + Entitlement: &v2.Entitlement{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: entResourceType, + Resource: entResourceID, + }, + }, + Id: entID, + }, + Principal: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: principalType, + Resource: principalID, + }, + }, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("revoke failed: %v", err)), nil + } + + result := map[string]any{ + "success": true, + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleCreateResource handles the create_resource tool. +func (m *MCPServer) handleCreateResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceType, err := req.RequireString("resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil + } + + displayName, err := req.RequireString("display_name") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("display_name is required: %v", err)), nil + } + + // Build parent resource if specified. + var parentResource *v2.Resource + parentType := getStringArg(req, "parent_resource_type") + parentID := getStringArg(req, "parent_resource_id") + if parentType != "" && parentID != "" { + parentResource = &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: parentType, + Resource: parentID, + }, + } + } + + resp, err := m.connector.CreateResource(ctx, &v2.CreateResourceRequest{ + Resource: &v2.Resource{ + Id: &v2.ResourceId{ + ResourceType: resourceType, + }, + DisplayName: displayName, + ParentResourceId: parentResource.GetId(), + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("create resource failed: %v", err)), nil + } + + resource, err := protoToMap(resp.GetCreated()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource: %v", err)), nil + } + + jsonBytes, err := json.Marshal(resource) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleDeleteResource handles the delete_resource tool. +func (m *MCPServer) handleDeleteResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceType, err := req.RequireString("resource_type") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil + } + + resourceID, err := req.RequireString("resource_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil + } + + _, err = m.connector.DeleteResource(ctx, &v2.DeleteResourceRequest{ + ResourceId: &v2.ResourceId{ + ResourceType: resourceType, + Resource: resourceID, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("delete resource failed: %v", err)), nil + } + + result := map[string]any{ + "success": true, + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleListTicketSchemas handles the list_ticket_schemas tool. +func (m *MCPServer) handleListTicketSchemas(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resp, err := m.connector.ListTicketSchemas(ctx, &v2.TicketsServiceListTicketSchemasRequest{}) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to list ticket schemas: %v", err)), nil + } + + schemas, err := protoListToMaps(resp.GetList()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket schemas: %v", err)), nil + } + + result := map[string]any{ + "schemas": schemas, + "next_page_token": resp.GetNextPageToken(), + "has_more": resp.GetNextPageToken() != "", + } + + jsonBytes, err := json.Marshal(result) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleCreateTicket handles the create_ticket tool. +func (m *MCPServer) handleCreateTicket(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + schemaID, err := req.RequireString("schema_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("schema_id is required: %v", err)), nil + } + + displayName, err := req.RequireString("display_name") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("display_name is required: %v", err)), nil + } + + description := getStringArg(req, "description") + + resp, err := m.connector.CreateTicket(ctx, &v2.TicketsServiceCreateTicketRequest{ + Schema: &v2.TicketSchema{ + Id: schemaID, + }, + Request: &v2.TicketRequest{ + DisplayName: displayName, + Description: description, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("create ticket failed: %v", err)), nil + } + + ticket, err := protoToMap(resp.GetTicket()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket: %v", err)), nil + } + + jsonBytes, err := json.Marshal(ticket) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// handleGetTicket handles the get_ticket tool. +func (m *MCPServer) handleGetTicket(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ticketID, err := req.RequireString("ticket_id") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("ticket_id is required: %v", err)), nil + } + + resp, err := m.connector.GetTicket(ctx, &v2.TicketsServiceGetTicketRequest{ + Id: ticketID, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("get ticket failed: %v", err)), nil + } + + ticket, err := protoToMap(resp.GetTicket()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket: %v", err)), nil + } + + jsonBytes, err := json.Marshal(ticket) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + return mcp.NewToolResultText(string(jsonBytes)), nil +} + +// Helper functions. + +func getPageSize(req mcp.CallToolRequest) uint32 { + args := req.GetArguments() + if args == nil { + return defaultPageSize + } + if ps, ok := args["page_size"]; ok { + if psFloat, ok := ps.(float64); ok { + return uint32(psFloat) + } + } + return defaultPageSize +} + +func getStringArg(req mcp.CallToolRequest, name string) string { + args := req.GetArguments() + if args == nil { + return "" + } + if v, ok := args[name]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} diff --git a/pkg/mcp/server.go b/pkg/mcp/server.go new file mode 100644 index 000000000..42a11fb52 --- /dev/null +++ b/pkg/mcp/server.go @@ -0,0 +1,338 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/types" +) + +// MCPServer wraps a ConnectorServer and exposes its functionality via MCP. +type MCPServer struct { + connector types.ConnectorServer + server *server.MCPServer + caps *v2.ConnectorCapabilities +} + +// NewMCPServer creates a new MCP server that wraps the given ConnectorServer. +func NewMCPServer(ctx context.Context, name string, connector types.ConnectorServer) (*MCPServer, error) { + // Get connector metadata to determine capabilities. + metaResp, err := connector.GetMetadata(ctx, &v2.ConnectorServiceGetMetadataRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to get connector metadata: %w", err) + } + + s := server.NewMCPServer( + name, + "1.0.0", + server.WithToolCapabilities(false), + server.WithRecovery(), + ) + + m := &MCPServer{ + connector: connector, + server: s, + caps: metaResp.GetMetadata().GetCapabilities(), + } + + m.registerTools() + return m, nil +} + +// Serve starts the MCP server on stdio. +func (m *MCPServer) Serve() error { + return server.ServeStdio(m.server) +} + +// registerTools registers all MCP tools based on connector capabilities. +func (m *MCPServer) registerTools() { + // Always register read-only tools. + m.registerReadTools() + + // Register provisioning tools if the connector supports provisioning. + if m.hasCapability(v2.Capability_CAPABILITY_PROVISION) { + m.registerProvisioningTools() + } + + // Register ticketing tools if the connector supports ticketing. + if m.hasCapability(v2.Capability_CAPABILITY_TICKETING) { + m.registerTicketingTools() + } +} + +// hasCapability checks if the connector has the given capability. +func (m *MCPServer) hasCapability(cap v2.Capability) bool { + if m.caps == nil { + return false + } + for _, c := range m.caps.GetConnectorCapabilities() { + if c == cap { + return true + } + } + return false +} + +// registerReadTools registers read-only tools that are always available. +func (m *MCPServer) registerReadTools() { + // get_metadata - Get connector metadata and capabilities. + m.server.AddTool( + mcp.NewTool("get_metadata", + mcp.WithDescription("Get connector metadata including display name, description, and capabilities"), + ), + m.handleGetMetadata, + ) + + // validate - Validate connector configuration. + m.server.AddTool( + mcp.NewTool("validate", + mcp.WithDescription("Validate the connector configuration and connectivity"), + ), + m.handleValidate, + ) + + // list_resource_types - List available resource types. + m.server.AddTool( + mcp.NewTool("list_resource_types", + mcp.WithDescription("List all resource types supported by this connector"), + mcp.WithNumber("page_size", + mcp.Description("Number of items per page (default 50)"), + ), + mcp.WithString("page_token", + mcp.Description("Pagination token from previous response"), + ), + ), + m.handleListResourceTypes, + ) + + // list_resources - List resources of a specific type. + m.server.AddTool( + mcp.NewTool("list_resources", + mcp.WithDescription("List resources of a specific type"), + mcp.WithString("resource_type_id", + mcp.Required(), + mcp.Description("The resource type ID to list (e.g., 'user', 'group')"), + ), + mcp.WithString("parent_resource_type", + mcp.Description("Parent resource type (optional, for hierarchical resources)"), + ), + mcp.WithString("parent_resource_id", + mcp.Description("Parent resource ID (optional, for hierarchical resources)"), + ), + mcp.WithNumber("page_size", + mcp.Description("Number of items per page (default 50)"), + ), + mcp.WithString("page_token", + mcp.Description("Pagination token from previous response"), + ), + ), + m.handleListResources, + ) + + // get_resource - Get a specific resource by ID. + m.server.AddTool( + mcp.NewTool("get_resource", + mcp.WithDescription("Get a specific resource by its type and ID"), + mcp.WithString("resource_type", + mcp.Required(), + mcp.Description("The resource type (e.g., 'user', 'group')"), + ), + mcp.WithString("resource_id", + mcp.Required(), + mcp.Description("The resource ID"), + ), + ), + m.handleGetResource, + ) + + // list_entitlements - List entitlements for a resource. + m.server.AddTool( + mcp.NewTool("list_entitlements", + mcp.WithDescription("List entitlements (permissions, roles, memberships) for a resource"), + mcp.WithString("resource_type", + mcp.Required(), + mcp.Description("The resource type"), + ), + mcp.WithString("resource_id", + mcp.Required(), + mcp.Description("The resource ID"), + ), + mcp.WithNumber("page_size", + mcp.Description("Number of items per page (default 50)"), + ), + mcp.WithString("page_token", + mcp.Description("Pagination token from previous response"), + ), + ), + m.handleListEntitlements, + ) + + // list_grants - List grants for a resource. + m.server.AddTool( + mcp.NewTool("list_grants", + mcp.WithDescription("List grants (who has what access) for a resource"), + mcp.WithString("resource_type", + mcp.Required(), + mcp.Description("The resource type"), + ), + mcp.WithString("resource_id", + mcp.Required(), + mcp.Description("The resource ID"), + ), + mcp.WithNumber("page_size", + mcp.Description("Number of items per page (default 50)"), + ), + mcp.WithString("page_token", + mcp.Description("Pagination token from previous response"), + ), + ), + m.handleListGrants, + ) +} + +// registerProvisioningTools registers tools for provisioning operations. +func (m *MCPServer) registerProvisioningTools() { + // grant - Grant an entitlement to a principal. + m.server.AddTool( + mcp.NewTool("grant", + mcp.WithDescription("Grant an entitlement to a principal (user or group)"), + mcp.WithString("entitlement_resource_type", + mcp.Required(), + mcp.Description("Resource type of the entitlement"), + ), + mcp.WithString("entitlement_resource_id", + mcp.Required(), + mcp.Description("Resource ID of the entitlement"), + ), + mcp.WithString("entitlement_id", + mcp.Required(), + mcp.Description("The entitlement ID"), + ), + mcp.WithString("principal_resource_type", + mcp.Required(), + mcp.Description("Resource type of the principal (e.g., 'user', 'group')"), + ), + mcp.WithString("principal_resource_id", + mcp.Required(), + mcp.Description("Resource ID of the principal"), + ), + ), + m.handleGrant, + ) + + // revoke - Revoke a grant. + m.server.AddTool( + mcp.NewTool("revoke", + mcp.WithDescription("Revoke a grant from a principal"), + mcp.WithString("grant_id", + mcp.Required(), + mcp.Description("The grant ID to revoke"), + ), + mcp.WithString("entitlement_resource_type", + mcp.Required(), + mcp.Description("Resource type of the entitlement"), + ), + mcp.WithString("entitlement_resource_id", + mcp.Required(), + mcp.Description("Resource ID of the entitlement"), + ), + mcp.WithString("entitlement_id", + mcp.Required(), + mcp.Description("The entitlement ID"), + ), + mcp.WithString("principal_resource_type", + mcp.Required(), + mcp.Description("Resource type of the principal"), + ), + mcp.WithString("principal_resource_id", + mcp.Required(), + mcp.Description("Resource ID of the principal"), + ), + ), + m.handleRevoke, + ) + + // create_resource - Create a new resource. + m.server.AddTool( + mcp.NewTool("create_resource", + mcp.WithDescription("Create a new resource"), + mcp.WithString("resource_type", + mcp.Required(), + mcp.Description("The resource type to create"), + ), + mcp.WithString("display_name", + mcp.Required(), + mcp.Description("Display name for the new resource"), + ), + mcp.WithString("parent_resource_type", + mcp.Description("Parent resource type (optional)"), + ), + mcp.WithString("parent_resource_id", + mcp.Description("Parent resource ID (optional)"), + ), + ), + m.handleCreateResource, + ) + + // delete_resource - Delete a resource. + m.server.AddTool( + mcp.NewTool("delete_resource", + mcp.WithDescription("Delete a resource"), + mcp.WithString("resource_type", + mcp.Required(), + mcp.Description("The resource type"), + ), + mcp.WithString("resource_id", + mcp.Required(), + mcp.Description("The resource ID to delete"), + ), + ), + m.handleDeleteResource, + ) +} + +// registerTicketingTools registers tools for ticketing operations. +func (m *MCPServer) registerTicketingTools() { + // list_ticket_schemas - List available ticket schemas. + m.server.AddTool( + mcp.NewTool("list_ticket_schemas", + mcp.WithDescription("List available ticket schemas"), + ), + m.handleListTicketSchemas, + ) + + // create_ticket - Create a new ticket. + m.server.AddTool( + mcp.NewTool("create_ticket", + mcp.WithDescription("Create a new ticket"), + mcp.WithString("schema_id", + mcp.Required(), + mcp.Description("The ticket schema ID"), + ), + mcp.WithString("display_name", + mcp.Required(), + mcp.Description("Display name for the ticket"), + ), + mcp.WithString("description", + mcp.Description("Description of the ticket"), + ), + ), + m.handleCreateTicket, + ) + + // get_ticket - Get a ticket by ID. + m.server.AddTool( + mcp.NewTool("get_ticket", + mcp.WithDescription("Get a ticket by ID"), + mcp.WithString("ticket_id", + mcp.Required(), + mcp.Description("The ticket ID"), + ), + ), + m.handleGetTicket, + ) +} diff --git a/vendor/github.com/bahlo/generic-list-go/LICENSE b/vendor/github.com/bahlo/generic-list-go/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/bahlo/generic-list-go/README.md b/vendor/github.com/bahlo/generic-list-go/README.md new file mode 100644 index 000000000..68bbce9fb --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/README.md @@ -0,0 +1,5 @@ +# generic-list-go [![CI](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml/badge.svg)](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml) + +Go [container/list](https://pkg.go.dev/container/list) but with generics. + +The code is based on `container/list` in `go1.18beta2`. diff --git a/vendor/github.com/bahlo/generic-list-go/list.go b/vendor/github.com/bahlo/generic-list-go/list.go new file mode 100644 index 000000000..a06a7c612 --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/list.go @@ -0,0 +1,235 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package list implements a doubly linked list. +// +// To iterate over a list (where l is a *List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e.Value +// } +// +package list + +// Element is an element of a linked list. +type Element[T any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] + + // The list to which this element belongs. + list *List[T] + + // The value stored with this element. + Value T +} + +// Next returns the next list element or nil. +func (e *Element[T]) Next() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *Element[T]) Prev() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[T]) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[T]) Front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[T]) Back() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[T]) Remove(e *Element[T]) T { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *List[T]) PushFront(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *List[T]) PushBack(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark) +} + +// PushBackList inserts a copy of another list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushBackList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of another list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushFrontList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/buger/jsonparser/.gitignore b/vendor/github.com/buger/jsonparser/.gitignore new file mode 100644 index 000000000..5598d8a56 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/.gitignore @@ -0,0 +1,12 @@ + +*.test + +*.out + +*.mprof + +.idea + +vendor/github.com/buger/goterm/ +prof.cpu +prof.mem diff --git a/vendor/github.com/buger/jsonparser/.travis.yml b/vendor/github.com/buger/jsonparser/.travis.yml new file mode 100644 index 000000000..dbfb7cf98 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/.travis.yml @@ -0,0 +1,11 @@ +language: go +arch: + - amd64 + - ppc64le +go: + - 1.7.x + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x +script: go test -v ./. diff --git a/vendor/github.com/buger/jsonparser/Dockerfile b/vendor/github.com/buger/jsonparser/Dockerfile new file mode 100644 index 000000000..37fc9fd0b --- /dev/null +++ b/vendor/github.com/buger/jsonparser/Dockerfile @@ -0,0 +1,12 @@ +FROM golang:1.6 + +RUN go get github.com/Jeffail/gabs +RUN go get github.com/bitly/go-simplejson +RUN go get github.com/pquerna/ffjson +RUN go get github.com/antonholmquist/jason +RUN go get github.com/mreiferson/go-ujson +RUN go get -tags=unsafe -u github.com/ugorji/go/codec +RUN go get github.com/mailru/easyjson + +WORKDIR /go/src/github.com/buger/jsonparser +ADD . /go/src/github.com/buger/jsonparser \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/LICENSE b/vendor/github.com/buger/jsonparser/LICENSE new file mode 100644 index 000000000..ac25aeb7d --- /dev/null +++ b/vendor/github.com/buger/jsonparser/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 Leonid Bugaev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/buger/jsonparser/Makefile b/vendor/github.com/buger/jsonparser/Makefile new file mode 100644 index 000000000..e843368cf --- /dev/null +++ b/vendor/github.com/buger/jsonparser/Makefile @@ -0,0 +1,36 @@ +SOURCE = parser.go +CONTAINER = jsonparser +SOURCE_PATH = /go/src/github.com/buger/jsonparser +BENCHMARK = JsonParser +BENCHTIME = 5s +TEST = . +DRUN = docker run -v `pwd`:$(SOURCE_PATH) -i -t $(CONTAINER) + +build: + docker build -t $(CONTAINER) . + +race: + $(DRUN) --env GORACE="halt_on_error=1" go test ./. $(ARGS) -v -race -timeout 15s + +bench: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -benchtime $(BENCHTIME) -v + +bench_local: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench . $(ARGS) -benchtime $(BENCHTIME) -v + +profile: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -memprofile mem.mprof -v + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -cpuprofile cpu.out -v + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -c + +test: + $(DRUN) go test $(LDFLAGS) ./ -run $(TEST) -timeout 10s $(ARGS) -v + +fmt: + $(DRUN) go fmt ./... + +vet: + $(DRUN) go vet ./. + +bash: + $(DRUN) /bin/bash \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/README.md b/vendor/github.com/buger/jsonparser/README.md new file mode 100644 index 000000000..d7e0ec397 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/README.md @@ -0,0 +1,365 @@ +[![Go Report Card](https://goreportcard.com/badge/github.com/buger/jsonparser)](https://goreportcard.com/report/github.com/buger/jsonparser) ![License](https://img.shields.io/dub/l/vibe-d.svg) +# Alternative JSON parser for Go (10x times faster standard library) + +It does not require you to know the structure of the payload (eg. create structs), and allows accessing fields by providing the path to them. It is up to **10 times faster** than standard `encoding/json` package (depending on payload size and usage), **allocates no memory**. See benchmarks below. + +## Rationale +Originally I made this for a project that relies on a lot of 3rd party APIs that can be unpredictable and complex. +I love simplicity and prefer to avoid external dependecies. `encoding/json` requires you to know exactly your data structures, or if you prefer to use `map[string]interface{}` instead, it will be very slow and hard to manage. +I investigated what's on the market and found that most libraries are just wrappers around `encoding/json`, there is few options with own parsers (`ffjson`, `easyjson`), but they still requires you to create data structures. + + +Goal of this project is to push JSON parser to the performance limits and not sacrifice with compliance and developer user experience. + +## Example +For the given JSON our goal is to extract the user's full name, number of github followers and avatar. + +```go +import "github.com/buger/jsonparser" + +... + +data := []byte(`{ + "person": { + "name": { + "first": "Leonid", + "last": "Bugaev", + "fullName": "Leonid Bugaev" + }, + "github": { + "handle": "buger", + "followers": 109 + }, + "avatars": [ + { "url": "https://avatars1.githubusercontent.com/u/14009?v=3&s=460", "type": "thumbnail" } + ] + }, + "company": { + "name": "Acme" + } +}`) + +// You can specify key path by providing arguments to Get function +jsonparser.Get(data, "person", "name", "fullName") + +// There is `GetInt` and `GetBoolean` helpers if you exactly know key data type +jsonparser.GetInt(data, "person", "github", "followers") + +// When you try to get object, it will return you []byte slice pointer to data containing it +// In `company` it will be `{"name": "Acme"}` +jsonparser.Get(data, "company") + +// If the key doesn't exist it will throw an error +var size int64 +if value, err := jsonparser.GetInt(data, "company", "size"); err == nil { + size = value +} + +// You can use `ArrayEach` helper to iterate items [item1, item2 .... itemN] +jsonparser.ArrayEach(data, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { + fmt.Println(jsonparser.Get(value, "url")) +}, "person", "avatars") + +// Or use can access fields by index! +jsonparser.GetString(data, "person", "avatars", "[0]", "url") + +// You can use `ObjectEach` helper to iterate objects { "key1":object1, "key2":object2, .... "keyN":objectN } +jsonparser.ObjectEach(data, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { + fmt.Printf("Key: '%s'\n Value: '%s'\n Type: %s\n", string(key), string(value), dataType) + return nil +}, "person", "name") + +// The most efficient way to extract multiple keys is `EachKey` + +paths := [][]string{ + []string{"person", "name", "fullName"}, + []string{"person", "avatars", "[0]", "url"}, + []string{"company", "url"}, +} +jsonparser.EachKey(data, func(idx int, value []byte, vt jsonparser.ValueType, err error){ + switch idx { + case 0: // []string{"person", "name", "fullName"} + ... + case 1: // []string{"person", "avatars", "[0]", "url"} + ... + case 2: // []string{"company", "url"}, + ... + } +}, paths...) + +// For more information see docs below +``` + +## Need to speedup your app? + +I'm available for consulting and can help you push your app performance to the limits. Ping me at: leonsbox@gmail.com. + +## Reference + +Library API is really simple. You just need the `Get` method to perform any operation. The rest is just helpers around it. + +You also can view API at [godoc.org](https://godoc.org/github.com/buger/jsonparser) + + +### **`Get`** +```go +func Get(data []byte, keys ...string) (value []byte, dataType jsonparser.ValueType, offset int, err error) +``` +Receives data structure, and key path to extract value from. + +Returns: +* `value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error +* `dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` +* `offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. +* `err` - If the key is not found or any other parsing issue, it should return error. If key not found it also sets `dataType` to `NotExist` + +Accepts multiple keys to specify path to JSON value (in case of quering nested structures). +If no keys are provided it will try to extract the closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. + +Note that keys can be an array indexes: `jsonparser.GetInt("person", "avatars", "[0]", "url")`, pretty cool, yeah? + +### **`GetString`** +```go +func GetString(data []byte, keys ...string) (val string, err error) +``` +Returns strings properly handing escaped and unicode characters. Note that this will cause additional memory allocations. + +### **`GetUnsafeString`** +If you need string in your app, and ready to sacrifice with support of escaped symbols in favor of speed. It returns string mapped to existing byte slice memory, without any allocations: +```go +s, _, := jsonparser.GetUnsafeString(data, "person", "name", "title") +switch s { + case 'CEO': + ... + case 'Engineer' + ... + ... +} +``` +Note that `unsafe` here means that your string will exist until GC will free underlying byte slice, for most of cases it means that you can use this string only in current context, and should not pass it anywhere externally: through channels or any other way. + + +### **`GetBoolean`**, **`GetInt`** and **`GetFloat`** +```go +func GetBoolean(data []byte, keys ...string) (val bool, err error) + +func GetFloat(data []byte, keys ...string) (val float64, err error) + +func GetInt(data []byte, keys ...string) (val int64, err error) +``` +If you know the key type, you can use the helpers above. +If key data type do not match, it will return error. + +### **`ArrayEach`** +```go +func ArrayEach(data []byte, cb func(value []byte, dataType jsonparser.ValueType, offset int, err error), keys ...string) +``` +Needed for iterating arrays, accepts a callback function with the same return arguments as `Get`. + +### **`ObjectEach`** +```go +func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) +``` +Needed for iterating object, accepts a callback function. Example: +```go +var handler func([]byte, []byte, jsonparser.ValueType, int) error +handler = func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { + //do stuff here +} +jsonparser.ObjectEach(myJson, handler) +``` + + +### **`EachKey`** +```go +func EachKey(data []byte, cb func(idx int, value []byte, dataType jsonparser.ValueType, err error), paths ...[]string) +``` +When you need to read multiple keys, and you do not afraid of low-level API `EachKey` is your friend. It read payload only single time, and calls callback function once path is found. For example when you call multiple times `Get`, it has to process payload multiple times, each time you call it. Depending on payload `EachKey` can be multiple times faster than `Get`. Path can use nested keys as well! + +```go +paths := [][]string{ + []string{"uuid"}, + []string{"tz"}, + []string{"ua"}, + []string{"st"}, +} +var data SmallPayload + +jsonparser.EachKey(smallFixture, func(idx int, value []byte, vt jsonparser.ValueType, err error){ + switch idx { + case 0: + data.Uuid, _ = value + case 1: + v, _ := jsonparser.ParseInt(value) + data.Tz = int(v) + case 2: + data.Ua, _ = value + case 3: + v, _ := jsonparser.ParseInt(value) + data.St = int(v) + } +}, paths...) +``` + +### **`Set`** +```go +func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) +``` +Receives existing data structure, key path to set, and value to set at that key. *This functionality is experimental.* + +Returns: +* `value` - Pointer to original data structure with updated or added key value. +* `err` - If any parsing issue, it should return error. + +Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). + +Note that keys can be an array indexes: `jsonparser.Set(data, []byte("http://github.com"), "person", "avatars", "[0]", "url")` + +### **`Delete`** +```go +func Delete(data []byte, keys ...string) value []byte +``` +Receives existing data structure, and key path to delete. *This functionality is experimental.* + +Returns: +* `value` - Pointer to original data structure with key path deleted if it can be found. If there is no key path, then the whole data structure is deleted. + +Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). + +Note that keys can be an array indexes: `jsonparser.Delete(data, "person", "avatars", "[0]", "url")` + + +## What makes it so fast? +* It does not rely on `encoding/json`, `reflection` or `interface{}`, the only real package dependency is `bytes`. +* Operates with JSON payload on byte level, providing you pointers to the original data structure: no memory allocation. +* No automatic type conversions, by default everything is a []byte, but it provides you value type, so you can convert by yourself (there is few helpers included). +* Does not parse full record, only keys you specified + + +## Benchmarks + +There are 3 benchmark types, trying to simulate real-life usage for small, medium and large JSON payloads. +For each metric, the lower value is better. Time/op is in nanoseconds. Values better than standard encoding/json marked as bold text. +Benchmarks run on standard Linode 1024 box. + +Compared libraries: +* https://golang.org/pkg/encoding/json +* https://github.com/Jeffail/gabs +* https://github.com/a8m/djson +* https://github.com/bitly/go-simplejson +* https://github.com/antonholmquist/jason +* https://github.com/mreiferson/go-ujson +* https://github.com/ugorji/go/codec +* https://github.com/pquerna/ffjson +* https://github.com/mailru/easyjson +* https://github.com/buger/jsonparser + +#### TLDR +If you want to skip next sections we have 2 winner: `jsonparser` and `easyjson`. +`jsonparser` is up to 10 times faster than standard `encoding/json` package (depending on payload size and usage), and almost infinitely (literally) better in memory consumption because it operates with data on byte level, and provide direct slice pointers. +`easyjson` wins in CPU in medium tests and frankly i'm impressed with this package: it is remarkable results considering that it is almost drop-in replacement for `encoding/json` (require some code generation). + +It's hard to fully compare `jsonparser` and `easyjson` (or `ffson`), they a true parsers and fully process record, unlike `jsonparser` which parse only keys you specified. + +If you searching for replacement of `encoding/json` while keeping structs, `easyjson` is an amazing choice. If you want to process dynamic JSON, have memory constrains, or more control over your data you should try `jsonparser`. + +`jsonparser` performance heavily depends on usage, and it works best when you do not need to process full record, only some keys. The more calls you need to make, the slower it will be, in contrast `easyjson` (or `ffjson`, `encoding/json`) parser record only 1 time, and then you can make as many calls as you want. + +With great power comes great responsibility! :) + + +#### Small payload + +Each test processes 190 bytes of http log as a JSON record. +It should read multiple fields. +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_small_payload_test.go + +Library | time/op | bytes/op | allocs/op + ------ | ------- | -------- | ------- +encoding/json struct | 7879 | 880 | 18 +encoding/json interface{} | 8946 | 1521 | 38 +Jeffail/gabs | 10053 | 1649 | 46 +bitly/go-simplejson | 10128 | 2241 | 36 +antonholmquist/jason | 27152 | 7237 | 101 +github.com/ugorji/go/codec | 8806 | 2176 | 31 +mreiferson/go-ujson | **7008** | **1409** | 37 +a8m/djson | 3862 | 1249 | 30 +pquerna/ffjson | **3769** | **624** | **15** +mailru/easyjson | **2002** | **192** | **9** +buger/jsonparser | **1367** | **0** | **0** +buger/jsonparser (EachKey API) | **809** | **0** | **0** + +Winners are ffjson, easyjson and jsonparser, where jsonparser is up to 9.8x faster than encoding/json and 4.6x faster than ffjson, and slightly faster than easyjson. +If you look at memory allocation, jsonparser has no rivals, as it makes no data copy and operates with raw []byte structures and pointers to it. + +#### Medium payload + +Each test processes a 2.4kb JSON record (based on Clearbit API). +It should read multiple nested fields and 1 array. + +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_medium_payload_test.go + +| Library | time/op | bytes/op | allocs/op | +| ------- | ------- | -------- | --------- | +| encoding/json struct | 57749 | 1336 | 29 | +| encoding/json interface{} | 79297 | 10627 | 215 | +| Jeffail/gabs | 83807 | 11202 | 235 | +| bitly/go-simplejson | 88187 | 17187 | 220 | +| antonholmquist/jason | 94099 | 19013 | 247 | +| github.com/ugorji/go/codec | 114719 | 6712 | 152 | +| mreiferson/go-ujson | **56972** | 11547 | 270 | +| a8m/djson | 28525 | 10196 | 198 | +| pquerna/ffjson | **20298** | **856** | **20** | +| mailru/easyjson | **10512** | **336** | **12** | +| buger/jsonparser | **15955** | **0** | **0** | +| buger/jsonparser (EachKey API) | **8916** | **0** | **0** | + +The difference between ffjson and jsonparser in CPU usage is smaller, while the memory consumption difference is growing. On the other hand `easyjson` shows remarkable performance for medium payload. + +`gabs`, `go-simplejson` and `jason` are based on encoding/json and map[string]interface{} and actually only helpers for unstructured JSON, their performance correlate with `encoding/json interface{}`, and they will skip next round. +`go-ujson` while have its own parser, shows same performance as `encoding/json`, also skips next round. Same situation with `ugorji/go/codec`, but it showed unexpectedly bad performance for complex payloads. + + +#### Large payload + +Each test processes a 24kb JSON record (based on Discourse API) +It should read 2 arrays, and for each item in array get a few fields. +Basically it means processing a full JSON file. + +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_large_payload_test.go + +| Library | time/op | bytes/op | allocs/op | +| --- | --- | --- | --- | +| encoding/json struct | 748336 | 8272 | 307 | +| encoding/json interface{} | 1224271 | 215425 | 3395 | +| a8m/djson | 510082 | 213682 | 2845 | +| pquerna/ffjson | **312271** | **7792** | **298** | +| mailru/easyjson | **154186** | **6992** | **288** | +| buger/jsonparser | **85308** | **0** | **0** | + +`jsonparser` now is a winner, but do not forget that it is way more lightweight parser than `ffson` or `easyjson`, and they have to parser all the data, while `jsonparser` parse only what you need. All `ffjson`, `easysjon` and `jsonparser` have their own parsing code, and does not depend on `encoding/json` or `interface{}`, thats one of the reasons why they are so fast. `easyjson` also use a bit of `unsafe` package to reduce memory consuption (in theory it can lead to some unexpected GC issue, but i did not tested enough) + +Also last benchmark did not included `EachKey` test, because in this particular case we need to read lot of Array values, and using `ArrayEach` is more efficient. + +## Questions and support + +All bug-reports and suggestions should go though Github Issues. + +## Contributing + +1. Fork it +2. Create your feature branch (git checkout -b my-new-feature) +3. Commit your changes (git commit -am 'Added some feature') +4. Push to the branch (git push origin my-new-feature) +5. Create new Pull Request + +## Development + +All my development happens using Docker, and repo include some Make tasks to simplify development. + +* `make build` - builds docker image, usually can be called only once +* `make test` - run tests +* `make fmt` - run go fmt +* `make bench` - run benchmarks (if you need to run only single benchmark modify `BENCHMARK` variable in make file) +* `make profile` - runs benchmark and generate 3 files- `cpu.out`, `mem.mprof` and `benchmark.test` binary, which can be used for `go tool pprof` +* `make bash` - enter container (i use it for running `go tool pprof` above) diff --git a/vendor/github.com/buger/jsonparser/bytes.go b/vendor/github.com/buger/jsonparser/bytes.go new file mode 100644 index 000000000..0bb0ff395 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes.go @@ -0,0 +1,47 @@ +package jsonparser + +import ( + bio "bytes" +) + +// minInt64 '-9223372036854775808' is the smallest representable number in int64 +const minInt64 = `9223372036854775808` + +// About 2x faster then strconv.ParseInt because it only supports base 10, which is enough for JSON +func parseInt(bytes []byte) (v int64, ok bool, overflow bool) { + if len(bytes) == 0 { + return 0, false, false + } + + var neg bool = false + if bytes[0] == '-' { + neg = true + bytes = bytes[1:] + } + + var b int64 = 0 + for _, c := range bytes { + if c >= '0' && c <= '9' { + b = (10 * v) + int64(c-'0') + } else { + return 0, false, false + } + if overflow = (b < v); overflow { + break + } + v = b + } + + if overflow { + if neg && bio.Equal(bytes, []byte(minInt64)) { + return b, true, false + } + return 0, false, true + } + + if neg { + return -v, true, false + } else { + return v, true, false + } +} diff --git a/vendor/github.com/buger/jsonparser/bytes_safe.go b/vendor/github.com/buger/jsonparser/bytes_safe.go new file mode 100644 index 000000000..ff16a4a19 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes_safe.go @@ -0,0 +1,25 @@ +// +build appengine appenginevm + +package jsonparser + +import ( + "strconv" +) + +// See fastbytes_unsafe.go for explanation on why *[]byte is used (signatures must be consistent with those in that file) + +func equalStr(b *[]byte, s string) bool { + return string(*b) == s +} + +func parseFloat(b *[]byte) (float64, error) { + return strconv.ParseFloat(string(*b), 64) +} + +func bytesToString(b *[]byte) string { + return string(*b) +} + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/buger/jsonparser/bytes_unsafe.go b/vendor/github.com/buger/jsonparser/bytes_unsafe.go new file mode 100644 index 000000000..589fea87e --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes_unsafe.go @@ -0,0 +1,44 @@ +// +build !appengine,!appenginevm + +package jsonparser + +import ( + "reflect" + "strconv" + "unsafe" + "runtime" +) + +// +// The reason for using *[]byte rather than []byte in parameters is an optimization. As of Go 1.6, +// the compiler cannot perfectly inline the function when using a non-pointer slice. That is, +// the non-pointer []byte parameter version is slower than if its function body is manually +// inlined, whereas the pointer []byte version is equally fast to the manually inlined +// version. Instruction count in assembly taken from "go tool compile" confirms this difference. +// +// TODO: Remove hack after Go 1.7 release +// +func equalStr(b *[]byte, s string) bool { + return *(*string)(unsafe.Pointer(b)) == s +} + +func parseFloat(b *[]byte) (float64, error) { + return strconv.ParseFloat(*(*string)(unsafe.Pointer(b)), 64) +} + +// A hack until issue golang/go#2632 is fixed. +// See: https://github.com/golang/go/issues/2632 +func bytesToString(b *[]byte) string { + return *(*string)(unsafe.Pointer(b)) +} + +func StringToBytes(s string) []byte { + b := make([]byte, 0, 0) + bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + bh.Data = sh.Data + bh.Cap = sh.Len + bh.Len = sh.Len + runtime.KeepAlive(s) + return b +} diff --git a/vendor/github.com/buger/jsonparser/escape.go b/vendor/github.com/buger/jsonparser/escape.go new file mode 100644 index 000000000..49669b942 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/escape.go @@ -0,0 +1,173 @@ +package jsonparser + +import ( + "bytes" + "unicode/utf8" +) + +// JSON Unicode stuff: see https://tools.ietf.org/html/rfc7159#section-7 + +const supplementalPlanesOffset = 0x10000 +const highSurrogateOffset = 0xD800 +const lowSurrogateOffset = 0xDC00 + +const basicMultilingualPlaneReservedOffset = 0xDFFF +const basicMultilingualPlaneOffset = 0xFFFF + +func combineUTF16Surrogates(high, low rune) rune { + return supplementalPlanesOffset + (high-highSurrogateOffset)<<10 + (low - lowSurrogateOffset) +} + +const badHex = -1 + +func h2I(c byte) int { + switch { + case c >= '0' && c <= '9': + return int(c - '0') + case c >= 'A' && c <= 'F': + return int(c - 'A' + 10) + case c >= 'a' && c <= 'f': + return int(c - 'a' + 10) + } + return badHex +} + +// decodeSingleUnicodeEscape decodes a single \uXXXX escape sequence. The prefix \u is assumed to be present and +// is not checked. +// In JSON, these escapes can either come alone or as part of "UTF16 surrogate pairs" that must be handled together. +// This function only handles one; decodeUnicodeEscape handles this more complex case. +func decodeSingleUnicodeEscape(in []byte) (rune, bool) { + // We need at least 6 characters total + if len(in) < 6 { + return utf8.RuneError, false + } + + // Convert hex to decimal + h1, h2, h3, h4 := h2I(in[2]), h2I(in[3]), h2I(in[4]), h2I(in[5]) + if h1 == badHex || h2 == badHex || h3 == badHex || h4 == badHex { + return utf8.RuneError, false + } + + // Compose the hex digits + return rune(h1<<12 + h2<<8 + h3<<4 + h4), true +} + +// isUTF16EncodedRune checks if a rune is in the range for non-BMP characters, +// which is used to describe UTF16 chars. +// Source: https://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane +func isUTF16EncodedRune(r rune) bool { + return highSurrogateOffset <= r && r <= basicMultilingualPlaneReservedOffset +} + +func decodeUnicodeEscape(in []byte) (rune, int) { + if r, ok := decodeSingleUnicodeEscape(in); !ok { + // Invalid Unicode escape + return utf8.RuneError, -1 + } else if r <= basicMultilingualPlaneOffset && !isUTF16EncodedRune(r) { + // Valid Unicode escape in Basic Multilingual Plane + return r, 6 + } else if r2, ok := decodeSingleUnicodeEscape(in[6:]); !ok { // Note: previous decodeSingleUnicodeEscape success guarantees at least 6 bytes remain + // UTF16 "high surrogate" without manditory valid following Unicode escape for the "low surrogate" + return utf8.RuneError, -1 + } else if r2 < lowSurrogateOffset { + // Invalid UTF16 "low surrogate" + return utf8.RuneError, -1 + } else { + // Valid UTF16 surrogate pair + return combineUTF16Surrogates(r, r2), 12 + } +} + +// backslashCharEscapeTable: when '\X' is found for some byte X, it is to be replaced with backslashCharEscapeTable[X] +var backslashCharEscapeTable = [...]byte{ + '"': '"', + '\\': '\\', + '/': '/', + 'b': '\b', + 'f': '\f', + 'n': '\n', + 'r': '\r', + 't': '\t', +} + +// unescapeToUTF8 unescapes the single escape sequence starting at 'in' into 'out' and returns +// how many characters were consumed from 'in' and emitted into 'out'. +// If a valid escape sequence does not appear as a prefix of 'in', (-1, -1) to signal the error. +func unescapeToUTF8(in, out []byte) (inLen int, outLen int) { + if len(in) < 2 || in[0] != '\\' { + // Invalid escape due to insufficient characters for any escape or no initial backslash + return -1, -1 + } + + // https://tools.ietf.org/html/rfc7159#section-7 + switch e := in[1]; e { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': + // Valid basic 2-character escapes (use lookup table) + out[0] = backslashCharEscapeTable[e] + return 2, 1 + case 'u': + // Unicode escape + if r, inLen := decodeUnicodeEscape(in); inLen == -1 { + // Invalid Unicode escape + return -1, -1 + } else { + // Valid Unicode escape; re-encode as UTF8 + outLen := utf8.EncodeRune(out, r) + return inLen, outLen + } + } + + return -1, -1 +} + +// unescape unescapes the string contained in 'in' and returns it as a slice. +// If 'in' contains no escaped characters: +// Returns 'in'. +// Else, if 'out' is of sufficient capacity (guaranteed if cap(out) >= len(in)): +// 'out' is used to build the unescaped string and is returned with no extra allocation +// Else: +// A new slice is allocated and returned. +func Unescape(in, out []byte) ([]byte, error) { + firstBackslash := bytes.IndexByte(in, '\\') + if firstBackslash == -1 { + return in, nil + } + + // Get a buffer of sufficient size (allocate if needed) + if cap(out) < len(in) { + out = make([]byte, len(in)) + } else { + out = out[0:len(in)] + } + + // Copy the first sequence of unescaped bytes to the output and obtain a buffer pointer (subslice) + copy(out, in[:firstBackslash]) + in = in[firstBackslash:] + buf := out[firstBackslash:] + + for len(in) > 0 { + // Unescape the next escaped character + inLen, bufLen := unescapeToUTF8(in, buf) + if inLen == -1 { + return nil, MalformedStringEscapeError + } + + in = in[inLen:] + buf = buf[bufLen:] + + // Copy everything up until the next backslash + nextBackslash := bytes.IndexByte(in, '\\') + if nextBackslash == -1 { + copy(buf, in) + buf = buf[len(in):] + break + } else { + copy(buf, in[:nextBackslash]) + buf = buf[nextBackslash:] + in = in[nextBackslash:] + } + } + + // Trim the out buffer to the amount that was actually emitted + return out[:len(out)-len(buf)], nil +} diff --git a/vendor/github.com/buger/jsonparser/fuzz.go b/vendor/github.com/buger/jsonparser/fuzz.go new file mode 100644 index 000000000..854bd11b2 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/fuzz.go @@ -0,0 +1,117 @@ +package jsonparser + +func FuzzParseString(data []byte) int { + r, err := ParseString(data) + if err != nil || r == "" { + return 0 + } + return 1 +} + +func FuzzEachKey(data []byte) int { + paths := [][]string{ + {"name"}, + {"order"}, + {"nested", "a"}, + {"nested", "b"}, + {"nested2", "a"}, + {"nested", "nested3", "b"}, + {"arr", "[1]", "b"}, + {"arrInt", "[3]"}, + {"arrInt", "[5]"}, + {"nested"}, + {"arr", "["}, + {"a\n", "b\n"}, + } + EachKey(data, func(idx int, value []byte, vt ValueType, err error) {}, paths...) + return 1 +} + +func FuzzDelete(data []byte) int { + Delete(data, "test") + return 1 +} + +func FuzzSet(data []byte) int { + _, err := Set(data, []byte(`"new value"`), "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzObjectEach(data []byte) int { + _ = ObjectEach(data, func(key, value []byte, valueType ValueType, off int) error { + return nil + }) + return 1 +} + +func FuzzParseFloat(data []byte) int { + _, err := ParseFloat(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzParseInt(data []byte) int { + _, err := ParseInt(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzParseBool(data []byte) int { + _, err := ParseBoolean(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzTokenStart(data []byte) int { + _ = tokenStart(data) + return 1 +} + +func FuzzGetString(data []byte) int { + _, err := GetString(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetFloat(data []byte) int { + _, err := GetFloat(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetInt(data []byte) int { + _, err := GetInt(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetBoolean(data []byte) int { + _, err := GetBoolean(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetUnsafeString(data []byte) int { + _, err := GetUnsafeString(data, "test") + if err != nil { + return 0 + } + return 1 +} diff --git a/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh b/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh new file mode 100644 index 000000000..c573b0e2d --- /dev/null +++ b/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh @@ -0,0 +1,47 @@ +#!/bin/bash -eu + +git clone https://github.com/dvyukov/go-fuzz-corpus +zip corpus.zip go-fuzz-corpus/json/corpus/* + +cp corpus.zip $OUT/fuzzparsestring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseString fuzzparsestring + +cp corpus.zip $OUT/fuzzeachkey_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzEachKey fuzzeachkey + +cp corpus.zip $OUT/fuzzdelete_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzDelete fuzzdelete + +cp corpus.zip $OUT/fuzzset_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzSet fuzzset + +cp corpus.zip $OUT/fuzzobjecteach_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzObjectEach fuzzobjecteach + +cp corpus.zip $OUT/fuzzparsefloat_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseFloat fuzzparsefloat + +cp corpus.zip $OUT/fuzzparseint_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseInt fuzzparseint + +cp corpus.zip $OUT/fuzzparsebool_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseBool fuzzparsebool + +cp corpus.zip $OUT/fuzztokenstart_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzTokenStart fuzztokenstart + +cp corpus.zip $OUT/fuzzgetstring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetString fuzzgetstring + +cp corpus.zip $OUT/fuzzgetfloat_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetFloat fuzzgetfloat + +cp corpus.zip $OUT/fuzzgetint_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetInt fuzzgetint + +cp corpus.zip $OUT/fuzzgetboolean_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetBoolean fuzzgetboolean + +cp corpus.zip $OUT/fuzzgetunsafestring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetUnsafeString fuzzgetunsafestring + diff --git a/vendor/github.com/buger/jsonparser/parser.go b/vendor/github.com/buger/jsonparser/parser.go new file mode 100644 index 000000000..14b80bc48 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/parser.go @@ -0,0 +1,1283 @@ +package jsonparser + +import ( + "bytes" + "errors" + "fmt" + "strconv" +) + +// Errors +var ( + KeyPathNotFoundError = errors.New("Key path not found") + UnknownValueTypeError = errors.New("Unknown value type") + MalformedJsonError = errors.New("Malformed JSON error") + MalformedStringError = errors.New("Value is string, but can't find closing '\"' symbol") + MalformedArrayError = errors.New("Value is array, but can't find closing ']' symbol") + MalformedObjectError = errors.New("Value looks like object, but can't find closing '}' symbol") + MalformedValueError = errors.New("Value looks like Number/Boolean/None, but can't find its end: ',' or '}' symbol") + OverflowIntegerError = errors.New("Value is number, but overflowed while parsing") + MalformedStringEscapeError = errors.New("Encountered an invalid escape sequence in a string") +) + +// How much stack space to allocate for unescaping JSON strings; if a string longer +// than this needs to be escaped, it will result in a heap allocation +const unescapeStackBufSize = 64 + +func tokenEnd(data []byte) int { + for i, c := range data { + switch c { + case ' ', '\n', '\r', '\t', ',', '}', ']': + return i + } + } + + return len(data) +} + +func findTokenStart(data []byte, token byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case token: + return i + case '[', '{': + return 0 + } + } + + return 0 +} + +func findKeyStart(data []byte, key string) (int, error) { + i := 0 + ln := len(data) + if ln > 0 && (data[0] == '{' || data[0] == '[') { + i = 1 + } + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + + if ku, err := Unescape(StringToBytes(key), stackbuf[:]); err == nil { + key = bytesToString(&ku) + } + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + break + } + i += strEnd + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + break + } + + i += valueOffset + + // if string is a key, and key level match + k := data[keyBegin:keyEnd] + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + if keyEscaped { + if ku, err := Unescape(k, stackbuf[:]); err != nil { + break + } else { + k = ku + } + } + + if data[i] == ':' && len(key) == len(k) && bytesToString(&k) == key { + return keyBegin - 1, nil + } + + case '[': + end := blockEnd(data[i:], data[i], ']') + if end != -1 { + i = i + end + } + case '{': + end := blockEnd(data[i:], data[i], '}') + if end != -1 { + i = i + end + } + } + i++ + } + + return -1, KeyPathNotFoundError +} + +func tokenStart(data []byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case '\n', '\r', '\t', ',', '{', '[': + return i + } + } + + return 0 +} + +// Find position of next character which is not whitespace +func nextToken(data []byte) int { + for i, c := range data { + switch c { + case ' ', '\n', '\r', '\t': + continue + default: + return i + } + } + + return -1 +} + +// Find position of last character which is not whitespace +func lastToken(data []byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case ' ', '\n', '\r', '\t': + continue + default: + return i + } + } + + return -1 +} + +// Tries to find the end of string +// Support if string contains escaped quote symbols. +func stringEnd(data []byte) (int, bool) { + escaped := false + for i, c := range data { + if c == '"' { + if !escaped { + return i + 1, false + } else { + j := i - 1 + for { + if j < 0 || data[j] != '\\' { + return i + 1, true // even number of backslashes + } + j-- + if j < 0 || data[j] != '\\' { + break // odd number of backslashes + } + j-- + + } + } + } else if c == '\\' { + escaped = true + } + } + + return -1, escaped +} + +// Find end of the data structure, array or object. +// For array openSym and closeSym will be '[' and ']', for object '{' and '}' +func blockEnd(data []byte, openSym byte, closeSym byte) int { + level := 0 + i := 0 + ln := len(data) + + for i < ln { + switch data[i] { + case '"': // If inside string, skip it + se, _ := stringEnd(data[i+1:]) + if se == -1 { + return -1 + } + i += se + case openSym: // If open symbol, increase level + level++ + case closeSym: // If close symbol, increase level + level-- + + // If we have returned to the original level, we're done + if level == 0 { + return i + 1 + } + } + i++ + } + + return -1 +} + +func searchKeys(data []byte, keys ...string) int { + keyLevel := 0 + level := 0 + i := 0 + ln := len(data) + lk := len(keys) + lastMatched := true + + if lk == 0 { + return 0 + } + + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + return -1 + } + i += strEnd + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + return -1 + } + + i += valueOffset + + // if string is a key + if data[i] == ':' { + if level < 1 { + return -1 + } + + key := data[keyBegin:keyEnd] + + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + var keyUnesc []byte + if !keyEscaped { + keyUnesc = key + } else if ku, err := Unescape(key, stackbuf[:]); err != nil { + return -1 + } else { + keyUnesc = ku + } + + if level <= len(keys) { + if equalStr(&keyUnesc, keys[level-1]) { + lastMatched = true + + // if key level match + if keyLevel == level-1 { + keyLevel++ + // If we found all keys in path + if keyLevel == lk { + return i + 1 + } + } + } else { + lastMatched = false + } + } else { + return -1 + } + } else { + i-- + } + case '{': + + // in case parent key is matched then only we will increase the level otherwise can directly + // can move to the end of this block + if !lastMatched { + end := blockEnd(data[i:], '{', '}') + if end == -1 { + return -1 + } + i += end - 1 + } else { + level++ + } + case '}': + level-- + if level == keyLevel { + keyLevel-- + } + case '[': + // If we want to get array element by index + if keyLevel == level && keys[level][0] == '[' { + var keyLen = len(keys[level]) + if keyLen < 3 || keys[level][0] != '[' || keys[level][keyLen-1] != ']' { + return -1 + } + aIdx, err := strconv.Atoi(keys[level][1 : keyLen-1]) + if err != nil { + return -1 + } + var curIdx int + var valueFound []byte + var valueOffset int + var curI = i + ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { + if curIdx == aIdx { + valueFound = value + valueOffset = offset + if dataType == String { + valueOffset = valueOffset - 2 + valueFound = data[curI+valueOffset : curI+valueOffset+len(value)+2] + } + } + curIdx += 1 + }) + + if valueFound == nil { + return -1 + } else { + subIndex := searchKeys(valueFound, keys[level+1:]...) + if subIndex < 0 { + return -1 + } + return i + valueOffset + subIndex + } + } else { + // Do not search for keys inside arrays + if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { + return -1 + } else { + i += arraySkip - 1 + } + } + case ':': // If encountered, JSON data is malformed + return -1 + } + + i++ + } + + return -1 +} + +func sameTree(p1, p2 []string) bool { + minLen := len(p1) + if len(p2) < minLen { + minLen = len(p2) + } + + for pi_1, p_1 := range p1[:minLen] { + if p2[pi_1] != p_1 { + return false + } + } + + return true +} + +func EachKey(data []byte, cb func(int, []byte, ValueType, error), paths ...[]string) int { + var x struct{} + pathFlags := make([]bool, len(paths)) + var level, pathsMatched, i int + ln := len(data) + + var maxPath int + for _, p := range paths { + if len(p) > maxPath { + maxPath = len(p) + } + } + + pathsBuf := make([]string, maxPath) + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + return -1 + } + i += strEnd + + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + return -1 + } + + i += valueOffset + + // if string is a key, and key level match + if data[i] == ':' { + match := -1 + key := data[keyBegin:keyEnd] + + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + var keyUnesc []byte + if !keyEscaped { + keyUnesc = key + } else { + var stackbuf [unescapeStackBufSize]byte + if ku, err := Unescape(key, stackbuf[:]); err != nil { + return -1 + } else { + keyUnesc = ku + } + } + + if maxPath >= level { + if level < 1 { + cb(-1, nil, Unknown, MalformedJsonError) + return -1 + } + + pathsBuf[level-1] = bytesToString(&keyUnesc) + for pi, p := range paths { + if len(p) != level || pathFlags[pi] || !equalStr(&keyUnesc, p[level-1]) || !sameTree(p, pathsBuf[:level]) { + continue + } + + match = pi + + pathsMatched++ + pathFlags[pi] = true + + v, dt, _, e := Get(data[i+1:]) + cb(pi, v, dt, e) + + if pathsMatched == len(paths) { + break + } + } + if pathsMatched == len(paths) { + return i + } + } + + if match == -1 { + tokenOffset := nextToken(data[i+1:]) + i += tokenOffset + + if data[i] == '{' { + blockSkip := blockEnd(data[i:], '{', '}') + i += blockSkip + 1 + } + } + + if i < ln { + switch data[i] { + case '{', '}', '[', '"': + i-- + } + } + } else { + i-- + } + case '{': + level++ + case '}': + level-- + case '[': + var ok bool + arrIdxFlags := make(map[int]struct{}) + pIdxFlags := make([]bool, len(paths)) + + if level < 0 { + cb(-1, nil, Unknown, MalformedJsonError) + return -1 + } + + for pi, p := range paths { + if len(p) < level+1 || pathFlags[pi] || p[level][0] != '[' || !sameTree(p, pathsBuf[:level]) { + continue + } + if len(p[level]) >= 2 { + aIdx, _ := strconv.Atoi(p[level][1 : len(p[level])-1]) + arrIdxFlags[aIdx] = x + pIdxFlags[pi] = true + } + } + + if len(arrIdxFlags) > 0 { + level++ + + var curIdx int + arrOff, _ := ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { + if _, ok = arrIdxFlags[curIdx]; ok { + for pi, p := range paths { + if pIdxFlags[pi] { + aIdx, _ := strconv.Atoi(p[level-1][1 : len(p[level-1])-1]) + + if curIdx == aIdx { + of := searchKeys(value, p[level:]...) + + pathsMatched++ + pathFlags[pi] = true + + if of != -1 { + v, dt, _, e := Get(value[of:]) + cb(pi, v, dt, e) + } + } + } + } + } + + curIdx += 1 + }) + + if pathsMatched == len(paths) { + return i + } + + i += arrOff - 1 + } else { + // Do not search for keys inside arrays + if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { + return -1 + } else { + i += arraySkip - 1 + } + } + case ']': + level-- + } + + i++ + } + + return -1 +} + +// Data types available in valid JSON data. +type ValueType int + +const ( + NotExist = ValueType(iota) + String + Number + Object + Array + Boolean + Null + Unknown +) + +func (vt ValueType) String() string { + switch vt { + case NotExist: + return "non-existent" + case String: + return "string" + case Number: + return "number" + case Object: + return "object" + case Array: + return "array" + case Boolean: + return "boolean" + case Null: + return "null" + default: + return "unknown" + } +} + +var ( + trueLiteral = []byte("true") + falseLiteral = []byte("false") + nullLiteral = []byte("null") +) + +func createInsertComponent(keys []string, setValue []byte, comma, object bool) []byte { + isIndex := string(keys[0][0]) == "[" + offset := 0 + lk := calcAllocateSpace(keys, setValue, comma, object) + buffer := make([]byte, lk, lk) + if comma { + offset += WriteToBuffer(buffer[offset:], ",") + } + if isIndex && !comma { + offset += WriteToBuffer(buffer[offset:], "[") + } else { + if object { + offset += WriteToBuffer(buffer[offset:], "{") + } + if !isIndex { + offset += WriteToBuffer(buffer[offset:], "\"") + offset += WriteToBuffer(buffer[offset:], keys[0]) + offset += WriteToBuffer(buffer[offset:], "\":") + } + } + + for i := 1; i < len(keys); i++ { + if string(keys[i][0]) == "[" { + offset += WriteToBuffer(buffer[offset:], "[") + } else { + offset += WriteToBuffer(buffer[offset:], "{\"") + offset += WriteToBuffer(buffer[offset:], keys[i]) + offset += WriteToBuffer(buffer[offset:], "\":") + } + } + offset += WriteToBuffer(buffer[offset:], string(setValue)) + for i := len(keys) - 1; i > 0; i-- { + if string(keys[i][0]) == "[" { + offset += WriteToBuffer(buffer[offset:], "]") + } else { + offset += WriteToBuffer(buffer[offset:], "}") + } + } + if isIndex && !comma { + offset += WriteToBuffer(buffer[offset:], "]") + } + if object && !isIndex { + offset += WriteToBuffer(buffer[offset:], "}") + } + return buffer +} + +func calcAllocateSpace(keys []string, setValue []byte, comma, object bool) int { + isIndex := string(keys[0][0]) == "[" + lk := 0 + if comma { + // , + lk += 1 + } + if isIndex && !comma { + // [] + lk += 2 + } else { + if object { + // { + lk += 1 + } + if !isIndex { + // "keys[0]" + lk += len(keys[0]) + 3 + } + } + + + lk += len(setValue) + for i := 1; i < len(keys); i++ { + if string(keys[i][0]) == "[" { + // [] + lk += 2 + } else { + // {"keys[i]":setValue} + lk += len(keys[i]) + 5 + } + } + + if object && !isIndex { + // } + lk += 1 + } + + return lk +} + +func WriteToBuffer(buffer []byte, str string) int { + copy(buffer, str) + return len(str) +} + +/* + +Del - Receives existing data structure, path to delete. + +Returns: +`data` - return modified data + +*/ +func Delete(data []byte, keys ...string) []byte { + lk := len(keys) + if lk == 0 { + return data[:0] + } + + array := false + if len(keys[lk-1]) > 0 && string(keys[lk-1][0]) == "[" { + array = true + } + + var startOffset, keyOffset int + endOffset := len(data) + var err error + if !array { + if len(keys) > 1 { + _, _, startOffset, endOffset, err = internalGet(data, keys[:lk-1]...) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + } + + keyOffset, err = findKeyStart(data[startOffset:endOffset], keys[lk-1]) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + keyOffset += startOffset + _, _, _, subEndOffset, _ := internalGet(data[startOffset:endOffset], keys[lk-1]) + endOffset = startOffset + subEndOffset + tokEnd := tokenEnd(data[endOffset:]) + tokStart := findTokenStart(data[:keyOffset], ","[0]) + + if data[endOffset+tokEnd] == ","[0] { + endOffset += tokEnd + 1 + } else if data[endOffset+tokEnd] == " "[0] && len(data) > endOffset+tokEnd+1 && data[endOffset+tokEnd+1] == ","[0] { + endOffset += tokEnd + 2 + } else if data[endOffset+tokEnd] == "}"[0] && data[tokStart] == ","[0] { + keyOffset = tokStart + } + } else { + _, _, keyOffset, endOffset, err = internalGet(data, keys...) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + + tokEnd := tokenEnd(data[endOffset:]) + tokStart := findTokenStart(data[:keyOffset], ","[0]) + + if data[endOffset+tokEnd] == ","[0] { + endOffset += tokEnd + 1 + } else if data[endOffset+tokEnd] == "]"[0] && data[tokStart] == ","[0] { + keyOffset = tokStart + } + } + + // We need to remove remaining trailing comma if we delete las element in the object + prevTok := lastToken(data[:keyOffset]) + remainedValue := data[endOffset:] + + var newOffset int + if nextToken(remainedValue) > -1 && remainedValue[nextToken(remainedValue)] == '}' && data[prevTok] == ',' { + newOffset = prevTok + } else { + newOffset = prevTok + 1 + } + + // We have to make a copy here if we don't want to mangle the original data, because byte slices are + // accessed by reference and not by value + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + data = append(dataCopy[:newOffset], dataCopy[endOffset:]...) + + return data +} + +/* + +Set - Receives existing data structure, path to set, and data to set at that key. + +Returns: +`value` - modified byte array +`err` - On any parsing error + +*/ +func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) { + // ensure keys are set + if len(keys) == 0 { + return nil, KeyPathNotFoundError + } + + _, _, startOffset, endOffset, err := internalGet(data, keys...) + if err != nil { + if err != KeyPathNotFoundError { + // problem parsing the data + return nil, err + } + // full path doesnt exist + // does any subpath exist? + var depth int + for i := range keys { + _, _, start, end, sErr := internalGet(data, keys[:i+1]...) + if sErr != nil { + break + } else { + endOffset = end + startOffset = start + depth++ + } + } + comma := true + object := false + if endOffset == -1 { + firstToken := nextToken(data) + // We can't set a top-level key if data isn't an object + if firstToken < 0 || data[firstToken] != '{' { + return nil, KeyPathNotFoundError + } + // Don't need a comma if the input is an empty object + secondToken := firstToken + 1 + nextToken(data[firstToken+1:]) + if data[secondToken] == '}' { + comma = false + } + // Set the top level key at the end (accounting for any trailing whitespace) + // This assumes last token is valid like '}', could check and return error + endOffset = lastToken(data) + } + depthOffset := endOffset + if depth != 0 { + // if subpath is a non-empty object, add to it + // or if subpath is a non-empty array, add to it + if (data[startOffset] == '{' && data[startOffset+1+nextToken(data[startOffset+1:])] != '}') || + (data[startOffset] == '[' && data[startOffset+1+nextToken(data[startOffset+1:])] == '{') && keys[depth:][0][0] == 91 { + depthOffset-- + startOffset = depthOffset + // otherwise, over-write it with a new object + } else { + comma = false + object = true + } + } else { + startOffset = depthOffset + } + value = append(data[:startOffset], append(createInsertComponent(keys[depth:], setValue, comma, object), data[depthOffset:]...)...) + } else { + // path currently exists + startComponent := data[:startOffset] + endComponent := data[endOffset:] + + value = make([]byte, len(startComponent)+len(endComponent)+len(setValue)) + newEndOffset := startOffset + len(setValue) + copy(value[0:startOffset], startComponent) + copy(value[startOffset:newEndOffset], setValue) + copy(value[newEndOffset:], endComponent) + } + return value, nil +} + +func getType(data []byte, offset int) ([]byte, ValueType, int, error) { + var dataType ValueType + endOffset := offset + + // if string value + if data[offset] == '"' { + dataType = String + if idx, _ := stringEnd(data[offset+1:]); idx != -1 { + endOffset += idx + 1 + } else { + return nil, dataType, offset, MalformedStringError + } + } else if data[offset] == '[' { // if array value + dataType = Array + // break label, for stopping nested loops + endOffset = blockEnd(data[offset:], '[', ']') + + if endOffset == -1 { + return nil, dataType, offset, MalformedArrayError + } + + endOffset += offset + } else if data[offset] == '{' { // if object value + dataType = Object + // break label, for stopping nested loops + endOffset = blockEnd(data[offset:], '{', '}') + + if endOffset == -1 { + return nil, dataType, offset, MalformedObjectError + } + + endOffset += offset + } else { + // Number, Boolean or None + end := tokenEnd(data[endOffset:]) + + if end == -1 { + return nil, dataType, offset, MalformedValueError + } + + value := data[offset : endOffset+end] + + switch data[offset] { + case 't', 'f': // true or false + if bytes.Equal(value, trueLiteral) || bytes.Equal(value, falseLiteral) { + dataType = Boolean + } else { + return nil, Unknown, offset, UnknownValueTypeError + } + case 'u', 'n': // undefined or null + if bytes.Equal(value, nullLiteral) { + dataType = Null + } else { + return nil, Unknown, offset, UnknownValueTypeError + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': + dataType = Number + default: + return nil, Unknown, offset, UnknownValueTypeError + } + + endOffset += end + } + return data[offset:endOffset], dataType, endOffset, nil +} + +/* +Get - Receives data structure, and key path to extract value from. + +Returns: +`value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error +`dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` +`offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. +`err` - If key not found or any other parsing issue it should return error. If key not found it also sets `dataType` to `NotExist` + +Accept multiple keys to specify path to JSON value (in case of quering nested structures). +If no keys provided it will try to extract closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. +*/ +func Get(data []byte, keys ...string) (value []byte, dataType ValueType, offset int, err error) { + a, b, _, d, e := internalGet(data, keys...) + return a, b, d, e +} + +func internalGet(data []byte, keys ...string) (value []byte, dataType ValueType, offset, endOffset int, err error) { + if len(keys) > 0 { + if offset = searchKeys(data, keys...); offset == -1 { + return nil, NotExist, -1, -1, KeyPathNotFoundError + } + } + + // Go to closest value + nO := nextToken(data[offset:]) + if nO == -1 { + return nil, NotExist, offset, -1, MalformedJsonError + } + + offset += nO + value, dataType, endOffset, err = getType(data, offset) + if err != nil { + return value, dataType, offset, endOffset, err + } + + // Strip quotes from string values + if dataType == String { + value = value[1 : len(value)-1] + } + + return value[:len(value):len(value)], dataType, offset, endOffset, nil +} + +// ArrayEach is used when iterating arrays, accepts a callback function with the same return arguments as `Get`. +func ArrayEach(data []byte, cb func(value []byte, dataType ValueType, offset int, err error), keys ...string) (offset int, err error) { + if len(data) == 0 { + return -1, MalformedObjectError + } + + nT := nextToken(data) + if nT == -1 { + return -1, MalformedJsonError + } + + offset = nT + 1 + + if len(keys) > 0 { + if offset = searchKeys(data, keys...); offset == -1 { + return offset, KeyPathNotFoundError + } + + // Go to closest value + nO := nextToken(data[offset:]) + if nO == -1 { + return offset, MalformedJsonError + } + + offset += nO + + if data[offset] != '[' { + return offset, MalformedArrayError + } + + offset++ + } + + nO := nextToken(data[offset:]) + if nO == -1 { + return offset, MalformedJsonError + } + + offset += nO + + if data[offset] == ']' { + return offset, nil + } + + for true { + v, t, o, e := Get(data[offset:]) + + if e != nil { + return offset, e + } + + if o == 0 { + break + } + + if t != NotExist { + cb(v, t, offset+o-len(v), e) + } + + if e != nil { + break + } + + offset += o + + skipToToken := nextToken(data[offset:]) + if skipToToken == -1 { + return offset, MalformedArrayError + } + offset += skipToToken + + if data[offset] == ']' { + break + } + + if data[offset] != ',' { + return offset, MalformedArrayError + } + + offset++ + } + + return offset, nil +} + +// ObjectEach iterates over the key-value pairs of a JSON object, invoking a given callback for each such entry +func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) { + offset := 0 + + // Descend to the desired key, if requested + if len(keys) > 0 { + if off := searchKeys(data, keys...); off == -1 { + return KeyPathNotFoundError + } else { + offset = off + } + } + + // Validate and skip past opening brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedObjectError + } else if offset += off; data[offset] != '{' { + return MalformedObjectError + } else { + offset++ + } + + // Skip to the first token inside the object, or stop if we find the ending brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedJsonError + } else if offset += off; data[offset] == '}' { + return nil + } + + // Loop pre-condition: data[offset] points to what should be either the next entry's key, or the closing brace (if it's anything else, the JSON is malformed) + for offset < len(data) { + // Step 1: find the next key + var key []byte + + // Check what the the next token is: start of string, end of object, or something else (error) + switch data[offset] { + case '"': + offset++ // accept as string and skip opening quote + case '}': + return nil // we found the end of the object; stop and return success + default: + return MalformedObjectError + } + + // Find the end of the key string + var keyEscaped bool + if off, esc := stringEnd(data[offset:]); off == -1 { + return MalformedJsonError + } else { + key, keyEscaped = data[offset:offset+off-1], esc + offset += off + } + + // Unescape the string if needed + if keyEscaped { + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + if keyUnescaped, err := Unescape(key, stackbuf[:]); err != nil { + return MalformedStringEscapeError + } else { + key = keyUnescaped + } + } + + // Step 2: skip the colon + if off := nextToken(data[offset:]); off == -1 { + return MalformedJsonError + } else if offset += off; data[offset] != ':' { + return MalformedJsonError + } else { + offset++ + } + + // Step 3: find the associated value, then invoke the callback + if value, valueType, off, err := Get(data[offset:]); err != nil { + return err + } else if err := callback(key, value, valueType, offset+off); err != nil { // Invoke the callback here! + return err + } else { + offset += off + } + + // Step 4: skip over the next comma to the following token, or stop if we hit the ending brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedArrayError + } else { + offset += off + switch data[offset] { + case '}': + return nil // Stop if we hit the close brace + case ',': + offset++ // Ignore the comma + default: + return MalformedObjectError + } + } + + // Skip to the next token after the comma + if off := nextToken(data[offset:]); off == -1 { + return MalformedArrayError + } else { + offset += off + } + } + + return MalformedObjectError // we shouldn't get here; it's expected that we will return via finding the ending brace +} + +// GetUnsafeString returns the value retrieved by `Get`, use creates string without memory allocation by mapping string to slice memory. It does not handle escape symbols. +func GetUnsafeString(data []byte, keys ...string) (val string, err error) { + v, _, _, e := Get(data, keys...) + + if e != nil { + return "", e + } + + return bytesToString(&v), nil +} + +// GetString returns the value retrieved by `Get`, cast to a string if possible, trying to properly handle escape and utf8 symbols +// If key data type do not match, it will return an error. +func GetString(data []byte, keys ...string) (val string, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return "", e + } + + if t != String { + return "", fmt.Errorf("Value is not a string: %s", string(v)) + } + + // If no escapes return raw content + if bytes.IndexByte(v, '\\') == -1 { + return string(v), nil + } + + return ParseString(v) +} + +// GetFloat returns the value retrieved by `Get`, cast to a float64 if possible. +// The offset is the same as in `Get`. +// If key data type do not match, it will return an error. +func GetFloat(data []byte, keys ...string) (val float64, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return 0, e + } + + if t != Number { + return 0, fmt.Errorf("Value is not a number: %s", string(v)) + } + + return ParseFloat(v) +} + +// GetInt returns the value retrieved by `Get`, cast to a int64 if possible. +// If key data type do not match, it will return an error. +func GetInt(data []byte, keys ...string) (val int64, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return 0, e + } + + if t != Number { + return 0, fmt.Errorf("Value is not a number: %s", string(v)) + } + + return ParseInt(v) +} + +// GetBoolean returns the value retrieved by `Get`, cast to a bool if possible. +// The offset is the same as in `Get`. +// If key data type do not match, it will return error. +func GetBoolean(data []byte, keys ...string) (val bool, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return false, e + } + + if t != Boolean { + return false, fmt.Errorf("Value is not a boolean: %s", string(v)) + } + + return ParseBoolean(v) +} + +// ParseBoolean parses a Boolean ValueType into a Go bool (not particularly useful, but here for completeness) +func ParseBoolean(b []byte) (bool, error) { + switch { + case bytes.Equal(b, trueLiteral): + return true, nil + case bytes.Equal(b, falseLiteral): + return false, nil + default: + return false, MalformedValueError + } +} + +// ParseString parses a String ValueType into a Go string (the main parsing work is unescaping the JSON string) +func ParseString(b []byte) (string, error) { + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + if bU, err := Unescape(b, stackbuf[:]); err != nil { + return "", MalformedValueError + } else { + return string(bU), nil + } +} + +// ParseNumber parses a Number ValueType into a Go float64 +func ParseFloat(b []byte) (float64, error) { + if v, err := parseFloat(&b); err != nil { + return 0, MalformedValueError + } else { + return v, nil + } +} + +// ParseInt parses a Number ValueType into a Go int64 +func ParseInt(b []byte) (int64, error) { + if v, ok, overflow := parseInt(b); !ok { + if overflow { + return 0, OverflowIntegerError + } + return 0, MalformedValueError + } else { + return v, nil + } +} diff --git a/vendor/github.com/invopop/jsonschema/.gitignore b/vendor/github.com/invopop/jsonschema/.gitignore new file mode 100644 index 000000000..8ef0e14fc --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/.gitignore @@ -0,0 +1,2 @@ +vendor/ +.idea/ diff --git a/vendor/github.com/invopop/jsonschema/.golangci.yml b/vendor/github.com/invopop/jsonschema/.golangci.yml new file mode 100644 index 000000000..b89b2e124 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/.golangci.yml @@ -0,0 +1,69 @@ +run: + tests: true + max-same-issues: 50 + +output: + print-issued-lines: false + +linters: + enable: + - gocyclo + - gocritic + - goconst + - dupl + - unconvert + - goimports + - unused + - govet + - nakedret + - errcheck + - revive + - ineffassign + - goconst + - unparam + - gofmt + +linters-settings: + vet: + check-shadowing: true + use-installed-packages: true + dupl: + threshold: 100 + goconst: + min-len: 8 + min-occurrences: 3 + gocyclo: + min-complexity: 20 + gocritic: + disabled-checks: + - ifElseChain + gofmt: + rewrite-rules: + - pattern: "interface{}" + replacement: "any" + - pattern: "a[b:len(a)]" + replacement: "a[b:]" + +issues: + max-per-linter: 0 + max-same: 0 + exclude-dirs: + - resources + - old + exclude-files: + - cmd/protopkg/main.go + exclude-use-default: false + exclude: + # Captured by errcheck. + - "^(G104|G204):" + # Very commonly not checked. + - 'Error return value of .(.*\.Help|.*\.MarkFlagRequired|(os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*Print(f|ln|)|os\.(Un)?Setenv). is not checked' + # Weird error only seen on Kochiku... + - "internal error: no range for" + - 'exported method `.*\.(MarshalJSON|UnmarshalJSON|URN|Payload|GoString|Close|Provides|Requires|ExcludeFromHash|MarshalText|UnmarshalText|Description|Check|Poll|Severity)` should have comment or be unexported' + - "composite literal uses unkeyed fields" + - 'declaration of "err" shadows declaration' + - "by other packages, and that stutters" + - "Potential file inclusion via variable" + - "at least one file in a package should have a package comment" + - "bad syntax for struct tag pair" diff --git a/vendor/github.com/invopop/jsonschema/COPYING b/vendor/github.com/invopop/jsonschema/COPYING new file mode 100644 index 000000000..2993ec085 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/COPYING @@ -0,0 +1,19 @@ +Copyright (C) 2014 Alec Thomas + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/invopop/jsonschema/README.md b/vendor/github.com/invopop/jsonschema/README.md new file mode 100644 index 000000000..27b362e1d --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/README.md @@ -0,0 +1,374 @@ +# Go JSON Schema Reflection + +[![Lint](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml) +[![Test Go](https://github.com/invopop/jsonschema/actions/workflows/test.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/test.yaml) +[![Go Report Card](https://goreportcard.com/badge/github.com/invopop/jsonschema)](https://goreportcard.com/report/github.com/invopop/jsonschema) +[![GoDoc](https://godoc.org/github.com/invopop/jsonschema?status.svg)](https://godoc.org/github.com/invopop/jsonschema) +[![codecov](https://codecov.io/gh/invopop/jsonschema/graph/badge.svg?token=JMEB8W8GNZ)](https://codecov.io/gh/invopop/jsonschema) +![Latest Tag](https://img.shields.io/github/v/tag/invopop/jsonschema) + +This package can be used to generate [JSON Schemas](http://json-schema.org/latest/json-schema-validation.html) from Go types through reflection. + +- Supports arbitrarily complex types, including `interface{}`, maps, slices, etc. +- Supports json-schema features such as minLength, maxLength, pattern, format, etc. +- Supports simple string and numeric enums. +- Supports custom property fields via the `jsonschema_extras` struct tag. + +This repository is a fork of the original [jsonschema](https://github.com/alecthomas/jsonschema) by [@alecthomas](https://github.com/alecthomas). At [Invopop](https://invopop.com) we use jsonschema as a cornerstone in our [GOBL library](https://github.com/invopop/gobl), and wanted to be able to continue building and adding features without taking up Alec's time. There have been a few significant changes that probably mean this version is a not compatible with with Alec's: + +- The original was stuck on the draft-04 version of JSON Schema, we've now moved to the latest JSON Schema Draft 2020-12. +- Schema IDs are added automatically from the current Go package's URL in order to be unique, and can be disabled with the `Anonymous` option. +- Support for the `FullyQualifyTypeName` option has been removed. If you have conflicts, you should use multiple schema files with different IDs, set the `DoNotReference` option to true to hide definitions completely, or add your own naming strategy using the `Namer` property. +- Support for `yaml` tags and related options has been dropped for the sake of simplification. There were a [few inconsistencies](https://github.com/invopop/jsonschema/pull/21) around this that have now been fixed. + +## Versions + +This project is still under v0 scheme, as per Go convention, breaking changes are likely. Please pin go modules to version tags or branches, and reach out if you think something can be improved. + +Go version >= 1.18 is required as generics are now being used. + +## Example + +The following Go type: + +```go +type TestUser struct { + ID int `json:"id"` + Name string `json:"name" jsonschema:"title=the name,description=The name of a friend,example=joe,example=lucy,default=alex"` + Friends []int `json:"friends,omitempty" jsonschema_description:"The list of IDs, omitted when empty"` + Tags map[string]interface{} `json:"tags,omitempty" jsonschema_extras:"a=b,foo=bar,foo=bar1"` + BirthDate time.Time `json:"birth_date,omitempty" jsonschema:"oneof_required=date"` + YearOfBirth string `json:"year_of_birth,omitempty" jsonschema:"oneof_required=year"` + Metadata interface{} `json:"metadata,omitempty" jsonschema:"oneof_type=string;array"` + FavColor string `json:"fav_color,omitempty" jsonschema:"enum=red,enum=green,enum=blue"` +} +``` + +Results in following JSON Schema: + +```go +jsonschema.Reflect(&TestUser{}) +``` + +```json +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/invopop/jsonschema_test/test-user", + "$ref": "#/$defs/TestUser", + "$defs": { + "TestUser": { + "oneOf": [ + { + "required": ["birth_date"], + "title": "date" + }, + { + "required": ["year_of_birth"], + "title": "year" + } + ], + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string", + "title": "the name", + "description": "The name of a friend", + "default": "alex", + "examples": ["joe", "lucy"] + }, + "friends": { + "items": { + "type": "integer" + }, + "type": "array", + "description": "The list of IDs, omitted when empty" + }, + "tags": { + "type": "object", + "a": "b", + "foo": ["bar", "bar1"] + }, + "birth_date": { + "type": "string", + "format": "date-time" + }, + "year_of_birth": { + "type": "string" + }, + "metadata": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array" + } + ] + }, + "fav_color": { + "type": "string", + "enum": ["red", "green", "blue"] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["id", "name"] + } + } +} +``` + +## YAML + +Support for `yaml` tags has now been removed. If you feel very strongly about this, we've opened a discussion to hear your comments: https://github.com/invopop/jsonschema/discussions/28 + +The recommended approach if you need to deal with YAML data is to first convert to JSON. The [invopop/yaml](https://github.com/invopop/yaml) library will make this trivial. + +## Configurable behaviour + +The behaviour of the schema generator can be altered with parameters when a `jsonschema.Reflector` +instance is created. + +### ExpandedStruct + +If set to `true`, makes the top level struct not to reference itself in the definitions. But type passed should be a struct type. + +eg. + +```go +type GrandfatherType struct { + FamilyName string `json:"family_name" jsonschema:"required"` +} + +type SomeBaseType struct { + SomeBaseProperty int `json:"some_base_property"` + // The jsonschema required tag is nonsensical for private and ignored properties. + // Their presence here tests that the fields *will not* be required in the output + // schema, even if they are tagged required. + somePrivateBaseProperty string `json:"i_am_private" jsonschema:"required"` + SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"` + SomeSchemaIgnoredProperty string `jsonschema:"-,required"` + SomeUntaggedBaseProperty bool `jsonschema:"required"` + someUnexportedUntaggedBaseProperty bool + Grandfather GrandfatherType `json:"grand"` +} +``` + +will output: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "required": ["some_base_property", "grand", "SomeUntaggedBaseProperty"], + "properties": { + "SomeUntaggedBaseProperty": { + "type": "boolean" + }, + "grand": { + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/definitions/GrandfatherType" + }, + "some_base_property": { + "type": "integer" + } + }, + "type": "object", + "$defs": { + "GrandfatherType": { + "required": ["family_name"], + "properties": { + "family_name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + } + } +} +``` + +### Using Go Comments + +Writing a good schema with descriptions inside tags can become cumbersome and tedious, especially if you already have some Go comments around your types and field definitions. If you'd like to take advantage of these existing comments, you can use the `AddGoComments(base, path string)` method that forms part of the reflector to parse your go files and automatically generate a dictionary of Go import paths, types, and fields, to individual comments. These will then be used automatically as description fields, and can be overridden with a manual definition if needed. + +Take a simplified example of a User struct which for the sake of simplicity we assume is defined inside this package: + +```go +package main + +// User is used as a base to provide tests for comments. +type User struct { + // Unique sequential identifier. + ID int `json:"id" jsonschema:"required"` + // Name of the user + Name string `json:"name"` +} +``` + +To get the comments provided into your JSON schema, use a regular `Reflector` and add the go code using an import module URL and path. Fully qualified go module paths cannot be determined reliably by the `go/parser` library, so we need to introduce this manually: + +```go +r := new(Reflector) +if err := r.AddGoComments("github.com/invopop/jsonschema", "./"); err != nil { + // deal with error +} +s := r.Reflect(&User{}) +// output +``` + +Expect the results to be similar to: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/User", + "$defs": { + "User": { + "required": ["id"], + "properties": { + "id": { + "type": "integer", + "description": "Unique sequential identifier." + }, + "name": { + "type": "string", + "description": "Name of the user" + } + }, + "additionalProperties": false, + "type": "object", + "description": "User is used as a base to provide tests for comments." + } + } +} +``` + +### Custom Key Naming + +In some situations, the keys actually used to write files are different from Go structs'. + +This is often the case when writing a configuration file to YAML or JSON from a Go struct, or when returning a JSON response for a Web API: APIs typically use snake_case, while Go uses PascalCase. + +You can pass a `func(string) string` function to `Reflector`'s `KeyNamer` option to map Go field names to JSON key names and reflect the aforementioned transformations, without having to specify `json:"..."` on every struct field. + +For example, consider the following struct + +```go +type User struct { + GivenName string + PasswordSalted []byte `json:"salted_password"` +} +``` + +We can transform field names to snake_case in the generated JSON schema: + +```go +r := new(jsonschema.Reflector) +r.KeyNamer = strcase.SnakeCase // from package github.com/stoewer/go-strcase + +r.Reflect(&User{}) +``` + +Will yield + +```diff + { + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/User", + "$defs": { + "User": { + "properties": { +- "GivenName": { ++ "given_name": { + "type": "string" + }, + "salted_password": { + "type": "string", + "contentEncoding": "base64" + } + }, + "additionalProperties": false, + "type": "object", +- "required": ["GivenName", "salted_password"] ++ "required": ["given_name", "salted_password"] + } + } + } +``` + +As you can see, if a field name has a `json:""` tag set, the `key` argument to `KeyNamer` will have the value of that tag. + +### Custom Type Definitions + +Sometimes it can be useful to have custom JSON Marshal and Unmarshal methods in your structs that automatically convert for example a string into an object. + +This library will recognize and attempt to call four different methods that help you adjust schemas to your specific needs: + +- `JSONSchema() *Schema` - will prevent auto-generation of the schema so that you can provide your own definition. +- `JSONSchemaExtend(schema *jsonschema.Schema)` - will be called _after_ the schema has been generated, allowing you to add or manipulate the fields easily. +- `JSONSchemaAlias() any` - is called when reflecting the type of object and allows for an alternative to be used instead. +- `JSONSchemaProperty(prop string) any` - will be called for every property inside a struct giving you the chance to provide an alternative object to convert into a schema. + +Note that all of these methods **must** be defined on a non-pointer object for them to be called. + +Take the following simplified example of a `CompactDate` that only includes the Year and Month: + +```go +type CompactDate struct { + Year int + Month int +} + +func (d *CompactDate) UnmarshalJSON(data []byte) error { + if len(data) != 9 { + return errors.New("invalid compact date length") + } + var err error + d.Year, err = strconv.Atoi(string(data[1:5])) + if err != nil { + return err + } + d.Month, err = strconv.Atoi(string(data[7:8])) + if err != nil { + return err + } + return nil +} + +func (d *CompactDate) MarshalJSON() ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte('"') + buf.WriteString(fmt.Sprintf("%d-%02d", d.Year, d.Month)) + buf.WriteByte('"') + return buf.Bytes(), nil +} + +func (CompactDate) JSONSchema() *Schema { + return &Schema{ + Type: "string", + Title: "Compact Date", + Description: "Short date that only includes year and month", + Pattern: "^[0-9]{4}-[0-1][0-9]$", + } +} +``` + +The resulting schema generated for this struct would look like: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/CompactDate", + "$defs": { + "CompactDate": { + "pattern": "^[0-9]{4}-[0-1][0-9]$", + "type": "string", + "title": "Compact Date", + "description": "Short date that only includes year and month" + } + } +} +``` diff --git a/vendor/github.com/invopop/jsonschema/id.go b/vendor/github.com/invopop/jsonschema/id.go new file mode 100644 index 000000000..73fafb38d --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/id.go @@ -0,0 +1,76 @@ +package jsonschema + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// ID represents a Schema ID type which should always be a URI. +// See draft-bhutton-json-schema-00 section 8.2.1 +type ID string + +// EmptyID is used to explicitly define an ID with no value. +const EmptyID ID = "" + +// Validate is used to check if the ID looks like a proper schema. +// This is done by parsing the ID as a URL and checking it has all the +// relevant parts. +func (id ID) Validate() error { + u, err := url.Parse(id.String()) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if u.Hostname() == "" { + return errors.New("missing hostname") + } + if !strings.Contains(u.Hostname(), ".") { + return errors.New("hostname does not look valid") + } + if u.Path == "" { + return errors.New("path is expected") + } + if u.Scheme != "https" && u.Scheme != "http" { + return errors.New("unexpected schema") + } + return nil +} + +// Anchor sets the anchor part of the schema URI. +func (id ID) Anchor(name string) ID { + b := id.Base() + return ID(b.String() + "#" + name) +} + +// Def adds or replaces a definition identifier. +func (id ID) Def(name string) ID { + b := id.Base() + return ID(b.String() + "#/$defs/" + name) +} + +// Add appends the provided path to the id, and removes any +// anchor data that might be there. +func (id ID) Add(path string) ID { + b := id.Base() + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return ID(b.String() + path) +} + +// Base removes any anchor information from the schema +func (id ID) Base() ID { + s := id.String() + i := strings.LastIndex(s, "#") + if i != -1 { + s = s[0:i] + } + s = strings.TrimRight(s, "/") + return ID(s) +} + +// String provides string version of ID +func (id ID) String() string { + return string(id) +} diff --git a/vendor/github.com/invopop/jsonschema/reflect.go b/vendor/github.com/invopop/jsonschema/reflect.go new file mode 100644 index 000000000..73ce7e465 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/reflect.go @@ -0,0 +1,1148 @@ +// Package jsonschema uses reflection to generate JSON Schemas from Go types [1]. +// +// If json tags are present on struct fields, they will be used to infer +// property names and if a property is required (omitempty is present). +// +// [1] http://json-schema.org/latest/json-schema-validation.html +package jsonschema + +import ( + "bytes" + "encoding/json" + "net" + "net/url" + "reflect" + "strconv" + "strings" + "time" +) + +// customSchemaImpl is used to detect if the type provides it's own +// custom Schema Type definition to use instead. Very useful for situations +// where there are custom JSON Marshal and Unmarshal methods. +type customSchemaImpl interface { + JSONSchema() *Schema +} + +// Function to be run after the schema has been generated. +// this will let you modify a schema afterwards +type extendSchemaImpl interface { + JSONSchemaExtend(*Schema) +} + +// If the object to be reflected defines a `JSONSchemaAlias` method, its type will +// be used instead of the original type. +type aliasSchemaImpl interface { + JSONSchemaAlias() any +} + +// If an object to be reflected defines a `JSONSchemaPropertyAlias` method, +// it will be called for each property to determine if another object +// should be used for the contents. +type propertyAliasSchemaImpl interface { + JSONSchemaProperty(prop string) any +} + +var customAliasSchema = reflect.TypeOf((*aliasSchemaImpl)(nil)).Elem() +var customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).Elem() + +var customType = reflect.TypeOf((*customSchemaImpl)(nil)).Elem() +var extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).Elem() + +// customSchemaGetFieldDocString +type customSchemaGetFieldDocString interface { + GetFieldDocString(fieldName string) string +} + +type customGetFieldDocString func(fieldName string) string + +var customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem() + +// Reflect reflects to Schema from a value using the default Reflector +func Reflect(v any) *Schema { + return ReflectFromType(reflect.TypeOf(v)) +} + +// ReflectFromType generates root schema using the default Reflector +func ReflectFromType(t reflect.Type) *Schema { + r := &Reflector{} + return r.ReflectFromType(t) +} + +// A Reflector reflects values into a Schema. +type Reflector struct { + // BaseSchemaID defines the URI that will be used as a base to determine Schema + // IDs for models. For example, a base Schema ID of `https://invopop.com/schemas` + // when defined with a struct called `User{}`, will result in a schema with an + // ID set to `https://invopop.com/schemas/user`. + // + // If no `BaseSchemaID` is provided, we'll take the type's complete package path + // and use that as a base instead. Set `Anonymous` to try if you do not want to + // include a schema ID. + BaseSchemaID ID + + // Anonymous when true will hide the auto-generated Schema ID and provide what is + // known as an "anonymous schema". As a rule, this is not recommended. + Anonymous bool + + // AssignAnchor when true will use the original struct's name as an anchor inside + // every definition, including the root schema. These can be useful for having a + // reference to the original struct's name in CamelCase instead of the snake-case used + // by default for URI compatibility. + // + // Anchors do not appear to be widely used out in the wild, so at this time the + // anchors themselves will not be used inside generated schema. + AssignAnchor bool + + // AllowAdditionalProperties will cause the Reflector to generate a schema + // without additionalProperties set to 'false' for all struct types. This means + // the presence of additional keys in JSON objects will not cause validation + // to fail. Note said additional keys will simply be dropped when the + // validated JSON is unmarshaled. + AllowAdditionalProperties bool + + // RequiredFromJSONSchemaTags will cause the Reflector to generate a schema + // that requires any key tagged with `jsonschema:required`, overriding the + // default of requiring any key *not* tagged with `json:,omitempty`. + RequiredFromJSONSchemaTags bool + + // Do not reference definitions. This will remove the top-level $defs map and + // instead cause the entire structure of types to be output in one tree. The + // list of type definitions (`$defs`) will not be included. + DoNotReference bool + + // ExpandedStruct when true will include the reflected type's definition in the + // root as opposed to a definition with a reference. + ExpandedStruct bool + + // FieldNameTag will change the tag used to get field names. json tags are used by default. + FieldNameTag string + + // IgnoredTypes defines a slice of types that should be ignored in the schema, + // switching to just allowing additional properties instead. + IgnoredTypes []any + + // Lookup allows a function to be defined that will provide a custom mapping of + // types to Schema IDs. This allows existing schema documents to be referenced + // by their ID instead of being embedded into the current schema definitions. + // Reflected types will never be pointers, only underlying elements. + Lookup func(reflect.Type) ID + + // Mapper is a function that can be used to map custom Go types to jsonschema schemas. + Mapper func(reflect.Type) *Schema + + // Namer allows customizing of type names. The default is to use the type's name + // provided by the reflect package. + Namer func(reflect.Type) string + + // KeyNamer allows customizing of key names. + // The default is to use the key's name as is, or the json tag if present. + // If a json tag is present, KeyNamer will receive the tag's name as an argument, not the original key name. + KeyNamer func(string) string + + // AdditionalFields allows adding structfields for a given type + AdditionalFields func(reflect.Type) []reflect.StructField + + // LookupComment allows customizing comment lookup. Given a reflect.Type and optionally + // a field name, it should return the comment string associated with this type or field. + // + // If the field name is empty, it should return the type's comment; otherwise, the field's + // comment should be returned. If no comment is found, an empty string should be returned. + // + // When set, this function is called before the below CommentMap lookup mechanism. However, + // if it returns an empty string, the CommentMap is still consulted. + LookupComment func(reflect.Type, string) string + + // CommentMap is a dictionary of fully qualified go types and fields to comment + // strings that will be used if a description has not already been provided in + // the tags. Types and fields are added to the package path using "." as a + // separator. + // + // Type descriptions should be defined like: + // + // map[string]string{"github.com/invopop/jsonschema.Reflector": "A Reflector reflects values into a Schema."} + // + // And Fields defined as: + // + // map[string]string{"github.com/invopop/jsonschema.Reflector.DoNotReference": "Do not reference definitions."} + // + // See also: AddGoComments, LookupComment + CommentMap map[string]string +} + +// Reflect reflects to Schema from a value. +func (r *Reflector) Reflect(v any) *Schema { + return r.ReflectFromType(reflect.TypeOf(v)) +} + +// ReflectFromType generates root schema +func (r *Reflector) ReflectFromType(t reflect.Type) *Schema { + if t.Kind() == reflect.Ptr { + t = t.Elem() // re-assign from pointer + } + + name := r.typeName(t) + + s := new(Schema) + definitions := Definitions{} + s.Definitions = definitions + bs := r.reflectTypeToSchemaWithID(definitions, t) + if r.ExpandedStruct { + *s = *definitions[name] + delete(definitions, name) + } else { + *s = *bs + } + + // Attempt to set the schema ID + if !r.Anonymous && s.ID == EmptyID { + baseSchemaID := r.BaseSchemaID + if baseSchemaID == EmptyID { + id := ID("https://" + t.PkgPath()) + if err := id.Validate(); err == nil { + // it's okay to silently ignore URL errors + baseSchemaID = id + } + } + if baseSchemaID != EmptyID { + s.ID = baseSchemaID.Add(ToSnakeCase(name)) + } + } + + s.Version = Version + if !r.DoNotReference { + s.Definitions = definitions + } + + return s +} + +// Available Go defined types for JSON Schema Validation. +// RFC draft-wright-json-schema-validation-00, section 7.3 +var ( + timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1 + ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6 +) + +// Byte slices will be encoded as base64 +var byteSliceType = reflect.TypeOf([]byte(nil)) + +// Except for json.RawMessage +var rawMessageType = reflect.TypeOf(json.RawMessage{}) + +// Go code generated from protobuf enum types should fulfil this interface. +type protoEnum interface { + EnumDescriptor() ([]byte, []int) +} + +var protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem() + +// SetBaseSchemaID is a helper use to be able to set the reflectors base +// schema ID from a string as opposed to then ID instance. +func (r *Reflector) SetBaseSchemaID(id string) { + r.BaseSchemaID = ID(id) +} + +func (r *Reflector) refOrReflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { + id := r.lookupID(t) + if id != EmptyID { + return &Schema{ + Ref: id.String(), + } + } + + // Already added to definitions? + if def := r.refDefinition(definitions, t); def != nil { + return def + } + + return r.reflectTypeToSchemaWithID(definitions, t) +} + +func (r *Reflector) reflectTypeToSchemaWithID(defs Definitions, t reflect.Type) *Schema { + s := r.reflectTypeToSchema(defs, t) + if s != nil { + if r.Lookup != nil { + id := r.Lookup(t) + if id != EmptyID { + s.ID = id + } + } + } + return s +} + +func (r *Reflector) reflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { + // only try to reflect non-pointers + if t.Kind() == reflect.Ptr { + return r.refOrReflectTypeToSchema(definitions, t.Elem()) + } + + // Check if the there is an alias method that provides an object + // that we should use instead of this one. + if t.Implements(customAliasSchema) { + v := reflect.New(t) + o := v.Interface().(aliasSchemaImpl) + t = reflect.TypeOf(o.JSONSchemaAlias()) + return r.refOrReflectTypeToSchema(definitions, t) + } + + // Do any pre-definitions exist? + if r.Mapper != nil { + if t := r.Mapper(t); t != nil { + return t + } + } + if rt := r.reflectCustomSchema(definitions, t); rt != nil { + return rt + } + + // Prepare a base to which details can be added + st := new(Schema) + + // jsonpb will marshal protobuf enum options as either strings or integers. + // It will unmarshal either. + if t.Implements(protoEnumType) { + st.OneOf = []*Schema{ + {Type: "string"}, + {Type: "integer"}, + } + return st + } + + // Defined format types for JSON Schema Validation + // RFC draft-wright-json-schema-validation-00, section 7.3 + // TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7 + if t == ipType { + // TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + st.Type = "string" + st.Format = "ipv4" + return st + } + + switch t.Kind() { + case reflect.Struct: + r.reflectStruct(definitions, t, st) + + case reflect.Slice, reflect.Array: + r.reflectSliceOrArray(definitions, t, st) + + case reflect.Map: + r.reflectMap(definitions, t, st) + + case reflect.Interface: + // empty + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + st.Type = "integer" + + case reflect.Float32, reflect.Float64: + st.Type = "number" + + case reflect.Bool: + st.Type = "boolean" + + case reflect.String: + st.Type = "string" + + default: + panic("unsupported type " + t.String()) + } + + r.reflectSchemaExtend(definitions, t, st) + + // Always try to reference the definition which may have just been created + if def := r.refDefinition(definitions, t); def != nil { + return def + } + + return st +} + +func (r *Reflector) reflectCustomSchema(definitions Definitions, t reflect.Type) *Schema { + if t.Kind() == reflect.Ptr { + return r.reflectCustomSchema(definitions, t.Elem()) + } + + if t.Implements(customType) { + v := reflect.New(t) + o := v.Interface().(customSchemaImpl) + st := o.JSONSchema() + r.addDefinition(definitions, t, st) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + return st + } + + return nil +} + +func (r *Reflector) reflectSchemaExtend(definitions Definitions, t reflect.Type, s *Schema) *Schema { + if t.Implements(extendType) { + v := reflect.New(t) + o := v.Interface().(extendSchemaImpl) + o.JSONSchemaExtend(s) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + } + + return s +} + +func (r *Reflector) reflectSliceOrArray(definitions Definitions, t reflect.Type, st *Schema) { + if t == rawMessageType { + return + } + + r.addDefinition(definitions, t, st) + + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + + if t.Kind() == reflect.Array { + l := uint64(t.Len()) + st.MinItems = &l + st.MaxItems = &l + } + if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() { + st.Type = "string" + // NOTE: ContentMediaType is not set here + st.ContentEncoding = "base64" + } else { + st.Type = "array" + st.Items = r.refOrReflectTypeToSchema(definitions, t.Elem()) + } +} + +func (r *Reflector) reflectMap(definitions Definitions, t reflect.Type, st *Schema) { + r.addDefinition(definitions, t, st) + + st.Type = "object" + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + + switch t.Key().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + st.PatternProperties = map[string]*Schema{ + "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), + } + st.AdditionalProperties = FalseSchema + return + } + if t.Elem().Kind() != reflect.Interface { + st.AdditionalProperties = r.refOrReflectTypeToSchema(definitions, t.Elem()) + } +} + +// Reflects a struct to a JSON Schema type. +func (r *Reflector) reflectStruct(definitions Definitions, t reflect.Type, s *Schema) { + // Handle special types + switch t { + case timeType: // date-time RFC section 7.3.1 + s.Type = "string" + s.Format = "date-time" + return + case uriType: // uri RFC section 7.3.6 + s.Type = "string" + s.Format = "uri" + return + } + + r.addDefinition(definitions, t, s) + s.Type = "object" + s.Properties = NewProperties() + s.Description = r.lookupComment(t, "") + if r.AssignAnchor { + s.Anchor = t.Name() + } + if !r.AllowAdditionalProperties && s.AdditionalProperties == nil { + s.AdditionalProperties = FalseSchema + } + + ignored := false + for _, it := range r.IgnoredTypes { + if reflect.TypeOf(it) == t { + ignored = true + break + } + } + if !ignored { + r.reflectStructFields(s, definitions, t) + } +} + +func (r *Reflector) reflectStructFields(st *Schema, definitions Definitions, t reflect.Type) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return + } + + var getFieldDocString customGetFieldDocString + if t.Implements(customStructGetFieldDocString) { + v := reflect.New(t) + o := v.Interface().(customSchemaGetFieldDocString) + getFieldDocString = o.GetFieldDocString + } + + customPropertyMethod := func(string) any { + return nil + } + if t.Implements(customPropertyAliasSchema) { + v := reflect.New(t) + o := v.Interface().(propertyAliasSchemaImpl) + customPropertyMethod = o.JSONSchemaProperty + } + + handleField := func(f reflect.StructField) { + name, shouldEmbed, required, nullable := r.reflectFieldName(f) + // if anonymous and exported type should be processed recursively + // current type should inherit properties of anonymous one + if name == "" { + if shouldEmbed { + r.reflectStructFields(st, definitions, f.Type) + } + return + } + + // If a JSONSchemaAlias(prop string) method is defined, attempt to use + // the provided object's type instead of the field's type. + var property *Schema + if alias := customPropertyMethod(name); alias != nil { + property = r.refOrReflectTypeToSchema(definitions, reflect.TypeOf(alias)) + } else { + property = r.refOrReflectTypeToSchema(definitions, f.Type) + } + + property.structKeywordsFromTags(f, st, name) + if property.Description == "" { + property.Description = r.lookupComment(t, f.Name) + } + if getFieldDocString != nil { + property.Description = getFieldDocString(f.Name) + } + + if nullable { + property = &Schema{ + OneOf: []*Schema{ + property, + { + Type: "null", + }, + }, + } + } + + st.Properties.Set(name, property) + if required { + st.Required = appendUniqueString(st.Required, name) + } + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + handleField(f) + } + if r.AdditionalFields != nil { + if af := r.AdditionalFields(t); af != nil { + for _, sf := range af { + handleField(sf) + } + } + } +} + +func appendUniqueString(base []string, value string) []string { + for _, v := range base { + if v == value { + return base + } + } + return append(base, value) +} + +// addDefinition will append the provided schema. If needed, an ID and anchor will also be added. +func (r *Reflector) addDefinition(definitions Definitions, t reflect.Type, s *Schema) { + name := r.typeName(t) + if name == "" { + return + } + definitions[name] = s +} + +// refDefinition will provide a schema with a reference to an existing definition. +func (r *Reflector) refDefinition(definitions Definitions, t reflect.Type) *Schema { + if r.DoNotReference { + return nil + } + name := r.typeName(t) + if name == "" { + return nil + } + if _, ok := definitions[name]; !ok { + return nil + } + return &Schema{ + Ref: "#/$defs/" + name, + } +} + +func (r *Reflector) lookupID(t reflect.Type) ID { + if r.Lookup != nil { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return r.Lookup(t) + + } + return EmptyID +} + +func (t *Schema) structKeywordsFromTags(f reflect.StructField, parent *Schema, propertyName string) { + t.Description = f.Tag.Get("jsonschema_description") + + tags := splitOnUnescapedCommas(f.Tag.Get("jsonschema")) + tags = t.genericKeywords(tags, parent, propertyName) + + switch t.Type { + case "string": + t.stringKeywords(tags) + case "number": + t.numericalKeywords(tags) + case "integer": + t.numericalKeywords(tags) + case "array": + t.arrayKeywords(tags) + case "boolean": + t.booleanKeywords(tags) + } + extras := strings.Split(f.Tag.Get("jsonschema_extras"), ",") + t.extraKeywords(extras) +} + +// read struct tags for generic keywords +func (t *Schema) genericKeywords(tags []string, parent *Schema, propertyName string) []string { //nolint:gocyclo + unprocessed := make([]string, 0, len(tags)) + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "title": + t.Title = val + case "description": + t.Description = val + case "type": + t.Type = val + case "anchor": + t.Anchor = val + case "oneof_required": + var typeFound *Schema + for i := range parent.OneOf { + if parent.OneOf[i].Title == nameValue[1] { + typeFound = parent.OneOf[i] + } + } + if typeFound == nil { + typeFound = &Schema{ + Title: nameValue[1], + Required: []string{}, + } + parent.OneOf = append(parent.OneOf, typeFound) + } + typeFound.Required = append(typeFound.Required, propertyName) + case "anyof_required": + var typeFound *Schema + for i := range parent.AnyOf { + if parent.AnyOf[i].Title == nameValue[1] { + typeFound = parent.AnyOf[i] + } + } + if typeFound == nil { + typeFound = &Schema{ + Title: nameValue[1], + Required: []string{}, + } + parent.AnyOf = append(parent.AnyOf, typeFound) + } + typeFound.Required = append(typeFound.Required, propertyName) + case "oneof_ref": + subSchema := t + if t.Items != nil { + subSchema = t.Items + } + if subSchema.OneOf == nil { + subSchema.OneOf = make([]*Schema, 0, 1) + } + subSchema.Ref = "" + refs := strings.Split(nameValue[1], ";") + for _, r := range refs { + subSchema.OneOf = append(subSchema.OneOf, &Schema{ + Ref: r, + }) + } + case "oneof_type": + if t.OneOf == nil { + t.OneOf = make([]*Schema, 0, 1) + } + t.Type = "" + types := strings.Split(nameValue[1], ";") + for _, ty := range types { + t.OneOf = append(t.OneOf, &Schema{ + Type: ty, + }) + } + case "anyof_ref": + subSchema := t + if t.Items != nil { + subSchema = t.Items + } + if subSchema.AnyOf == nil { + subSchema.AnyOf = make([]*Schema, 0, 1) + } + subSchema.Ref = "" + refs := strings.Split(nameValue[1], ";") + for _, r := range refs { + subSchema.AnyOf = append(subSchema.AnyOf, &Schema{ + Ref: r, + }) + } + case "anyof_type": + if t.AnyOf == nil { + t.AnyOf = make([]*Schema, 0, 1) + } + t.Type = "" + types := strings.Split(nameValue[1], ";") + for _, ty := range types { + t.AnyOf = append(t.AnyOf, &Schema{ + Type: ty, + }) + } + default: + unprocessed = append(unprocessed, tag) + } + } + } + return unprocessed +} + +// read struct tags for boolean type keywords +func (t *Schema) booleanKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) != 2 { + continue + } + name, val := nameValue[0], nameValue[1] + if name == "default" { + if val == "true" { + t.Default = true + } else if val == "false" { + t.Default = false + } + } + } +} + +// read struct tags for string type keywords +func (t *Schema) stringKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "minLength": + t.MinLength = parseUint(val) + case "maxLength": + t.MaxLength = parseUint(val) + case "pattern": + t.Pattern = val + case "format": + t.Format = val + case "readOnly": + i, _ := strconv.ParseBool(val) + t.ReadOnly = i + case "writeOnly": + i, _ := strconv.ParseBool(val) + t.WriteOnly = i + case "default": + t.Default = val + case "example": + t.Examples = append(t.Examples, val) + case "enum": + t.Enum = append(t.Enum, val) + } + } + } +} + +// read struct tags for numerical type keywords +func (t *Schema) numericalKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "multipleOf": + t.MultipleOf, _ = toJSONNumber(val) + case "minimum": + t.Minimum, _ = toJSONNumber(val) + case "maximum": + t.Maximum, _ = toJSONNumber(val) + case "exclusiveMaximum": + t.ExclusiveMaximum, _ = toJSONNumber(val) + case "exclusiveMinimum": + t.ExclusiveMinimum, _ = toJSONNumber(val) + case "default": + if num, ok := toJSONNumber(val); ok { + t.Default = num + } + case "example": + if num, ok := toJSONNumber(val); ok { + t.Examples = append(t.Examples, num) + } + case "enum": + if num, ok := toJSONNumber(val); ok { + t.Enum = append(t.Enum, num) + } + } + } + } +} + +// read struct tags for object type keywords +// func (t *Type) objectKeywords(tags []string) { +// for _, tag := range tags{ +// nameValue := strings.Split(tag, "=") +// name, val := nameValue[0], nameValue[1] +// switch name{ +// case "dependencies": +// t.Dependencies = val +// break; +// case "patternProperties": +// t.PatternProperties = val +// break; +// } +// } +// } + +// read struct tags for array type keywords +func (t *Schema) arrayKeywords(tags []string) { + var defaultValues []any + + unprocessed := make([]string, 0, len(tags)) + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "minItems": + t.MinItems = parseUint(val) + case "maxItems": + t.MaxItems = parseUint(val) + case "uniqueItems": + t.UniqueItems = true + case "default": + defaultValues = append(defaultValues, val) + case "format": + t.Items.Format = val + case "pattern": + t.Items.Pattern = val + default: + unprocessed = append(unprocessed, tag) // left for further processing by underlying type + } + } + } + if len(defaultValues) > 0 { + t.Default = defaultValues + } + + if len(unprocessed) == 0 { + // we don't have anything else to process + return + } + + switch t.Items.Type { + case "string": + t.Items.stringKeywords(unprocessed) + case "number": + t.Items.numericalKeywords(unprocessed) + case "integer": + t.Items.numericalKeywords(unprocessed) + case "array": + // explicitly don't support traversal for the [][]..., as it's unclear where the array tags belong + case "boolean": + t.Items.booleanKeywords(unprocessed) + } +} + +func (t *Schema) extraKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + t.setExtra(nameValue[0], nameValue[1]) + } + } +} + +func (t *Schema) setExtra(key, val string) { + if t.Extras == nil { + t.Extras = map[string]any{} + } + if existingVal, ok := t.Extras[key]; ok { + switch existingVal := existingVal.(type) { + case string: + t.Extras[key] = []string{existingVal, val} + case []string: + t.Extras[key] = append(existingVal, val) + case int: + t.Extras[key], _ = strconv.Atoi(val) + case bool: + t.Extras[key] = (val == "true" || val == "t") + } + } else { + switch key { + case "minimum": + t.Extras[key], _ = strconv.Atoi(val) + default: + var x any + if val == "true" { + x = true + } else if val == "false" { + x = false + } else { + x = val + } + t.Extras[key] = x + } + } +} + +func requiredFromJSONTags(tags []string, val *bool) { + if ignoredByJSONTags(tags) { + return + } + + for _, tag := range tags[1:] { + if tag == "omitempty" { + *val = false + return + } + } + *val = true +} + +func requiredFromJSONSchemaTags(tags []string, val *bool) { + if ignoredByJSONSchemaTags(tags) { + return + } + for _, tag := range tags { + if tag == "required" { + *val = true + } + } +} + +func nullableFromJSONSchemaTags(tags []string) bool { + if ignoredByJSONSchemaTags(tags) { + return false + } + for _, tag := range tags { + if tag == "nullable" { + return true + } + } + return false +} + +func ignoredByJSONTags(tags []string) bool { + return tags[0] == "-" +} + +func ignoredByJSONSchemaTags(tags []string) bool { + return tags[0] == "-" +} + +func inlinedByJSONTags(tags []string) bool { + for _, tag := range tags[1:] { + if tag == "inline" { + return true + } + } + return false +} + +// toJSONNumber converts string to *json.Number. +// It'll aso return whether the number is valid. +func toJSONNumber(s string) (json.Number, bool) { + num := json.Number(s) + if _, err := num.Int64(); err == nil { + return num, true + } + if _, err := num.Float64(); err == nil { + return num, true + } + return json.Number(""), false +} + +func parseUint(num string) *uint64 { + val, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return nil + } + return &val +} + +func (r *Reflector) fieldNameTag() string { + if r.FieldNameTag != "" { + return r.FieldNameTag + } + return "json" +} + +func (r *Reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) { + jsonTagString := f.Tag.Get(r.fieldNameTag()) + jsonTags := strings.Split(jsonTagString, ",") + + if ignoredByJSONTags(jsonTags) { + return "", false, false, false + } + + schemaTags := strings.Split(f.Tag.Get("jsonschema"), ",") + if ignoredByJSONSchemaTags(schemaTags) { + return "", false, false, false + } + + var required bool + if !r.RequiredFromJSONSchemaTags { + requiredFromJSONTags(jsonTags, &required) + } + requiredFromJSONSchemaTags(schemaTags, &required) + + nullable := nullableFromJSONSchemaTags(schemaTags) + + if f.Anonymous && jsonTags[0] == "" { + // As per JSON Marshal rules, anonymous structs are inherited + if f.Type.Kind() == reflect.Struct { + return "", true, false, false + } + + // As per JSON Marshal rules, anonymous pointer to structs are inherited + if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct { + return "", true, false, false + } + } + + // As per JSON Marshal rules, inline nested structs that have `inline` tag. + if inlinedByJSONTags(jsonTags) { + return "", true, false, false + } + + // Try to determine the name from the different combos + name := f.Name + if jsonTags[0] != "" { + name = jsonTags[0] + } + if !f.Anonymous && f.PkgPath != "" { + // field not anonymous and not export has no export name + name = "" + } else if r.KeyNamer != nil { + name = r.KeyNamer(name) + } + + return name, false, required, nullable +} + +// UnmarshalJSON is used to parse a schema object or boolean. +func (t *Schema) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte("true")) { + *t = *TrueSchema + return nil + } else if bytes.Equal(data, []byte("false")) { + *t = *FalseSchema + return nil + } + type SchemaAlt Schema + aux := &struct { + *SchemaAlt + }{ + SchemaAlt: (*SchemaAlt)(t), + } + return json.Unmarshal(data, aux) +} + +// MarshalJSON is used to serialize a schema object or boolean. +func (t *Schema) MarshalJSON() ([]byte, error) { + if t.boolean != nil { + if *t.boolean { + return []byte("true"), nil + } + return []byte("false"), nil + } + if reflect.DeepEqual(&Schema{}, t) { + // Don't bother returning empty schemas + return []byte("true"), nil + } + type SchemaAlt Schema + b, err := json.Marshal((*SchemaAlt)(t)) + if err != nil { + return nil, err + } + if len(t.Extras) == 0 { + return b, nil + } + m, err := json.Marshal(t.Extras) + if err != nil { + return nil, err + } + if len(b) == 2 { + return m, nil + } + b[len(b)-1] = ',' + return append(b, m[1:]...), nil +} + +func (r *Reflector) typeName(t reflect.Type) string { + if r.Namer != nil { + if name := r.Namer(t); name != "" { + return name + } + } + return t.Name() +} + +// Split on commas that are not preceded by `\`. +// This way, we prevent splitting regexes +func splitOnUnescapedCommas(tagString string) []string { + ret := make([]string, 0) + separated := strings.Split(tagString, ",") + ret = append(ret, separated[0]) + i := 0 + for _, nextTag := range separated[1:] { + if len(ret[i]) == 0 { + ret = append(ret, nextTag) + i++ + continue + } + + if ret[i][len(ret[i])-1] == '\\' { + ret[i] = ret[i][:len(ret[i])-1] + "," + nextTag + } else { + ret = append(ret, nextTag) + i++ + } + } + + return ret +} + +func fullyQualifiedTypeName(t reflect.Type) string { + return t.PkgPath() + "." + t.Name() +} diff --git a/vendor/github.com/invopop/jsonschema/reflect_comments.go b/vendor/github.com/invopop/jsonschema/reflect_comments.go new file mode 100644 index 000000000..ff374c75c --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/reflect_comments.go @@ -0,0 +1,146 @@ +package jsonschema + +import ( + "fmt" + "io/fs" + gopath "path" + "path/filepath" + "reflect" + "strings" + + "go/ast" + "go/doc" + "go/parser" + "go/token" +) + +type commentOptions struct { + fullObjectText bool // use the first sentence only? +} + +// CommentOption allows for special configuration options when preparing Go +// source files for comment extraction. +type CommentOption func(*commentOptions) + +// WithFullComment will configure the comment extraction to process to use an +// object type's full comment text instead of just the synopsis. +func WithFullComment() CommentOption { + return func(o *commentOptions) { + o.fullObjectText = true + } +} + +// AddGoComments will update the reflectors comment map with all the comments +// found in the provided source directories including sub-directories, in order to +// generate a dictionary of comments associated with Types and Fields. The results +// will be added to the `Reflect.CommentMap` ready to use with Schema "description" +// fields. +// +// The `go/parser` library is used to extract all the comments and unfortunately doesn't +// have a built-in way to determine the fully qualified name of a package. The `base` +// parameter, the URL used to import that package, is thus required to be able to match +// reflected types. +// +// When parsing type comments, by default we use the `go/doc`'s Synopsis method to extract +// the first phrase only. Field comments, which tend to be much shorter, will include everything. +// This behavior can be changed by using the `WithFullComment` option. +func (r *Reflector) AddGoComments(base, path string, opts ...CommentOption) error { + if r.CommentMap == nil { + r.CommentMap = make(map[string]string) + } + co := new(commentOptions) + for _, opt := range opts { + opt(co) + } + + return r.extractGoComments(base, path, r.CommentMap, co) +} + +func (r *Reflector) extractGoComments(base, path string, commentMap map[string]string, opts *commentOptions) error { + fset := token.NewFileSet() + dict := make(map[string][]*ast.Package) + err := filepath.Walk(path, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + d, err := parser.ParseDir(fset, path, nil, parser.ParseComments) + if err != nil { + return err + } + for _, v := range d { + // paths may have multiple packages, like for tests + k := gopath.Join(base, path) + dict[k] = append(dict[k], v) + } + } + return nil + }) + if err != nil { + return err + } + + for pkg, p := range dict { + for _, f := range p { + gtxt := "" + typ := "" + ast.Inspect(f, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.TypeSpec: + typ = x.Name.String() + if !ast.IsExported(typ) { + typ = "" + } else { + txt := x.Doc.Text() + if txt == "" && gtxt != "" { + txt = gtxt + gtxt = "" + } + if !opts.fullObjectText { + txt = doc.Synopsis(txt) + } + commentMap[fmt.Sprintf("%s.%s", pkg, typ)] = strings.TrimSpace(txt) + } + case *ast.Field: + txt := x.Doc.Text() + if txt == "" { + txt = x.Comment.Text() + } + if typ != "" && txt != "" { + for _, n := range x.Names { + if ast.IsExported(n.String()) { + k := fmt.Sprintf("%s.%s.%s", pkg, typ, n) + commentMap[k] = strings.TrimSpace(txt) + } + } + } + case *ast.GenDecl: + // remember for the next type + gtxt = x.Doc.Text() + } + return true + }) + } + } + + return nil +} + +func (r *Reflector) lookupComment(t reflect.Type, name string) string { + if r.LookupComment != nil { + if comment := r.LookupComment(t, name); comment != "" { + return comment + } + } + + if r.CommentMap == nil { + return "" + } + + n := fullyQualifiedTypeName(t) + if name != "" { + n = n + "." + name + } + + return r.CommentMap[n] +} diff --git a/vendor/github.com/invopop/jsonschema/schema.go b/vendor/github.com/invopop/jsonschema/schema.go new file mode 100644 index 000000000..2d914b8c8 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/schema.go @@ -0,0 +1,94 @@ +package jsonschema + +import ( + "encoding/json" + + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +// Version is the JSON Schema version. +var Version = "https://json-schema.org/draft/2020-12/schema" + +// Schema represents a JSON Schema object type. +// RFC draft-bhutton-json-schema-00 section 4.3 +type Schema struct { + // RFC draft-bhutton-json-schema-00 + Version string `json:"$schema,omitempty"` // section 8.1.1 + ID ID `json:"$id,omitempty"` // section 8.2.1 + Anchor string `json:"$anchor,omitempty"` // section 8.2.2 + Ref string `json:"$ref,omitempty"` // section 8.2.3.1 + DynamicRef string `json:"$dynamicRef,omitempty"` // section 8.2.3.2 + Definitions Definitions `json:"$defs,omitempty"` // section 8.2.4 + Comments string `json:"$comment,omitempty"` // section 8.3 + // RFC draft-bhutton-json-schema-00 section 10.2.1 (Sub-schemas with logic) + AllOf []*Schema `json:"allOf,omitempty"` // section 10.2.1.1 + AnyOf []*Schema `json:"anyOf,omitempty"` // section 10.2.1.2 + OneOf []*Schema `json:"oneOf,omitempty"` // section 10.2.1.3 + Not *Schema `json:"not,omitempty"` // section 10.2.1.4 + // RFC draft-bhutton-json-schema-00 section 10.2.2 (Apply sub-schemas conditionally) + If *Schema `json:"if,omitempty"` // section 10.2.2.1 + Then *Schema `json:"then,omitempty"` // section 10.2.2.2 + Else *Schema `json:"else,omitempty"` // section 10.2.2.3 + DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` // section 10.2.2.4 + // RFC draft-bhutton-json-schema-00 section 10.3.1 (arrays) + PrefixItems []*Schema `json:"prefixItems,omitempty"` // section 10.3.1.1 + Items *Schema `json:"items,omitempty"` // section 10.3.1.2 (replaces additionalItems) + Contains *Schema `json:"contains,omitempty"` // section 10.3.1.3 + // RFC draft-bhutton-json-schema-00 section 10.3.2 (sub-schemas) + Properties *orderedmap.OrderedMap[string, *Schema] `json:"properties,omitempty"` // section 10.3.2.1 + PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` // section 10.3.2.2 + AdditionalProperties *Schema `json:"additionalProperties,omitempty"` // section 10.3.2.3 + PropertyNames *Schema `json:"propertyNames,omitempty"` // section 10.3.2.4 + // RFC draft-bhutton-json-schema-validation-00, section 6 + Type string `json:"type,omitempty"` // section 6.1.1 + Enum []any `json:"enum,omitempty"` // section 6.1.2 + Const any `json:"const,omitempty"` // section 6.1.3 + MultipleOf json.Number `json:"multipleOf,omitempty"` // section 6.2.1 + Maximum json.Number `json:"maximum,omitempty"` // section 6.2.2 + ExclusiveMaximum json.Number `json:"exclusiveMaximum,omitempty"` // section 6.2.3 + Minimum json.Number `json:"minimum,omitempty"` // section 6.2.4 + ExclusiveMinimum json.Number `json:"exclusiveMinimum,omitempty"` // section 6.2.5 + MaxLength *uint64 `json:"maxLength,omitempty"` // section 6.3.1 + MinLength *uint64 `json:"minLength,omitempty"` // section 6.3.2 + Pattern string `json:"pattern,omitempty"` // section 6.3.3 + MaxItems *uint64 `json:"maxItems,omitempty"` // section 6.4.1 + MinItems *uint64 `json:"minItems,omitempty"` // section 6.4.2 + UniqueItems bool `json:"uniqueItems,omitempty"` // section 6.4.3 + MaxContains *uint64 `json:"maxContains,omitempty"` // section 6.4.4 + MinContains *uint64 `json:"minContains,omitempty"` // section 6.4.5 + MaxProperties *uint64 `json:"maxProperties,omitempty"` // section 6.5.1 + MinProperties *uint64 `json:"minProperties,omitempty"` // section 6.5.2 + Required []string `json:"required,omitempty"` // section 6.5.3 + DependentRequired map[string][]string `json:"dependentRequired,omitempty"` // section 6.5.4 + // RFC draft-bhutton-json-schema-validation-00, section 7 + Format string `json:"format,omitempty"` + // RFC draft-bhutton-json-schema-validation-00, section 8 + ContentEncoding string `json:"contentEncoding,omitempty"` // section 8.3 + ContentMediaType string `json:"contentMediaType,omitempty"` // section 8.4 + ContentSchema *Schema `json:"contentSchema,omitempty"` // section 8.5 + // RFC draft-bhutton-json-schema-validation-00, section 9 + Title string `json:"title,omitempty"` // section 9.1 + Description string `json:"description,omitempty"` // section 9.1 + Default any `json:"default,omitempty"` // section 9.2 + Deprecated bool `json:"deprecated,omitempty"` // section 9.3 + ReadOnly bool `json:"readOnly,omitempty"` // section 9.4 + WriteOnly bool `json:"writeOnly,omitempty"` // section 9.4 + Examples []any `json:"examples,omitempty"` // section 9.5 + + Extras map[string]any `json:"-"` + + // Special boolean representation of the Schema - section 4.3.2 + boolean *bool +} + +var ( + // TrueSchema defines a schema with a true value + TrueSchema = &Schema{boolean: &[]bool{true}[0]} + // FalseSchema defines a schema with a false value + FalseSchema = &Schema{boolean: &[]bool{false}[0]} +) + +// Definitions hold schema definitions. +// http://json-schema.org/latest/json-schema-validation.html#rfc.section.5.26 +// RFC draft-wright-json-schema-validation-00, section 5.26 +type Definitions map[string]*Schema diff --git a/vendor/github.com/invopop/jsonschema/utils.go b/vendor/github.com/invopop/jsonschema/utils.go new file mode 100644 index 000000000..ed8edf741 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/utils.go @@ -0,0 +1,26 @@ +package jsonschema + +import ( + "regexp" + "strings" + + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +// ToSnakeCase converts the provided string into snake case using dashes. +// This is useful for Schema IDs and definitions to be coherent with +// common JSON Schema examples. +func ToSnakeCase(str string) string { + snake := matchFirstCap.ReplaceAllString(str, "${1}-${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}-${2}") + return strings.ToLower(snake) +} + +// NewProperties is a helper method to instantiate a new properties ordered +// map. +func NewProperties() *orderedmap.OrderedMap[string, *Schema] { + return orderedmap.New[string, *Schema]() +} diff --git a/vendor/github.com/mailru/easyjson/LICENSE b/vendor/github.com/mailru/easyjson/LICENSE new file mode 100644 index 000000000..fbff658f7 --- /dev/null +++ b/vendor/github.com/mailru/easyjson/LICENSE @@ -0,0 +1,7 @@ +Copyright (c) 2016 Mail.Ru Group + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/mailru/easyjson/buffer/pool.go b/vendor/github.com/mailru/easyjson/buffer/pool.go new file mode 100644 index 000000000..598a54af9 --- /dev/null +++ b/vendor/github.com/mailru/easyjson/buffer/pool.go @@ -0,0 +1,278 @@ +// Package buffer implements a buffer for serialization, consisting of a chain of []byte-s to +// reduce copying and to allow reuse of individual chunks. +package buffer + +import ( + "io" + "net" + "sync" +) + +// PoolConfig contains configuration for the allocation and reuse strategy. +type PoolConfig struct { + StartSize int // Minimum chunk size that is allocated. + PooledSize int // Minimum chunk size that is reused, reusing chunks too small will result in overhead. + MaxSize int // Maximum chunk size that will be allocated. +} + +var config = PoolConfig{ + StartSize: 128, + PooledSize: 512, + MaxSize: 32768, +} + +// Reuse pool: chunk size -> pool. +var buffers = map[int]*sync.Pool{} + +func initBuffers() { + for l := config.PooledSize; l <= config.MaxSize; l *= 2 { + buffers[l] = new(sync.Pool) + } +} + +func init() { + initBuffers() +} + +// Init sets up a non-default pooling and allocation strategy. Should be run before serialization is done. +func Init(cfg PoolConfig) { + config = cfg + initBuffers() +} + +// putBuf puts a chunk to reuse pool if it can be reused. +func putBuf(buf []byte) { + size := cap(buf) + if size < config.PooledSize { + return + } + if c := buffers[size]; c != nil { + c.Put(buf[:0]) + } +} + +// getBuf gets a chunk from reuse pool or creates a new one if reuse failed. +func getBuf(size int) []byte { + if size >= config.PooledSize { + if c := buffers[size]; c != nil { + v := c.Get() + if v != nil { + return v.([]byte) + } + } + } + return make([]byte, 0, size) +} + +// Buffer is a buffer optimized for serialization without extra copying. +type Buffer struct { + + // Buf is the current chunk that can be used for serialization. + Buf []byte + + toPool []byte + bufs [][]byte +} + +// EnsureSpace makes sure that the current chunk contains at least s free bytes, +// possibly creating a new chunk. +func (b *Buffer) EnsureSpace(s int) { + if cap(b.Buf)-len(b.Buf) < s { + b.ensureSpaceSlow(s) + } +} + +func (b *Buffer) ensureSpaceSlow(s int) { + l := len(b.Buf) + if l > 0 { + if cap(b.toPool) != cap(b.Buf) { + // Chunk was reallocated, toPool can be pooled. + putBuf(b.toPool) + } + if cap(b.bufs) == 0 { + b.bufs = make([][]byte, 0, 8) + } + b.bufs = append(b.bufs, b.Buf) + l = cap(b.toPool) * 2 + } else { + l = config.StartSize + } + + if l > config.MaxSize { + l = config.MaxSize + } + b.Buf = getBuf(l) + b.toPool = b.Buf +} + +// AppendByte appends a single byte to buffer. +func (b *Buffer) AppendByte(data byte) { + b.EnsureSpace(1) + b.Buf = append(b.Buf, data) +} + +// AppendBytes appends a byte slice to buffer. +func (b *Buffer) AppendBytes(data []byte) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendBytesSlow(data) + } +} + +func (b *Buffer) appendBytesSlow(data []byte) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// AppendString appends a string to buffer. +func (b *Buffer) AppendString(data string) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendStringSlow(data) + } +} + +func (b *Buffer) appendStringSlow(data string) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// Size computes the size of a buffer by adding sizes of every chunk. +func (b *Buffer) Size() int { + size := len(b.Buf) + for _, buf := range b.bufs { + size += len(buf) + } + return size +} + +// DumpTo outputs the contents of a buffer to a writer and resets the buffer. +func (b *Buffer) DumpTo(w io.Writer) (written int, err error) { + bufs := net.Buffers(b.bufs) + if len(b.Buf) > 0 { + bufs = append(bufs, b.Buf) + } + n, err := bufs.WriteTo(w) + + for _, buf := range b.bufs { + putBuf(buf) + } + putBuf(b.toPool) + + b.bufs = nil + b.Buf = nil + b.toPool = nil + + return int(n), err +} + +// BuildBytes creates a single byte slice with all the contents of the buffer. Data is +// copied if it does not fit in a single chunk. You can optionally provide one byte +// slice as argument that it will try to reuse. +func (b *Buffer) BuildBytes(reuse ...[]byte) []byte { + if len(b.bufs) == 0 { + ret := b.Buf + b.toPool = nil + b.Buf = nil + return ret + } + + var ret []byte + size := b.Size() + + // If we got a buffer as argument and it is big enough, reuse it. + if len(reuse) == 1 && cap(reuse[0]) >= size { + ret = reuse[0][:0] + } else { + ret = make([]byte, 0, size) + } + for _, buf := range b.bufs { + ret = append(ret, buf...) + putBuf(buf) + } + + ret = append(ret, b.Buf...) + putBuf(b.toPool) + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} + +type readCloser struct { + offset int + bufs [][]byte +} + +func (r *readCloser) Read(p []byte) (n int, err error) { + for _, buf := range r.bufs { + // Copy as much as we can. + x := copy(p[n:], buf[r.offset:]) + n += x // Increment how much we filled. + + // Did we empty the whole buffer? + if r.offset+x == len(buf) { + // On to the next buffer. + r.offset = 0 + r.bufs = r.bufs[1:] + + // We can release this buffer. + putBuf(buf) + } else { + r.offset += x + } + + if n == len(p) { + break + } + } + // No buffers left or nothing read? + if len(r.bufs) == 0 { + err = io.EOF + } + return +} + +func (r *readCloser) Close() error { + // Release all remaining buffers. + for _, buf := range r.bufs { + putBuf(buf) + } + // In case Close gets called multiple times. + r.bufs = nil + + return nil +} + +// ReadCloser creates an io.ReadCloser with all the contents of the buffer. +func (b *Buffer) ReadCloser() io.ReadCloser { + ret := &readCloser{0, append(b.bufs, b.Buf)} + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} diff --git a/vendor/github.com/mailru/easyjson/jwriter/writer.go b/vendor/github.com/mailru/easyjson/jwriter/writer.go new file mode 100644 index 000000000..2c5b20105 --- /dev/null +++ b/vendor/github.com/mailru/easyjson/jwriter/writer.go @@ -0,0 +1,405 @@ +// Package jwriter contains a JSON writer. +package jwriter + +import ( + "io" + "strconv" + "unicode/utf8" + + "github.com/mailru/easyjson/buffer" +) + +// Flags describe various encoding options. The behavior may be actually implemented in the encoder, but +// Flags field in Writer is used to set and pass them around. +type Flags int + +const ( + NilMapAsEmpty Flags = 1 << iota // Encode nil map as '{}' rather than 'null'. + NilSliceAsEmpty // Encode nil slice as '[]' rather than 'null'. +) + +// Writer is a JSON writer. +type Writer struct { + Flags Flags + + Error error + Buffer buffer.Buffer + NoEscapeHTML bool +} + +// Size returns the size of the data that was written out. +func (w *Writer) Size() int { + return w.Buffer.Size() +} + +// DumpTo outputs the data to given io.Writer, resetting the buffer. +func (w *Writer) DumpTo(out io.Writer) (written int, err error) { + return w.Buffer.DumpTo(out) +} + +// BuildBytes returns writer data as a single byte slice. You can optionally provide one byte slice +// as argument that it will try to reuse. +func (w *Writer) BuildBytes(reuse ...[]byte) ([]byte, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.BuildBytes(reuse...), nil +} + +// ReadCloser returns an io.ReadCloser that can be used to read the data. +// ReadCloser also resets the buffer. +func (w *Writer) ReadCloser() (io.ReadCloser, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.ReadCloser(), nil +} + +// RawByte appends raw binary data to the buffer. +func (w *Writer) RawByte(c byte) { + w.Buffer.AppendByte(c) +} + +// RawByte appends raw binary data to the buffer. +func (w *Writer) RawString(s string) { + w.Buffer.AppendString(s) +} + +// Raw appends raw binary data to the buffer or sets the error if it is given. Useful for +// calling with results of MarshalJSON-like functions. +func (w *Writer) Raw(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.Buffer.AppendBytes(data) + default: + w.RawString("null") + } +} + +// RawText encloses raw binary data in quotes and appends in to the buffer. +// Useful for calling with results of MarshalText-like functions. +func (w *Writer) RawText(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.String(string(data)) + default: + w.RawString("null") + } +} + +// Base64Bytes appends data to the buffer after base64 encoding it +func (w *Writer) Base64Bytes(data []byte) { + if data == nil { + w.Buffer.AppendString("null") + return + } + w.Buffer.AppendByte('"') + w.base64(data) + w.Buffer.AppendByte('"') +} + +func (w *Writer) Uint8(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint16(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint32(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint64(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) +} + +func (w *Writer) Int8(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int16(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int32(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int64(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) +} + +func (w *Writer) Uint8Str(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint16Str(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint32Str(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) UintStr(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint64Str(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) UintptrStr(n uintptr) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int8Str(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int16Str(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int32Str(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) IntStr(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int64Str(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Float32(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) +} + +func (w *Writer) Float32Str(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Float64(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, n, 'g', -1, 64) +} + +func (w *Writer) Float64Str(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 64) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Bool(v bool) { + w.Buffer.EnsureSpace(5) + if v { + w.Buffer.Buf = append(w.Buffer.Buf, "true"...) + } else { + w.Buffer.Buf = append(w.Buffer.Buf, "false"...) + } +} + +const chars = "0123456789abcdef" + +func getTable(falseValues ...int) [128]bool { + table := [128]bool{} + + for i := 0; i < 128; i++ { + table[i] = true + } + + for _, v := range falseValues { + table[v] = false + } + + return table +} + +var ( + htmlEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '&', '<', '>', '\\') + htmlNoEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '\\') +) + +func (w *Writer) String(s string) { + w.Buffer.AppendByte('"') + + // Portions of the string that contain no escapes are appended as + // byte slices. + + p := 0 // last non-escape symbol + + escapeTable := &htmlEscapeTable + if w.NoEscapeHTML { + escapeTable = &htmlNoEscapeTable + } + + for i := 0; i < len(s); { + c := s[i] + + if c < utf8.RuneSelf { + if escapeTable[c] { + // single-width character, no escaping is required + i++ + continue + } + + w.Buffer.AppendString(s[p:i]) + switch c { + case '\t': + w.Buffer.AppendString(`\t`) + case '\r': + w.Buffer.AppendString(`\r`) + case '\n': + w.Buffer.AppendString(`\n`) + case '\\': + w.Buffer.AppendString(`\\`) + case '"': + w.Buffer.AppendString(`\"`) + default: + w.Buffer.AppendString(`\u00`) + w.Buffer.AppendByte(chars[c>>4]) + w.Buffer.AppendByte(chars[c&0xf]) + } + + i++ + p = i + continue + } + + // broken utf + runeValue, runeWidth := utf8.DecodeRuneInString(s[i:]) + if runeValue == utf8.RuneError && runeWidth == 1 { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\ufffd`) + i++ + p = i + continue + } + + // jsonp stuff - tab separator and line separator + if runeValue == '\u2028' || runeValue == '\u2029' { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\u202`) + w.Buffer.AppendByte(chars[runeValue&0xf]) + i += runeWidth + p = i + continue + } + i += runeWidth + } + w.Buffer.AppendString(s[p:]) + w.Buffer.AppendByte('"') +} + +const encode = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +const padChar = '=' + +func (w *Writer) base64(in []byte) { + + if len(in) == 0 { + return + } + + w.Buffer.EnsureSpace(((len(in)-1)/3 + 1) * 4) + + si := 0 + n := (len(in) / 3) * 3 + + for si < n { + // Convert 3x 8bit source bytes into 4 bytes + val := uint(in[si+0])<<16 | uint(in[si+1])<<8 | uint(in[si+2]) + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F], encode[val>>6&0x3F], encode[val&0x3F]) + + si += 3 + } + + remain := len(in) - si + if remain == 0 { + return + } + + // Add the remaining small block + val := uint(in[si+0]) << 16 + if remain == 2 { + val |= uint(in[si+1]) << 8 + } + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F]) + + switch remain { + case 2: + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>6&0x3F], byte(padChar)) + case 1: + w.Buffer.Buf = append(w.Buffer.Buf, byte(padChar), byte(padChar)) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/LICENSE b/vendor/github.com/mark3labs/mcp-go/LICENSE new file mode 100644 index 000000000..3d4843545 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Anthropic, PBC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/consts.go b/vendor/github.com/mark3labs/mcp-go/mcp/consts.go new file mode 100644 index 000000000..66eb3803b --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/consts.go @@ -0,0 +1,9 @@ +package mcp + +const ( + ContentTypeText = "text" + ContentTypeImage = "image" + ContentTypeAudio = "audio" + ContentTypeLink = "resource_link" + ContentTypeResource = "resource" +) diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/errors.go b/vendor/github.com/mark3labs/mcp-go/mcp/errors.go new file mode 100644 index 000000000..aead24744 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/errors.go @@ -0,0 +1,85 @@ +package mcp + +import ( + "errors" + "fmt" +) + +// Sentinel errors for common JSON-RPC error codes. +var ( + // ErrParseError indicates a JSON parsing error (code: PARSE_ERROR). + ErrParseError = errors.New("parse error") + + // ErrInvalidRequest indicates an invalid JSON-RPC request (code: INVALID_REQUEST). + ErrInvalidRequest = errors.New("invalid request") + + // ErrMethodNotFound indicates the requested method does not exist (code: METHOD_NOT_FOUND). + ErrMethodNotFound = errors.New("method not found") + + // ErrInvalidParams indicates invalid method parameters (code: INVALID_PARAMS). + ErrInvalidParams = errors.New("invalid params") + + // ErrInternalError indicates an internal JSON-RPC error (code: INTERNAL_ERROR). + ErrInternalError = errors.New("internal error") + + // ErrRequestInterrupted indicates a request was cancelled or timed out (code: REQUEST_INTERRUPTED). + ErrRequestInterrupted = errors.New("request interrupted") + + // ErrResourceNotFound indicates a requested resource was not found (code: RESOURCE_NOT_FOUND). + ErrResourceNotFound = errors.New("resource not found") +) + +// UnsupportedProtocolVersionError is returned when the server responds with +// a protocol version that the client doesn't support. +type UnsupportedProtocolVersionError struct { + Version string +} + +func (e UnsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.Version) +} + +// Is implements the errors.Is interface for better error handling +func (e UnsupportedProtocolVersionError) Is(target error) bool { + _, ok := target.(UnsupportedProtocolVersionError) + return ok +} + +// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError +func IsUnsupportedProtocolVersion(err error) bool { + _, ok := err.(UnsupportedProtocolVersionError) + return ok +} + +// AsError maps JSONRPCErrorDetails to a Go error. +// Returns sentinel errors wrapped with custom messages for known codes. +// Defaults to a generic error with the original message when the code is not mapped. +func (e *JSONRPCErrorDetails) AsError() error { + var err error + + switch e.Code { + case PARSE_ERROR: + err = ErrParseError + case INVALID_REQUEST: + err = ErrInvalidRequest + case METHOD_NOT_FOUND: + err = ErrMethodNotFound + case INVALID_PARAMS: + err = ErrInvalidParams + case INTERNAL_ERROR: + err = ErrInternalError + case REQUEST_INTERRUPTED: + err = ErrRequestInterrupted + case RESOURCE_NOT_FOUND: + err = ErrResourceNotFound + default: + return errors.New(e.Message) + } + + // Wrap the sentinel error with the custom message if it differs from the sentinel. + if e.Message != "" && e.Message != err.Error() { + return fmt.Errorf("%w: %s", err, e.Message) + } + + return err +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go new file mode 100644 index 000000000..9b0b48ed2 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go @@ -0,0 +1,176 @@ +package mcp + +import "net/http" + +/* Prompts */ + +// ListPromptsRequest is sent from the client to request a list of prompts and +// prompt templates the server has. +type ListPromptsRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListPromptsResult is the server's response to a prompts/list request from +// the client. +type ListPromptsResult struct { + PaginatedResult + Prompts []Prompt `json:"prompts"` +} + +// GetPromptRequest is used by the client to get a prompt provided by the +// server. +type GetPromptRequest struct { + Request + Params GetPromptParams `json:"params"` + Header http.Header `json:"-"` +} + +type GetPromptParams struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` +} + +// GetPromptResult is the server's response to a prompts/get request from the +// client. +type GetPromptResult struct { + Result + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// Prompt represents a prompt or prompt template that the server offers. +// If Arguments is non-nil and non-empty, this indicates the prompt is a template +// that requires argument values to be provided when calling prompts/get. +// If Arguments is nil or empty, this is a static prompt that takes no arguments. +type Prompt struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The name of the prompt or prompt template. + Name string `json:"name"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // A list of arguments to use for templating the prompt. + // The presence of arguments indicates this is a template prompt. + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// GetName returns the name of the prompt. +func (p Prompt) GetName() string { + return p.Name +} + +// PromptArgument describes an argument that a prompt template can accept. +// When a prompt includes arguments, clients must provide values for all +// required arguments when making a prompts/get request. +type PromptArgument struct { + // The name of the argument. + Name string `json:"name"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + // If true, clients must include this argument when calling prompts/get. + Required bool `json:"required,omitempty"` +} + +// Role represents the sender or recipient of messages and data in a +// conversation. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + +// PromptMessage describes a message returned as part of a prompt. +// +// This is similar to `SamplingMessage`, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Role Role `json:"role"` + Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource +} + +// PromptListChangedNotification is an optional notification from the server +// to the client, informing it that the list of prompts it offers has changed. This +// may be issued by servers without any previous subscription from the client. +type PromptListChangedNotification struct { + Notification +} + +// PromptOption is a function that configures a Prompt. +// It provides a flexible way to set various properties of a Prompt using the functional options pattern. +type PromptOption func(*Prompt) + +// ArgumentOption is a function that configures a PromptArgument. +// It allows for flexible configuration of prompt arguments using the functional options pattern. +type ArgumentOption func(*PromptArgument) + +// +// Core Prompt Functions +// + +// NewPrompt creates a new Prompt with the given name and options. +// The prompt will be configured based on the provided options. +// Options are applied in order, allowing for flexible prompt configuration. +func NewPrompt(name string, opts ...PromptOption) Prompt { + prompt := Prompt{ + Name: name, + } + + for _, opt := range opts { + opt(&prompt) + } + + return prompt +} + +// WithPromptDescription adds a description to the Prompt. +// The description should provide a clear, human-readable explanation of what the prompt does. +func WithPromptDescription(description string) PromptOption { + return func(p *Prompt) { + p.Description = description + } +} + +// WithArgument adds an argument to the prompt's argument list. +// The argument will be configured based on the provided options. +func WithArgument(name string, opts ...ArgumentOption) PromptOption { + return func(p *Prompt) { + arg := PromptArgument{ + Name: name, + } + + for _, opt := range opts { + opt(&arg) + } + + if p.Arguments == nil { + p.Arguments = make([]PromptArgument, 0) + } + p.Arguments = append(p.Arguments, arg) + } +} + +// +// Argument Options +// + +// ArgumentDescription adds a description to a prompt argument. +// The description should explain the purpose and expected values of the argument. +func ArgumentDescription(desc string) ArgumentOption { + return func(arg *PromptArgument) { + arg.Description = desc + } +} + +// RequiredArgument marks an argument as required in the prompt. +// Required arguments must be provided when getting the prompt. +func RequiredArgument() ArgumentOption { + return func(arg *PromptArgument) { + arg.Required = true + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go new file mode 100644 index 000000000..07a59a322 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go @@ -0,0 +1,99 @@ +package mcp + +import "github.com/yosida95/uritemplate/v3" + +// ResourceOption is a function that configures a Resource. +// It provides a flexible way to set various properties of a Resource using the functional options pattern. +type ResourceOption func(*Resource) + +// NewResource creates a new Resource with the given URI, name and options. +// The resource will be configured based on the provided options. +// Options are applied in order, allowing for flexible resource configuration. +func NewResource(uri string, name string, opts ...ResourceOption) Resource { + resource := Resource{ + URI: uri, + Name: name, + } + + for _, opt := range opts { + opt(&resource) + } + + return resource +} + +// WithResourceDescription adds a description to the Resource. +// The description should provide a clear, human-readable explanation of what the resource represents. +func WithResourceDescription(description string) ResourceOption { + return func(r *Resource) { + r.Description = description + } +} + +// WithMIMEType sets the MIME type for the Resource. +// This should indicate the format of the resource's contents. +func WithMIMEType(mimeType string) ResourceOption { + return func(r *Resource) { + r.MIMEType = mimeType + } +} + +// WithAnnotations adds annotations to the Resource. +// Annotations can provide additional metadata about the resource's intended use. +func WithAnnotations(audience []Role, priority float64) ResourceOption { + return func(r *Resource) { + if r.Annotations == nil { + r.Annotations = &Annotations{} + } + r.Annotations.Audience = audience + r.Annotations.Priority = priority + } +} + +// ResourceTemplateOption is a function that configures a ResourceTemplate. +// It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. +type ResourceTemplateOption func(*ResourceTemplate) + +// NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. +// The template will be configured based on the provided options. +// Options are applied in order, allowing for flexible template configuration. +func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { + template := ResourceTemplate{ + URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, + Name: name, + } + + for _, opt := range opts { + opt(&template) + } + + return template +} + +// WithTemplateDescription adds a description to the ResourceTemplate. +// The description should provide a clear, human-readable explanation of what resources this template represents. +func WithTemplateDescription(description string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.Description = description + } +} + +// WithTemplateMIMEType sets the MIME type for the ResourceTemplate. +// This should only be set if all resources matching this template will have the same type. +func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.MIMEType = mimeType + } +} + +// WithTemplateAnnotations adds annotations to the ResourceTemplate. +// Annotations can provide additional metadata about the template's intended use. +func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { + return func(t *ResourceTemplate) { + if t.Annotations == nil { + t.Annotations = &Annotations{} + } + t.Annotations.Audience = audience + t.Annotations.Priority = priority + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go new file mode 100644 index 000000000..42e888d52 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -0,0 +1,1331 @@ +package mcp + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "strconv" + + "github.com/invopop/jsonschema" +) + +var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") + +// ListToolsRequest is sent from the client to request a list of tools the +// server has. +type ListToolsRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListToolsResult is the server's response to a tools/list request from the +// client. +type ListToolsResult struct { + PaginatedResult + Tools []Tool `json:"tools"` +} + +// CallToolResult is the server's response to a tool call. +// +// Any errors that originate from the tool SHOULD be reported inside the result +// object, with `isError` set to true, _not_ as an MCP protocol-level error +// response. Otherwise, the LLM would not be able to see that an error occurred +// and self-correct. +// +// However, any errors in _finding_ the tool, an error indicating that the +// server does not support tool calls, or any other exceptional conditions, +// should be reported as an MCP error response. +type CallToolResult struct { + Result + Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource + // Structured content returned as a JSON object in the structuredContent field of a result. + // For backwards compatibility, a tool that returns structured content SHOULD also return + // functionally equivalent unstructured content. + StructuredContent any `json:"structuredContent,omitempty"` + // Whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + IsError bool `json:"isError,omitempty"` +} + +// CallToolRequest is used by the client to invoke a tool provided by the server. +type CallToolRequest struct { + Request + Header http.Header `json:"-"` // HTTP headers from the original request + Params CallToolParams `json:"params"` +} + +type CallToolParams struct { + Name string `json:"name"` + Arguments any `json:"arguments,omitempty"` + Meta *Meta `json:"_meta,omitempty"` +} + +// GetArguments returns the Arguments as map[string]any for backward compatibility +// If Arguments is not a map, it returns an empty map +func (r CallToolRequest) GetArguments() map[string]any { + if args, ok := r.Params.Arguments.(map[string]any); ok { + return args + } + return nil +} + +// GetRawArguments returns the Arguments as-is without type conversion +// This allows users to access the raw arguments in any format +func (r CallToolRequest) GetRawArguments() any { + return r.Params.Arguments +} + +// BindArguments unmarshals the Arguments into the provided struct +// This is useful for working with strongly-typed arguments +func (r CallToolRequest) BindArguments(target any) error { + if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr { + return fmt.Errorf("target must be a non-nil pointer") + } + + // Fast-path: already raw JSON + if raw, ok := r.Params.Arguments.(json.RawMessage); ok { + return json.Unmarshal(raw, target) + } + + data, err := json.Marshal(r.Params.Arguments) + if err != nil { + return fmt.Errorf("failed to marshal arguments: %w", err) + } + + return json.Unmarshal(data, target) +} + +// GetString returns a string argument by key, or the default value if not found +func (r CallToolRequest) GetString(key string, defaultValue string) string { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return defaultValue +} + +// RequireString returns a string argument by key, or an error if not found or not a string +func (r CallToolRequest) RequireString(key string) (string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str, nil + } + return "", fmt.Errorf("argument %q is not a string", key) + } + return "", fmt.Errorf("required argument %q not found", key) +} + +// GetInt returns an int argument by key, or the default value if not found +func (r CallToolRequest) GetInt(key string, defaultValue int) int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v + case float64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + } + return defaultValue +} + +// RequireInt returns an int argument by key, or an error if not found or not convertible to int +func (r CallToolRequest) RequireInt(key string) (int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v, nil + case float64: + return int(v), nil + case string: + if i, err := strconv.Atoi(v); err == nil { + return i, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to int", key) + default: + return 0, fmt.Errorf("argument %q is not an int", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetFloat returns a float64 argument by key, or the default value if not found +func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v + case int: + return float64(v) + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + } + return defaultValue +} + +// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64 +func (r CallToolRequest) RequireFloat(key string) (float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to float64", key) + default: + return 0, fmt.Errorf("argument %q is not a float64", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetBool returns a bool argument by key, or the default value if not found +func (r CallToolRequest) GetBool(key string, defaultValue bool) bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b + } + case int: + return v != 0 + case float64: + return v != 0 + } + } + return defaultValue +} + +// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool +func (r CallToolRequest) RequireBool(key string) (bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v, nil + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b, nil + } + return false, fmt.Errorf("argument %q cannot be converted to bool", key) + case int: + return v != 0, nil + case float64: + return v != 0, nil + default: + return false, fmt.Errorf("argument %q is not a bool", key) + } + } + return false, fmt.Errorf("required argument %q not found", key) +} + +// GetStringSlice returns a string slice argument by key, or the default value if not found +func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v + case []any: + result := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + } + } + return defaultValue +} + +// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice +func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v, nil + case []any: + result := make([]string, 0, len(v)) + for i, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } else { + return nil, fmt.Errorf("item %d in argument %q is not a string", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a string slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetIntSlice returns an int slice argument by key, or the default value if not found +func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v + case []any: + result := make([]int, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } + } + } + return result + } + } + return defaultValue +} + +// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice +func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v, nil + case []any: + result := make([]int, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not an int", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not an int slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetFloatSlice returns a float64 slice argument by key, or the default value if not found +func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v + case []any: + result := make([]float64, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } + } + } + return result + } + } + return defaultValue +} + +// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice +func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v, nil + case []any: + result := make([]float64, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a float64 slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetBoolSlice returns a bool slice argument by key, or the default value if not found +func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v + case []any: + result := make([]bool, 0, len(v)) + for _, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + } + } + return result + } + } + return defaultValue +} + +// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice +func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v, nil + case []any: + result := make([]bool, 0, len(v)) + for i, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + default: + return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a bool slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// MarshalJSON implements custom JSON marshaling for CallToolResult +func (r CallToolResult) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + + // Marshal Meta if present + if r.Meta != nil { + m["_meta"] = r.Meta + } + + // Marshal Content array + content := make([]any, len(r.Content)) + for i, c := range r.Content { + content[i] = c + } + m["content"] = content + + // Marshal StructuredContent if present + if r.StructuredContent != nil { + m["structuredContent"] = r.StructuredContent + } + + // Marshal IsError if true + if r.IsError { + m["isError"] = r.IsError + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling for CallToolResult +func (r *CallToolResult) UnmarshalJSON(data []byte) error { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Unmarshal Meta + if meta, ok := raw["_meta"]; ok { + if metaMap, ok := meta.(map[string]any); ok { + r.Meta = NewMetaFromMap(metaMap) + } + } + + // Unmarshal Content array + if contentRaw, ok := raw["content"]; ok { + if contentArray, ok := contentRaw.([]any); ok { + r.Content = make([]Content, len(contentArray)) + for i, item := range contentArray { + itemBytes, err := json.Marshal(item) + if err != nil { + return err + } + content, err := UnmarshalContent(itemBytes) + if err != nil { + return err + } + r.Content[i] = content + } + } + } + + // Unmarshal StructuredContent if present + if structured, ok := raw["structuredContent"]; ok { + r.StructuredContent = structured + } + + // Unmarshal IsError + if isError, ok := raw["isError"]; ok { + if isErrorBool, ok := isError.(bool); ok { + r.IsError = isErrorBool + } + } + + return nil +} + +// ToolListChangedNotification is an optional notification from the server to +// the client, informing it that the list of tools it offers has changed. This may +// be issued by servers without any previous subscription from the client. +type ToolListChangedNotification struct { + Notification +} + +// Tool represents the definition for a tool the client can call. +type Tool struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The name of the tool. + Name string `json:"name"` + // A human-readable description of the tool. + Description string `json:"description,omitempty"` + // A JSON Schema object defining the expected parameters for the tool. + InputSchema ToolInputSchema `json:"inputSchema"` + // Alternative to InputSchema - allows arbitrary JSON Schema to be provided + RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // A JSON Schema object defining the expected output returned by the tool . + OutputSchema ToolOutputSchema `json:"outputSchema,omitempty"` + // Optional JSON Schema defining expected output structure + RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // Optional properties describing tool behavior + Annotations ToolAnnotation `json:"annotations"` +} + +// GetName returns the name of the tool. +func (t Tool) GetName() string { + return t.Name +} + +// MarshalJSON implements the json.Marshaler interface for Tool. +// It handles marshaling either InputSchema or RawInputSchema based on which is set. +func (t Tool) MarshalJSON() ([]byte, error) { + // Create a map to build the JSON structure + m := make(map[string]any, 5) + + // Add the name and description + m["name"] = t.Name + if t.Description != "" { + m["description"] = t.Description + } + + // Determine which input schema to use + if t.RawInputSchema != nil { + if t.InputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["inputSchema"] = t.RawInputSchema + } else { + // Use the structured InputSchema + m["inputSchema"] = t.InputSchema + } + + // Add output schema if present + if t.RawOutputSchema != nil { + if t.OutputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["outputSchema"] = t.RawOutputSchema + } else if t.OutputSchema.Type != "" { // If no output schema is specified, do not return anything + m["outputSchema"] = t.OutputSchema + } + + m["annotations"] = t.Annotations + + // Marshal Meta if present + if t.Meta != nil { + m["_meta"] = t.Meta + } + + return json.Marshal(m) +} + +// ToolArgumentsSchema represents a JSON Schema for tool arguments. +type ToolArgumentsSchema struct { + Defs map[string]any `json:"$defs,omitempty"` + Type string `json:"type"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type ToolInputSchema ToolArgumentsSchema // For retro-compatibility +type ToolOutputSchema ToolArgumentsSchema + +// MarshalJSON implements the json.Marshaler interface for ToolInputSchema. +func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + m["type"] = tis.Type + + if tis.Defs != nil { + m["$defs"] = tis.Defs + } + + // Marshal Properties to '{}' rather than `nil` when its length equals zero + if tis.Properties != nil { + m["properties"] = tis.Properties + } + + if len(tis.Required) > 0 { + m["required"] = tis.Required + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for ToolArgumentsSchema. +// It handles both "$defs" (JSON Schema 2019-09+) and "definitions" (JSON Schema draft-07) +// by reading either field and storing it in the Defs field. +func (tis *ToolArgumentsSchema) UnmarshalJSON(data []byte) error { + // Use a temporary type to avoid infinite recursion + type Alias ToolArgumentsSchema + aux := &struct { + Definitions map[string]any `json:"definitions,omitempty"` + *Alias + }{ + Alias: (*Alias)(tis), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // If $defs wasn't provided but definitions was, use definitions + if tis.Defs == nil && aux.Definitions != nil { + tis.Defs = aux.Definitions + } + + return nil +} + +type ToolAnnotation struct { + // Human-readable title for the tool + Title string `json:"title,omitempty"` + // If true, the tool does not modify its environment + ReadOnlyHint *bool `json:"readOnlyHint,omitempty"` + // If true, the tool may perform destructive updates + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, repeated calls with same args have no additional effect + IdempotentHint *bool `json:"idempotentHint,omitempty"` + // If true, tool interacts with external entities + OpenWorldHint *bool `json:"openWorldHint,omitempty"` +} + +// ToolOption is a function that configures a Tool. +// It provides a flexible way to set various properties of a Tool using the functional options pattern. +type ToolOption func(*Tool) + +// PropertyOption is a function that configures a property in a Tool's input schema. +// It allows for flexible configuration of JSON Schema properties using the functional options pattern. +type PropertyOption func(map[string]any) + +// +// Core Tool Functions +// + +// NewTool creates a new Tool with the given name and options. +// The tool will have an object-type input schema with configurable properties. +// Options are applied in order, allowing for flexible tool configuration. +func NewTool(name string, opts ...ToolOption) Tool { + tool := Tool{ + Name: name, + InputSchema: ToolInputSchema{ + Type: "object", + Properties: make(map[string]any), + Required: nil, // Will be omitted from JSON if empty + }, + Annotations: ToolAnnotation{ + Title: "", + ReadOnlyHint: ToBoolPtr(false), + DestructiveHint: ToBoolPtr(true), + IdempotentHint: ToBoolPtr(false), + OpenWorldHint: ToBoolPtr(true), + }, + } + + for _, opt := range opts { + opt(&tool) + } + + return tool +} + +// NewToolWithRawSchema creates a new Tool with the given name and a raw JSON +// Schema. This allows for arbitrary JSON Schema to be used for the tool's input +// schema. +// +// NOTE a [Tool] built in such a way is incompatible with the [ToolOption] and +// runtime errors will result from supplying a [ToolOption] to a [Tool] built +// with this function. +func NewToolWithRawSchema(name, description string, schema json.RawMessage) Tool { + tool := Tool{ + Name: name, + Description: description, + RawInputSchema: schema, + } + + return tool +} + +// WithDescription adds a description to the Tool. +// The description should provide a clear, human-readable explanation of what the tool does. +func WithDescription(description string) ToolOption { + return func(t *Tool) { + t.Description = description + } +} + +// WithInputSchema creates a ToolOption that sets the input schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithInputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + t.InputSchema.Type = "" + t.RawInputSchema = json.RawMessage(mcpSchema) + } +} + +// WithRawInputSchema sets a raw JSON schema for the tool's input. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawInputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawInputSchema = schema + } +} + +// WithOutputSchema creates a ToolOption that sets the output schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithOutputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + // Retrieve the schema from raw JSON + if err := json.Unmarshal(mcpSchema, &t.OutputSchema); err != nil { + // Skip and maintain backward compatibility + return + } + + // Always set the type to "object" as of the current MCP spec + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + t.OutputSchema.Type = "object" + } +} + +// WithRawOutputSchema sets a raw JSON schema for the tool's output. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawOutputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawOutputSchema = schema + } +} + +// WithToolAnnotation adds optional hints about the Tool. +func WithToolAnnotation(annotation ToolAnnotation) ToolOption { + return func(t *Tool) { + t.Annotations = annotation + } +} + +// WithTitleAnnotation sets the Title field of the Tool's Annotations. +// It provides a human-readable title for the tool. +func WithTitleAnnotation(title string) ToolOption { + return func(t *Tool) { + t.Annotations.Title = title + } +} + +// WithReadOnlyHintAnnotation sets the ReadOnlyHint field of the Tool's Annotations. +// If true, it indicates the tool does not modify its environment. +func WithReadOnlyHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.ReadOnlyHint = &value + } +} + +// WithDestructiveHintAnnotation sets the DestructiveHint field of the Tool's Annotations. +// If true, it indicates the tool may perform destructive updates. +func WithDestructiveHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.DestructiveHint = &value + } +} + +// WithIdempotentHintAnnotation sets the IdempotentHint field of the Tool's Annotations. +// If true, it indicates repeated calls with the same arguments have no additional effect. +func WithIdempotentHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.IdempotentHint = &value + } +} + +// WithOpenWorldHintAnnotation sets the OpenWorldHint field of the Tool's Annotations. +// If true, it indicates the tool interacts with external entities. +func WithOpenWorldHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.OpenWorldHint = &value + } +} + +// +// Common Property Options +// + +// Description adds a description to a property in the JSON Schema. +// The description should explain the purpose and expected values of the property. +func Description(desc string) PropertyOption { + return func(schema map[string]any) { + schema["description"] = desc + } +} + +// Required marks a property as required in the tool's input schema. +// Required properties must be provided when using the tool. +func Required() PropertyOption { + return func(schema map[string]any) { + schema["required"] = true + } +} + +// Title adds a display-friendly title to a property in the JSON Schema. +// This title can be used by UI components to show a more readable property name. +func Title(title string) PropertyOption { + return func(schema map[string]any) { + schema["title"] = title + } +} + +// +// String Property Options +// + +// DefaultString sets the default value for a string property. +// This value will be used if the property is not explicitly provided. +func DefaultString(value string) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// Enum specifies a list of allowed values for a string property. +// The property value must be one of the specified enum values. +func Enum(values ...string) PropertyOption { + return func(schema map[string]any) { + schema["enum"] = values + } +} + +// MaxLength sets the maximum length for a string property. +// The string value must not exceed this length. +func MaxLength(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxLength"] = max + } +} + +// MinLength sets the minimum length for a string property. +// The string value must be at least this length. +func MinLength(min int) PropertyOption { + return func(schema map[string]any) { + schema["minLength"] = min + } +} + +// Pattern sets a regex pattern that a string property must match. +// The string value must conform to the specified regular expression. +func Pattern(pattern string) PropertyOption { + return func(schema map[string]any) { + schema["pattern"] = pattern + } +} + +// +// Number Property Options +// + +// DefaultNumber sets the default value for a number property. +// This value will be used if the property is not explicitly provided. +func DefaultNumber(value float64) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// Max sets the maximum value for a number property. +// The number value must not exceed this maximum. +func Max(max float64) PropertyOption { + return func(schema map[string]any) { + schema["maximum"] = max + } +} + +// Min sets the minimum value for a number property. +// The number value must not be less than this minimum. +func Min(min float64) PropertyOption { + return func(schema map[string]any) { + schema["minimum"] = min + } +} + +// MultipleOf specifies that a number must be a multiple of the given value. +// The number value must be divisible by this value. +func MultipleOf(value float64) PropertyOption { + return func(schema map[string]any) { + schema["multipleOf"] = value + } +} + +// +// Boolean Property Options +// + +// DefaultBool sets the default value for a boolean property. +// This value will be used if the property is not explicitly provided. +func DefaultBool(value bool) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// +// Array Property Options +// + +// DefaultArray sets the default value for an array property. +// This value will be used if the property is not explicitly provided. +func DefaultArray[T any](value []T) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// +// Property Type Helpers +// + +// WithBoolean adds a boolean property to the tool schema. +// It accepts property options to configure the boolean property's behavior and constraints. +func WithBoolean(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithNumber adds a number property to the tool schema. +// It accepts property options to configure the number property's behavior and constraints. +func WithNumber(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithString adds a string property to the tool schema. +// It accepts property options to configure the string property's behavior and constraints. +func WithString(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithObject adds an object property to the tool schema. +// It accepts property options to configure the object property's behavior and constraints. +func WithObject(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithArray returns a ToolOption that adds an array-typed property with the given name to a Tool's input schema. +// It applies provided PropertyOption functions to configure the property's schema, moves a `required` flag +// from the property schema into the Tool's InputSchema.Required slice when present, and registers the resulting +// schema under InputSchema.Properties[name]. +func WithArray(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "array", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithAny adds an input property named name with no predefined JSON Schema type to the Tool's input schema. +// The returned ToolOption applies the provided PropertyOption functions to the property's schema, moves a property-level +// `required` flag into the Tool's InputSchema.Required list if present, and stores the resulting schema under InputSchema.Properties[name]. +func WithAny(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{} + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// Properties sets the "properties" map for an object schema. +// The returned PropertyOption stores the provided map under the schema's "properties" key. +func Properties(props map[string]any) PropertyOption { + return func(schema map[string]any) { + schema["properties"] = props + } +} + +// AdditionalProperties specifies whether additional properties are allowed in the object +// or defines a schema for additional properties +func AdditionalProperties(schema any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["additionalProperties"] = schema + } +} + +// MinProperties sets the minimum number of properties for an object +func MinProperties(min int) PropertyOption { + return func(schema map[string]any) { + schema["minProperties"] = min + } +} + +// MaxProperties sets the maximum number of properties for an object +func MaxProperties(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxProperties"] = max + } +} + +// PropertyNames defines a schema for property names in an object +func PropertyNames(schema map[string]any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["propertyNames"] = schema + } +} + +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. +func Items(schema any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["items"] = schema + } +} + +// MinItems sets the minimum number of items for an array +func MinItems(min int) PropertyOption { + return func(schema map[string]any) { + schema["minItems"] = min + } +} + +// MaxItems sets the maximum number of items for an array +func MaxItems(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxItems"] = max + } +} + +// UniqueItems specifies whether array items must be unique +func UniqueItems(unique bool) PropertyOption { + return func(schema map[string]any) { + schema["uniqueItems"] = unique + } +} + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go new file mode 100644 index 000000000..a03a19dd7 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go @@ -0,0 +1,42 @@ +package mcp + +import ( + "context" + "fmt" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) + +// StructuredToolHandlerFunc is a function that handles a tool call with typed arguments and returns structured output +type StructuredToolHandlerFunc[TArgs any, TResult any] func(ctx context.Context, request CallToolRequest, args TArgs) (TResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, request, args) + } +} + +// NewStructuredToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +// and returns structured output. It automatically creates both structured and +// text content (from the structured output) for backwards compatibility. +func NewStructuredToolHandler[TArgs any, TResult any](handler StructuredToolHandlerFunc[TArgs, TResult]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args TArgs + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + + result, err := handler(ctx, request, args) + if err != nil { + return NewToolResultError(fmt.Sprintf("tool execution failed: %v", err)), nil + } + + return NewToolResultStructuredOnly(result), nil + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go new file mode 100644 index 000000000..6e447c61c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -0,0 +1,1252 @@ +// Package mcp defines the core types and interfaces for the Model Context Protocol (MCP). +// MCP is a protocol for communication between LLM-powered applications and their supporting services. +package mcp + +import ( + "encoding/json" + "fmt" + "maps" + "net/http" + "strconv" + + "github.com/yosida95/uritemplate/v3" +) + +type MCPMethod string + +const ( + // MethodInitialize initiates connection and negotiates protocol capabilities. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + MethodInitialize MCPMethod = "initialize" + + // MethodPing verifies connection liveness between client and server. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + MethodPing MCPMethod = "ping" + + // MethodResourcesList lists all available server resources. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesList MCPMethod = "resources/list" + + // MethodResourcesTemplatesList provides URI templates for constructing resource URIs. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesTemplatesList MCPMethod = "resources/templates/list" + + // MethodResourcesRead retrieves content of a specific resource by URI. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesRead MCPMethod = "resources/read" + + // MethodPromptsList lists all available prompt templates. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsList MCPMethod = "prompts/list" + + // MethodPromptsGet retrieves a specific prompt template with filled parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsGet MCPMethod = "prompts/get" + + // MethodToolsList lists all available executable tools. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsList MCPMethod = "tools/list" + + // MethodToolsCall invokes a specific tool with provided parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsCall MCPMethod = "tools/call" + + // MethodSetLogLevel configures the minimum log level for client + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging + MethodSetLogLevel MCPMethod = "logging/setLevel" + + // MethodElicitationCreate requests additional information from the user during interactions. + // https://modelcontextprotocol.io/docs/concepts/elicitation + MethodElicitationCreate MCPMethod = "elicitation/create" + + // MethodListRoots requests roots list from the client during interactions. + // https://modelcontextprotocol.io/specification/2025-06-18/client/roots + MethodListRoots MCPMethod = "roots/list" + + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification + MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + + MethodNotificationResourceUpdated = "notifications/resources/updated" + + // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification + MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + + // MethodNotificationToolsListChanged notifies when the list of available tools changes. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#list-changed-notification + MethodNotificationToolsListChanged = "notifications/tools/list_changed" + + // MethodNotificationRootsListChanged notifies when the list of available roots changes. + // https://modelcontextprotocol.io/specification/2025-06-18/client/roots#root-list-changes + MethodNotificationRootsListChanged = "notifications/roots/list_changed" +) + +type URITemplate struct { + *uritemplate.Template +} + +func (t *URITemplate) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Raw()) +} + +func (t *URITemplate) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + template, err := uritemplate.New(raw) + if err != nil { + return err + } + t.Template = template + return nil +} + +/* JSON-RPC types */ + +// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError +type JSONRPCMessage any + +// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. +const LATEST_PROTOCOL_VERSION = "2025-06-18" + +// ValidProtocolVersions lists all known valid MCP protocol versions. +var ValidProtocolVersions = []string{ + LATEST_PROTOCOL_VERSION, + "2025-03-26", + "2024-11-05", +} + +// JSONRPC_VERSION is the version of JSON-RPC used by MCP. +const JSONRPC_VERSION = "2.0" + +// ProgressToken is used to associate progress notifications with the original request. +type ProgressToken any + +// Cursor is an opaque token used to represent a cursor for pagination. +type Cursor string + +// Meta is metadata attached to a request's parameters. This can include fields +// formally defined by the protocol or other arbitrary data. +type Meta struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken + + // AdditionalFields are any fields present in the Meta that are not + // otherwise defined in the protocol. + AdditionalFields map[string]any +} + +func (m *Meta) MarshalJSON() ([]byte, error) { + raw := make(map[string]any) + if m.ProgressToken != nil { + raw["progressToken"] = m.ProgressToken + } + maps.Copy(raw, m.AdditionalFields) + + return json.Marshal(raw) +} + +func (m *Meta) UnmarshalJSON(data []byte) error { + raw := make(map[string]any) + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + m.ProgressToken = raw["progressToken"] + delete(raw, "progressToken") + m.AdditionalFields = raw + return nil +} + +func NewMetaFromMap(m map[string]any) *Meta { + progressToken := m["progressToken"] + if progressToken != nil { + delete(m, "progressToken") + } + + return &Meta{ + ProgressToken: progressToken, + AdditionalFields: m, + } +} + +type Request struct { + Method string `json:"method"` + Params RequestParams `json:"params,omitempty"` +} + +type RequestParams struct { + Meta *Meta `json:"_meta,omitempty"` +} + +type Params map[string]any + +type Notification struct { + Method string `json:"method"` + Params NotificationParams `json:"params,omitempty"` +} + +type NotificationParams struct { + // This parameter name is reserved by MCP to allow clients and + // servers to attach additional metadata to their notifications. + Meta map[string]any `json:"_meta,omitempty"` + + // Additional fields can be added to this map + AdditionalFields map[string]any `json:"-"` +} + +// MarshalJSON implements custom JSON marshaling +func (p NotificationParams) MarshalJSON() ([]byte, error) { + // Create a map to hold all fields + m := make(map[string]any) + + // Add Meta if it exists + if p.Meta != nil { + m["_meta"] = p.Meta + } + + // Add all additional fields + for k, v := range p.AdditionalFields { + // Ensure we don't override the _meta field + if k != "_meta" { + m[k] = v + } + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (p *NotificationParams) UnmarshalJSON(data []byte) error { + // Create a map to hold all fields + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return err + } + + // Initialize maps if they're nil + if p.Meta == nil { + p.Meta = make(map[string]any) + } + if p.AdditionalFields == nil { + p.AdditionalFields = make(map[string]any) + } + + // Process all fields + for k, v := range m { + if k == "_meta" { + // Handle Meta field + if meta, ok := v.(map[string]any); ok { + p.Meta = meta + } + } else { + // Handle additional fields + p.AdditionalFields[k] = v + } + } + + return nil +} + +type Result struct { + // This result property is reserved by the protocol to allow clients and + // servers to attach additional metadata to their responses. + Meta *Meta `json:"_meta,omitempty"` +} + +// RequestId is a uniquely identifying ID for a request in JSON-RPC. +// It can be any JSON-serializable value, typically a number or string. +type RequestId struct { + value any +} + +// NewRequestId creates a new RequestId with the given value +func NewRequestId(value any) RequestId { + return RequestId{value: value} +} + +// Value returns the underlying value of the RequestId +func (r RequestId) Value() any { + return r.value +} + +// String returns a string representation of the RequestId +func (r RequestId) String() string { + switch v := r.value.(type) { + case string: + return "string:" + v + case int64: + return "int64:" + strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return "int64:" + strconv.FormatInt(int64(v), 10) + } + return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) + case nil: + return "" + default: + return "unknown:" + fmt.Sprintf("%v", v) + } +} + +// IsNil returns true if the RequestId is nil +func (r RequestId) IsNil() bool { + return r.value == nil +} + +func (r RequestId) MarshalJSON() ([]byte, error) { + return json.Marshal(r.value) +} + +func (r *RequestId) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + r.value = nil + return nil + } + + // Try unmarshaling as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.value = s + return nil + } + + // JSON numbers are unmarshaled as float64 in Go + var f float64 + if err := json.Unmarshal(data, &f); err == nil { + if f == float64(int64(f)) { + r.value = int64(f) + } else { + r.value = f + } + return nil + } + + return fmt.Errorf("invalid request id: %s", string(data)) +} + +// JSONRPCRequest represents a request that expects a response. +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Params any `json:"params,omitempty"` + Request +} + +// JSONRPCNotification represents a notification which does not expect a response. +type JSONRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Notification +} + +// JSONRPCResponse represents a successful (non-error) response to a request. +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Result any `json:"result"` +} + +// JSONRPCError represents a non-successful (error) response to a request. +type JSONRPCError struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Error JSONRPCErrorDetails `json:"error"` +} + +// JSONRPCErrorDetails represents a JSON-RPC error for Go error handling. +// This is separate from the JSONRPCError type which represents the full JSON-RPC error response structure. +type JSONRPCErrorDetails struct { + // The error type that occurred. + Code int `json:"code"` + // A short description of the error. The message SHOULD be limited + // to a concise single sentence. + Message string `json:"message"` + // Additional information about the error. The value of this member + // is defined by the sender (e.g. detailed error information, nested errors etc.). + Data any `json:"data,omitempty"` +} + +// Standard JSON-RPC error codes +const ( + // PARSE_ERROR indicates invalid JSON was received by the server. + PARSE_ERROR = -32700 + + // INVALID_REQUEST indicates the JSON sent is not a valid Request object. + INVALID_REQUEST = -32600 + + // METHOD_NOT_FOUND indicates the method does not exist/is not available. + METHOD_NOT_FOUND = -32601 + + // INVALID_PARAMS indicates invalid method parameter(s). + INVALID_PARAMS = -32602 + + // INTERNAL_ERROR indicates internal JSON-RPC error. + INTERNAL_ERROR = -32603 + + // REQUEST_INTERRUPTED indicates a request was cancelled or timed out. + REQUEST_INTERRUPTED = -32800 +) + +// MCP error codes +const ( + // RESOURCE_NOT_FOUND indicates a requested resource was not found. + RESOURCE_NOT_FOUND = -32002 +) + +/* Empty result */ + +// EmptyResult represents a response that indicates success but carries no data. +type EmptyResult Result + +/* Cancellation */ + +// CancelledNotification can be sent by either side to indicate that it is +// cancelling a previously-issued request. +// +// The request SHOULD still be in-flight, but due to communication latency, it +// is always possible that this notification MAY arrive after the request has +// already finished. +// +// This notification indicates that the result will be unused, so any +// associated processing SHOULD cease. +// +// A client MUST NOT attempt to cancel its `initialize` request. +type CancelledNotification struct { + Notification + Params CancelledNotificationParams `json:"params"` +} + +type CancelledNotificationParams struct { + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued + // in the same direction. + RequestId RequestId `json:"requestId"` + + // An optional string describing the reason for the cancellation. This MAY + // be logged or presented to the user. + Reason string `json:"reason,omitempty"` +} + +/* Initialization */ + +// InitializeRequest is sent from the client to the server when it first +// connects, asking it to begin initialization. +type InitializeRequest struct { + Request + Params InitializeParams `json:"params"` + Header http.Header `json:"-"` +} + +type InitializeParams struct { + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo Implementation `json:"clientInfo"` +} + +// InitializeResult is sent after receiving an initialize request from the +// client. +type InitializeResult struct { + Result + // The version of the Model Context Protocol that the server wants to use. + // This may not match the version that the client requested. If the client cannot + // support this version, it MUST disconnect. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo Implementation `json:"serverInfo"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of + // available tools, resources, etc. It can be thought of like a "hint" to the model. + // For example, this information MAY be added to the system prompt. + Instructions string `json:"instructions,omitempty"` +} + +// InitializedNotification is sent from the client to the server after +// initialization has finished. +type InitializedNotification struct { + Notification +} + +// ClientCapabilities represents capabilities a client may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// client can define its own, additional capabilities. +type ClientCapabilities struct { + // Experimental, non-standard capabilities that the client supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the client supports listing roots. + Roots *struct { + // Whether the client supports notifications for changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // Present if the client supports sampling from an LLM. + Sampling *struct{} `json:"sampling,omitempty"` + // Present if the client supports elicitation requests from the server. + Elicitation *struct{} `json:"elicitation,omitempty"` +} + +// ServerCapabilities represents capabilities that a server may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// server can define its own, additional capabilities. +type ServerCapabilities struct { + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging *struct{} `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *struct { + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` + // Whether this server supports notifications for changes to the resource + // list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"resources,omitempty"` + // Present if the server supports sending sampling requests to clients. + Sampling *struct{} `json:"sampling,omitempty"` + // Present if the server offers any tools to call. + Tools *struct { + // Whether this server supports notifications for changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"tools,omitempty"` + // Present if the server supports elicitation requests to the client. + Elicitation *struct{} `json:"elicitation,omitempty"` + // Present if the server supports roots requests to the client. + Roots *struct{} `json:"roots,omitempty"` +} + +// Implementation describes the name and version of an MCP implementation. +type Implementation struct { + Name string `json:"name"` + Version string `json:"version"` + Title string `json:"title,omitempty"` +} + +/* Ping */ + +// PingRequest represents a ping, issued by either the server or the client, +// to check that the other party is still alive. The receiver must promptly respond, +// or else may be disconnected. +type PingRequest struct { + Request + Header http.Header `json:"-"` +} + +/* Progress notifications */ + +// ProgressNotification is an out-of-band notification used to inform the +// receiver of a progress update for a long-running request. +type ProgressNotification struct { + Notification + Params ProgressNotificationParams `json:"params"` +} + +type ProgressNotificationParams struct { + // The progress token which was given in the initial request, used to + // associate this notification with the request that is proceeding. + ProgressToken ProgressToken `json:"progressToken"` + // The progress thus far. This should increase every time progress is made, + // even if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + Total float64 `json:"total,omitempty"` + // Message related to progress. This should provide relevant human-readable + // progress information. + Message string `json:"message,omitempty"` +} + +/* Pagination */ + +type PaginatedRequest struct { + Request + Params PaginatedParams `json:"params,omitempty"` +} + +type PaginatedParams struct { + // An opaque token representing the current pagination position. + // If provided, the server should return results starting after this cursor. + Cursor Cursor `json:"cursor,omitempty"` +} + +type PaginatedResult struct { + Result + // An opaque token representing the pagination position after the last + // returned result. + // If present, there may be more results available. + NextCursor Cursor `json:"nextCursor,omitempty"` +} + +/* Resources */ + +// ListResourcesRequest is sent from the client to request a list of resources +// the server has. +type ListResourcesRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListResourcesResult is the server's response to a resources/list request +// from the client. +type ListResourcesResult struct { + PaginatedResult + Resources []Resource `json:"resources"` +} + +// ListResourceTemplatesRequest is sent from the client to request a list of +// resource templates the server has. +type ListResourceTemplatesRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListResourceTemplatesResult is the server's response to a +// resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + PaginatedResult + ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` +} + +// ReadResourceRequest is sent from the client to the server, to read a +// specific resource URI. +type ReadResourceRequest struct { + Request + Header http.Header `json:"-"` + Params ReadResourceParams `json:"params"` +} + +type ReadResourceParams struct { + // The URI of the resource to read. The URI can use any protocol; it is up + // to the server how to interpret it. + URI string `json:"uri"` + // Arguments to pass to the resource handler + Arguments map[string]any `json:"arguments,omitempty"` +} + +// ReadResourceResult is the server's response to a resources/read request +// from the client. +type ReadResourceResult struct { + Result + Contents []ResourceContents `json:"contents"` // Can be TextResourceContents or BlobResourceContents +} + +// ResourceListChangedNotification is an optional notification from the server +// to the client, informing it that the list of resources it can read from has +// changed. This may be issued by servers without any previous subscription from +// the client. +type ResourceListChangedNotification struct { + Notification +} + +// SubscribeRequest is sent from the client to request resources/updated +// notifications from the server whenever a particular resource changes. +type SubscribeRequest struct { + Request + Params SubscribeParams `json:"params"` + Header http.Header `json:"-"` +} + +type SubscribeParams struct { + // The URI of the resource to subscribe to. The URI can use any protocol; it + // is up to the server how to interpret it. + URI string `json:"uri"` +} + +// UnsubscribeRequest is sent from the client to request cancellation of +// resources/updated notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeRequest struct { + Request + Params UnsubscribeParams `json:"params"` + Header http.Header `json:"-"` +} + +type UnsubscribeParams struct { + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +// ResourceUpdatedNotification is a notification from the server to the client, +// informing it that a resource has changed and may need to be read again. This +// should only be sent if the client previously sent a resources/subscribe request. +type ResourceUpdatedNotification struct { + Notification + Params ResourceUpdatedNotificationParams `json:"params"` +} +type ResourceUpdatedNotificationParams struct { + // The URI of the resource that has been updated. This might be a sub- + // resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +// Resource represents a known resource that the server is capable of reading. +type Resource struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // A human-readable name for this resource. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` +} + +// GetName returns the name of the resource. +func (r Resource) GetName() string { + return r.Name +} + +// ResourceTemplate represents a template description for resources available +// on the server. +type ResourceTemplate struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // A URI template (according to RFC 6570) that can be used to construct + // resource URIs. + URITemplate *URITemplate `json:"uriTemplate"` + // A human-readable name for the type of resource this template refers to. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only + // be included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` +} + +// GetName returns the name of the resourceTemplate. +func (rt ResourceTemplate) GetName() string { + return rt.Name +} + +// ResourceContents represents the contents of a specific resource or sub- +// resource. +type ResourceContents interface { + isResourceContents() +} + +type TextResourceContents struct { + // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. + // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. + Meta map[string]any `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // The text of the item. This must only be set if the item can actually be + // represented as text (not binary data). + Text string `json:"text"` +} + +func (TextResourceContents) isResourceContents() {} + +type BlobResourceContents struct { + // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. + // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. + Meta map[string]any `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // A base64-encoded string representing the binary data of the item. + Blob string `json:"blob"` +} + +func (BlobResourceContents) isResourceContents() {} + +/* Logging */ + +// SetLevelRequest is a request from the client to the server, to enable or +// adjust logging. +type SetLevelRequest struct { + Request + Params SetLevelParams `json:"params"` + Header http.Header `json:"-"` +} + +type SetLevelParams struct { + // The level of logging that the client wants to receive from the server. + // The server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/logging/message. + Level LoggingLevel `json:"level"` +} + +// LoggingMessageNotification is a notification of a log message passed from +// server to client. If no logging/setLevel request has been sent from the client, +// the server MAY decide which messages to send automatically. +type LoggingMessageNotification struct { + Notification + Params LoggingMessageNotificationParams `json:"params"` +} + +type LoggingMessageNotificationParams struct { + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` +} + +// LoggingLevel represents the severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +const ( + LoggingLevelDebug LoggingLevel = "debug" + LoggingLevelInfo LoggingLevel = "info" + LoggingLevelNotice LoggingLevel = "notice" + LoggingLevelWarning LoggingLevel = "warning" + LoggingLevelError LoggingLevel = "error" + LoggingLevelCritical LoggingLevel = "critical" + LoggingLevelAlert LoggingLevel = "alert" + LoggingLevelEmergency LoggingLevel = "emergency" +) + +var levelToInt = map[LoggingLevel]int{ + LoggingLevelDebug: 0, + LoggingLevelInfo: 1, + LoggingLevelNotice: 2, + LoggingLevelWarning: 3, + LoggingLevelError: 4, + LoggingLevelCritical: 5, + LoggingLevelAlert: 6, + LoggingLevelEmergency: 7, +} + +func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { + ia, oka := levelToInt[l] + ib, okb := levelToInt[minLevel] + if !oka || !okb { + return false + } + return ia >= ib +} + +/* Elicitation */ + +// ElicitationRequest is a request from the server to the client to request additional +// information from the user during an interaction. +type ElicitationRequest struct { + Request + Params ElicitationParams `json:"params"` +} + +// ElicitationParams contains the parameters for an elicitation request. +type ElicitationParams struct { + // A human-readable message explaining what information is being requested and why. + Message string `json:"message"` + // A JSON Schema defining the expected structure of the user's response. + RequestedSchema any `json:"requestedSchema"` +} + +// ElicitationResult represents the result of an elicitation request. +type ElicitationResult struct { + Result + ElicitationResponse +} + +// ElicitationResponse represents the user's response to an elicitation request. +type ElicitationResponse struct { + // Action indicates whether the user accepted, declined, or cancelled. + Action ElicitationResponseAction `json:"action"` + // Content contains the user's response data if they accepted. + // Should conform to the requestedSchema from the ElicitationRequest. + Content any `json:"content,omitempty"` +} + +// ElicitationResponseAction indicates how the user responded to an elicitation request. +type ElicitationResponseAction string + +const ( + // ElicitationResponseActionAccept indicates the user provided the requested information. + ElicitationResponseActionAccept ElicitationResponseAction = "accept" + // ElicitationResponseActionDecline indicates the user explicitly declined to provide information. + ElicitationResponseActionDecline ElicitationResponseAction = "decline" + // ElicitationResponseActionCancel indicates the user cancelled without making a choice. + ElicitationResponseActionCancel ElicitationResponseAction = "cancel" +) + +/* Sampling */ + +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + +// CreateMessageRequest is a request from the server to sample an LLM via the +// client. The client has full discretion over which model to select. The client +// should also inform the user before beginning sampling, to allow them to inspect +// the request (human in the loop) and decide whether to approve it. +type CreateMessageRequest struct { + Request + CreateMessageParams `json:"params"` +} + +type CreateMessageParams struct { + Messages []SamplingMessage `json:"messages"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens"` + StopSequences []string `json:"stopSequences,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +// CreateMessageResult is the client's response to a sampling/create_message +// request from the server. The client should inform the user before returning the +// sampled message, to allow them to inspect the response (human in the loop) and +// decide whether to allow the server to see it. +type CreateMessageResult struct { + Result + SamplingMessage + // The name of the model that generated the message. + Model string `json:"model"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingMessage describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Role Role `json:"role"` + Content any `json:"content"` // Can be TextContent, ImageContent or AudioContent +} + +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., `["user", "assistant"]`). + Audience []Role `json:"audience,omitempty"` + + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that + // the data is entirely optional. + Priority float64 `json:"priority,omitempty"` +} + +// Annotated is the base for objects that include optional annotations for the +// client. The client can use annotations to inform how objects are used or +// displayed +type Annotated struct { + Annotations *Annotations `json:"annotations,omitempty"` +} + +type Content interface { + isContent() +} + +// TextContent represents text provided to or from an LLM. +// It must have Type set to "text". +type TextContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "text" + // The text content of the message. + Text string `json:"text"` +} + +func (TextContent) isContent() {} + +// ImageContent represents an image provided to or from an LLM. +// It must have Type set to "image". +type ImageContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "image" + // The base64-encoded image data. + Data string `json:"data"` + // The MIME type of the image. Different providers may support different image types. + MIMEType string `json:"mimeType"` +} + +func (ImageContent) isContent() {} + +// AudioContent represents the contents of audio, embedded into a prompt or tool call result. +// It must have Type set to "audio". +type AudioContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "audio" + // The base64-encoded audio data. + Data string `json:"data"` + // The MIME type of the audio. Different providers may support different audio types. + MIMEType string `json:"mimeType"` +} + +func (AudioContent) isContent() {} + +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + +// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. +// +// It is up to the client how best to render embedded resources for the +// benefit of the LLM and/or the user. +type EmbeddedResource struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` + Resource ResourceContents `json:"resource"` +} + +func (EmbeddedResource) isContent() {} + +// ModelPreferences represents the server's preferences for model selection, +// requested of the client during sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" modelis +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client MAY ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client MUST evaluate them in order + // (such that the first match is taken). + // + // The client SHOULD prioritize these hints over the numeric priorities, but + // MAY still use the priorities to select from ambiguous matches. + Hints []ModelHint `json:"hints,omitempty"` + + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important + // factor. + CostPriority float64 `json:"costPriority,omitempty"` + + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` + + // How much to prioritize intelligence and capabilities when selecting a + // model. A value of 0 means intelligence is not important, while a value of 1 + // means intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` +} + +// ModelHint represents hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client SHOULD treat this as a substring of a model name; for example: + // - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` + // - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. + // - `claude` should match any Claude model + // + // The client MAY also map the string to a different provider's model name or + // a different model family, as long as it fills a similar niche; for example: + // - `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +/* Autocomplete */ + +// CompleteRequest is a request from the client to the server, to ask for completion options. +type CompleteRequest struct { + Request + Params CompleteParams `json:"params"` + Header http.Header `json:"-"` +} + +type CompleteParams struct { + Ref any `json:"ref"` // Can be PromptReference or ResourceReference + Argument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` + } `json:"argument"` +} + +// CompleteResult is the server's response to a completion/complete request +type CompleteResult struct { + Result + Completion struct { + // An array of completion values. Must not exceed 100 items. + Values []string `json:"values"` + // The total number of completion options available. This can exceed the + // number of values actually sent in the response. + Total int `json:"total,omitempty"` + // Indicates whether there are additional completion options beyond those + // provided in the current response, even if the exact total is unknown. + HasMore bool `json:"hasMore,omitempty"` + } `json:"completion"` +} + +// ResourceReference is a reference to a resource or resource template definition. +type ResourceReference struct { + Type string `json:"type"` + // The URI or URI template of the resource. + URI string `json:"uri"` +} + +// PromptReference identifies a prompt. +type PromptReference struct { + Type string `json:"type"` + // The name of the prompt or prompt template + Name string `json:"name"` +} + +/* Roots */ + +// ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow +// servers to ask for specific directories or files to operate on. A common example +// for roots is providing a set of repositories or directories a server should operate +// on. +// +// This request is typically used when the server needs to understand the file system +// structure or access specific locations that the client has permission to read from. +type ListRootsRequest struct { + Request +} + +// ListRootsResult is the client's response to a roots/list request from the server. +// This result contains an array of Root objects, each representing a root directory +// or file that the server can operate on. +type ListRootsResult struct { + Result + Roots []Root `json:"roots"` +} + +// Root represents a root directory or file that the server can operate on. +type Root struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. + // This restriction may be relaxed in future versions of the protocol to allow + // other URI schemes. + URI string `json:"uri"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` +} + +// RootsListChangedNotification is a notification from the client to the +// server, informing it that the list of roots has changed. +// This notification should be sent whenever the client adds, removes, or modifies any root. +// The server should then request an updated list of roots using the ListRootsRequest. +type RootsListChangedNotification struct { + Notification +} + +// ClientRequest represents any request that can be sent from client to server. +type ClientRequest any + +// ClientNotification represents any notification that can be sent from client to server. +type ClientNotification any + +// ClientResult represents any result that can be sent from client to server. +type ClientResult any + +// ServerRequest represents any request that can be sent from server to client. +type ServerRequest any + +// ServerNotification represents any notification that can be sent from server to client. +type ServerNotification any + +// ServerResult represents any result that can be sent from server to client. +type ServerResult any + +type Named interface { + GetName() string +} + +// MarshalJSON implements custom JSON marshaling for Content interface +func MarshalContent(content Content) ([]byte, error) { + return json.Marshal(content) +} + +// UnmarshalContent implements custom JSON unmarshaling for Content interface +func UnmarshalContent(data []byte) (Content, error) { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + contentType, ok := raw["type"].(string) + if !ok { + return nil, fmt.Errorf("missing or invalid type field") + } + + switch contentType { + case ContentTypeText: + var content TextContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeImage: + var content ImageContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeAudio: + var content AudioContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeLink: + var content ResourceLink + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeResource: + var content EmbeddedResource + err := json.Unmarshal(data, &content) + return content, err + default: + return nil, fmt.Errorf("unknown content type: %s", contentType) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go new file mode 100644 index 000000000..904a3dd6b --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -0,0 +1,979 @@ +package mcp + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/cast" +) + +// ClientRequest types +var ( + _ ClientRequest = (*PingRequest)(nil) + _ ClientRequest = (*InitializeRequest)(nil) + _ ClientRequest = (*CompleteRequest)(nil) + _ ClientRequest = (*SetLevelRequest)(nil) + _ ClientRequest = (*GetPromptRequest)(nil) + _ ClientRequest = (*ListPromptsRequest)(nil) + _ ClientRequest = (*ListResourcesRequest)(nil) + _ ClientRequest = (*ReadResourceRequest)(nil) + _ ClientRequest = (*SubscribeRequest)(nil) + _ ClientRequest = (*UnsubscribeRequest)(nil) + _ ClientRequest = (*CallToolRequest)(nil) + _ ClientRequest = (*ListToolsRequest)(nil) +) + +// ClientNotification types +var ( + _ ClientNotification = (*CancelledNotification)(nil) + _ ClientNotification = (*ProgressNotification)(nil) + _ ClientNotification = (*InitializedNotification)(nil) + _ ClientNotification = (*RootsListChangedNotification)(nil) +) + +// ClientResult types +var ( + _ ClientResult = (*EmptyResult)(nil) + _ ClientResult = (*CreateMessageResult)(nil) + _ ClientResult = (*ListRootsResult)(nil) +) + +// ServerRequest types +var ( + _ ServerRequest = (*PingRequest)(nil) + _ ServerRequest = (*CreateMessageRequest)(nil) + _ ServerRequest = (*ListRootsRequest)(nil) +) + +// ServerNotification types +var ( + _ ServerNotification = (*CancelledNotification)(nil) + _ ServerNotification = (*ProgressNotification)(nil) + _ ServerNotification = (*LoggingMessageNotification)(nil) + _ ServerNotification = (*ResourceUpdatedNotification)(nil) + _ ServerNotification = (*ResourceListChangedNotification)(nil) + _ ServerNotification = (*ToolListChangedNotification)(nil) + _ ServerNotification = (*PromptListChangedNotification)(nil) +) + +// ServerResult types +var ( + _ ServerResult = (*EmptyResult)(nil) + _ ServerResult = (*InitializeResult)(nil) + _ ServerResult = (*CompleteResult)(nil) + _ ServerResult = (*GetPromptResult)(nil) + _ ServerResult = (*ListPromptsResult)(nil) + _ ServerResult = (*ListResourcesResult)(nil) + _ ServerResult = (*ReadResourceResult)(nil) + _ ServerResult = (*CallToolResult)(nil) + _ ServerResult = (*ListToolsResult)(nil) +) + +// Helper functions for type assertions + +// asType attempts to cast the given interface to the given type +func asType[T any](content any) (*T, bool) { + tc, ok := content.(T) + if !ok { + return nil, false + } + return &tc, true +} + +// AsTextContent attempts to cast the given interface to TextContent +func AsTextContent(content any) (*TextContent, bool) { + return asType[TextContent](content) +} + +// AsImageContent attempts to cast the given interface to ImageContent +func AsImageContent(content any) (*ImageContent, bool) { + return asType[ImageContent](content) +} + +// AsAudioContent attempts to cast the given interface to AudioContent +func AsAudioContent(content any) (*AudioContent, bool) { + return asType[AudioContent](content) +} + +// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource +func AsEmbeddedResource(content any) (*EmbeddedResource, bool) { + return asType[EmbeddedResource](content) +} + +// AsTextResourceContents attempts to cast the given interface to TextResourceContents +func AsTextResourceContents(content any) (*TextResourceContents, bool) { + return asType[TextResourceContents](content) +} + +// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents +func AsBlobResourceContents(content any) (*BlobResourceContents, bool) { + return asType[BlobResourceContents](content) +} + +// Helper function for JSON-RPC + +// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result. +// NOTE: This function expects a Result struct, but JSONRPCResponse.Result is typed as `any`. +// The Result struct wraps the actual result data with optional metadata. +// For direct result assignment, use NewJSONRPCResultResponse instead. +func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCResultResponse creates a new JSONRPCResponse with the given id and result. +// This function accepts any type for the result, matching the JSONRPCResponse.Result field type. +func NewJSONRPCResultResponse(id RequestId, result any) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCErrorDetails creates a new JSONRPCErrorDetails with the given code, message, and data. +func NewJSONRPCErrorDetails(code int, message string, data any) JSONRPCErrorDetails { + return JSONRPCErrorDetails{ + Code: code, + Message: message, + Data: data, + } +} + +// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message +func NewJSONRPCError( + id RequestId, + code int, + message string, + data any, +) JSONRPCError { + return JSONRPCError{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Error: NewJSONRPCErrorDetails(code, message, data), + } +} + +// NewProgressNotification +// Helper function for creating a progress notification +func NewProgressNotification( + token ProgressToken, + progress float64, + total *float64, + message *string, +) ProgressNotification { + notification := ProgressNotification{ + Notification: Notification{ + Method: "notifications/progress", + }, + Params: struct { + ProgressToken ProgressToken `json:"progressToken"` + Progress float64 `json:"progress"` + Total float64 `json:"total,omitempty"` + Message string `json:"message,omitempty"` + }{ + ProgressToken: token, + Progress: progress, + }, + } + if total != nil { + notification.Params.Total = *total + } + if message != nil { + notification.Params.Message = *message + } + return notification +} + +// NewLoggingMessageNotification +// Helper function for creating a logging message notification +func NewLoggingMessageNotification( + level LoggingLevel, + logger string, + data any, +) LoggingMessageNotification { + return LoggingMessageNotification{ + Notification: Notification{ + Method: "notifications/message", + }, + Params: struct { + Level LoggingLevel `json:"level"` + Logger string `json:"logger,omitempty"` + Data any `json:"data"` + }{ + Level: level, + Logger: logger, + Data: data, + }, + } +} + +// NewPromptMessage +// Helper function to create a new PromptMessage +func NewPromptMessage(role Role, content Content) PromptMessage { + return PromptMessage{ + Role: role, + Content: content, + } +} + +// NewTextContent +// Helper function to create a new TextContent +func NewTextContent(text string) TextContent { + return TextContent{ + Type: ContentTypeText, + Text: text, + } +} + +// NewImageContent +// Helper function to create a new ImageContent +func NewImageContent(data, mimeType string) ImageContent { + return ImageContent{ + Type: ContentTypeImage, + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new AudioContent +func NewAudioContent(data, mimeType string) AudioContent { + return AudioContent{ + Type: ContentTypeAudio, + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: ContentTypeLink, + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + +// Helper function to create a new EmbeddedResource +func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { + return EmbeddedResource{ + Type: ContentTypeResource, + Resource: resource, + } +} + +// NewToolResultText creates a new CallToolResult with a text content +func NewToolResultText(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + } +} + +// NewToolResultJSON creates a new CallToolResult with a JSON content. +func NewToolResultJSON[T any](data T) (*CallToolResult, error) { + b, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("unable to marshal JSON: %w", err) + } + + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: string(b), + }, + }, + StructuredContent: data, + }, nil +} + +// NewToolResultStructured creates a new CallToolResult with structured content. +// It includes both the structured content and a text representation for backward compatibility. +func NewToolResultStructured(structured any, fallbackText string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + +// NewToolResultStructuredOnly creates a new CallToolResult with structured +// content and creates a JSON string fallback for backwards compatibility. +// This is useful when you want to provide structured data without any specific text fallback. +func NewToolResultStructuredOnly(structured any) *CallToolResult { + var fallbackText string + // Convert to JSON string for backward compatibility + jsonBytes, err := json.Marshal(structured) + if err != nil { + fallbackText = fmt.Sprintf("Error serializing structured content: %v", err) + } else { + fallbackText = string(jsonBytes) + } + + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + +// NewToolResultImage creates a new CallToolResult with both text and image content +func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + ImageContent{ + Type: ContentTypeImage, + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultAudio creates a new CallToolResult with both text and audio content +func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + AudioContent{ + Type: ContentTypeAudio, + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultResource creates a new CallToolResult with an embedded resource +func NewToolResultResource( + text string, + resource ResourceContents, +) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + EmbeddedResource{ + Type: ContentTypeResource, + Resource: resource, + }, + }, + } +} + +// NewToolResultError creates a new CallToolResult with an error message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultError(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + IsError: true, + } +} + +// NewToolResultErrorFromErr creates a new CallToolResult with an error message. +// If an error is provided, its details will be appended to the text message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorFromErr(text string, err error) *CallToolResult { + if err != nil { + text = fmt.Sprintf("%s: %v", text, err) + } + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + IsError: true, + } +} + +// NewToolResultErrorf creates a new CallToolResult with an error message. +// The error message is formatted using the fmt package. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorf(format string, a ...any) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: fmt.Sprintf(format, a...), + }, + }, + IsError: true, + } +} + +// NewListResourcesResult creates a new ListResourcesResult +func NewListResourcesResult( + resources []Resource, + nextCursor Cursor, +) *ListResourcesResult { + return &ListResourcesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Resources: resources, + } +} + +// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult +func NewListResourceTemplatesResult( + templates []ResourceTemplate, + nextCursor Cursor, +) *ListResourceTemplatesResult { + return &ListResourceTemplatesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + ResourceTemplates: templates, + } +} + +// NewReadResourceResult creates a new ReadResourceResult with text content +func NewReadResourceResult(text string) *ReadResourceResult { + return &ReadResourceResult{ + Contents: []ResourceContents{ + TextResourceContents{ + Text: text, + }, + }, + } +} + +// NewListPromptsResult creates a new ListPromptsResult +func NewListPromptsResult( + prompts []Prompt, + nextCursor Cursor, +) *ListPromptsResult { + return &ListPromptsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Prompts: prompts, + } +} + +// NewGetPromptResult creates a new GetPromptResult +func NewGetPromptResult( + description string, + messages []PromptMessage, +) *GetPromptResult { + return &GetPromptResult{ + Description: description, + Messages: messages, + } +} + +// NewListToolsResult creates a new ListToolsResult +func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult { + return &ListToolsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Tools: tools, + } +} + +// NewInitializeResult creates a new InitializeResult +func NewInitializeResult( + protocolVersion string, + capabilities ServerCapabilities, + serverInfo Implementation, + instructions string, +) *InitializeResult { + return &InitializeResult{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ServerInfo: serverInfo, + Instructions: instructions, + } +} + +// FormatNumberResult +// Helper for formatting numbers in tool results +func FormatNumberResult(value float64) *CallToolResult { + return NewToolResultText(fmt.Sprintf("%.2f", value)) +} + +func ExtractString(data map[string]any, key string) string { + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +func ParseAnnotations(data map[string]any) *Annotations { + if data == nil { + return nil + } + annotations := &Annotations{} + if value, ok := data["priority"]; ok { + annotations.Priority = cast.ToFloat64(value) + } + + if value, ok := data["audience"]; ok { + for _, a := range cast.ToStringSlice(value) { + a := Role(a) + if a == RoleUser || a == RoleAssistant { + annotations.Audience = append(annotations.Audience, a) + } + } + } + return annotations + +} + +func ExtractMap(data map[string]any, key string) map[string]any { + if value, ok := data[key]; ok { + if m, ok := value.(map[string]any); ok { + return m + } + } + return nil +} + +func ParseContent(contentMap map[string]any) (Content, error) { + contentType := ExtractString(contentMap, "type") + + var annotations *Annotations + if annotationsMap := ExtractMap(contentMap, "annotations"); annotationsMap != nil { + annotations = ParseAnnotations(annotationsMap) + } + + switch contentType { + case ContentTypeText: + text := ExtractString(contentMap, "text") + c := NewTextContent(text) + c.Annotations = annotations + return c, nil + + case ContentTypeImage: + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("image data or mimeType is missing") + } + c := NewImageContent(data, mimeType) + c.Annotations = annotations + return c, nil + + case ContentTypeAudio: + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("audio data or mimeType is missing") + } + c := NewAudioContent(data, mimeType) + c.Annotations = annotations + return c, nil + + case ContentTypeLink: + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + c := NewResourceLink(uri, name, description, mimeType) + c.Annotations = annotations + return c, nil + + case ContentTypeResource: + resourceMap := ExtractMap(contentMap, "resource") + if resourceMap == nil { + return nil, fmt.Errorf("resource is missing") + } + + resourceContents, err := ParseResourceContents(resourceMap) + if err != nil { + return nil, err + } + + c := NewEmbeddedResource(resourceContents) + c.Annotations = annotations + return c, nil + } + + return nil, fmt.Errorf("unsupported content type: %s", contentType) +} + +func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + result := GetPromptResult{} + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + description, ok := jsonContent["description"] + if ok { + if descriptionStr, ok := description.(string); ok { + result.Description = descriptionStr + } + } + + messages, ok := jsonContent["messages"] + if ok { + messagesArr, ok := messages.([]any) + if !ok { + return nil, fmt.Errorf("messages is not an array") + } + + for _, message := range messagesArr { + messageMap, ok := message.(map[string]any) + if !ok { + return nil, fmt.Errorf("message is not an object") + } + + // Extract role + roleStr := ExtractString(messageMap, "role") + if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { + return nil, fmt.Errorf("unsupported role: %s", roleStr) + } + + // Extract content + contentMap, ok := messageMap["content"].(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + // Append processed message + result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) + + } + } + + return &result, nil +} + +func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result CallToolResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + isError, ok := jsonContent["isError"] + if ok { + if isErrorBool, ok := isError.(bool); ok { + result.IsError = isErrorBool + } + } + + contents, ok := jsonContent["content"] + if !ok { + return nil, fmt.Errorf("content is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("content is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + result.Content = append(result.Content, content) + } + + // Handle structured content + structuredContent, ok := jsonContent["structuredContent"] + if ok { + result.StructuredContent = structuredContent + } + + return &result, nil +} + +func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) { + uri := ExtractString(contentMap, "uri") + if uri == "" { + return nil, fmt.Errorf("resource uri is missing") + } + + mimeType := ExtractString(contentMap, "mimeType") + + meta := ExtractMap(contentMap, "_meta") + + if _, present := contentMap["_meta"]; present && meta == nil { + return nil, fmt.Errorf("_meta must be an object") + } + + if text := ExtractString(contentMap, "text"); text != "" { + return TextResourceContents{ + Meta: meta, + URI: uri, + MIMEType: mimeType, + Text: text, + }, nil + } + + if blob := ExtractString(contentMap, "blob"); blob != "" { + return BlobResourceContents{ + Meta: meta, + URI: uri, + MIMEType: mimeType, + Blob: blob, + }, nil + } + + return nil, fmt.Errorf("unsupported resource type") +} + +func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result ReadResourceResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + contents, ok := jsonContent["contents"] + if !ok { + return nil, fmt.Errorf("contents is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("contents is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseResourceContents(contentMap) + if err != nil { + return nil, err + } + + result.Contents = append(result.Contents, content) + } + + return &result, nil +} + +func ParseArgument(request CallToolRequest, key string, defaultVal any) any { + args := request.GetArguments() + if _, ok := args[key]; !ok { + return defaultVal + } else { + return args[key] + } +} + +// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +// The function uses cast.ToBool for conversion which handles various string representations +// such as "true", "yes", "1", etc. +func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { + v := ParseArgument(request, key, defaultValue) + return cast.ToBool(v) +} + +// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt64(v) +} + +// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. +func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt32(v) +} + +// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. +func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt16(v) +} + +// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. +func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt8(v) +} + +// ParseInt extracts and converts an int parameter from a CallToolRequest. +func ParseInt(request CallToolRequest, key string, defaultValue int) int { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt(v) +} + +// ParseUInt extracts and converts an uint parameter from a CallToolRequest. +func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint(v) +} + +// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. +func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint64(v) +} + +// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. +func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint32(v) +} + +// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. +func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint16(v) +} + +// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. +func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint8(v) +} + +// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. +func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat32(v) +} + +// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. +func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat64(v) +} + +// ParseString extracts and converts a string parameter from a CallToolRequest. +func ParseString(request CallToolRequest, key string, defaultValue string) string { + v := ParseArgument(request, key, defaultValue) + return cast.ToString(v) +} + +// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. +func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { + v := ParseArgument(request, key, defaultValue) + return cast.ToStringMap(v) +} + +// ToBoolPtr returns a pointer to the given boolean value +func ToBoolPtr(b bool) *bool { + return &b +} + +// GetTextFromContent extracts text from a Content interface that might be a TextContent struct +// or a map[string]any that was unmarshaled from JSON. This is useful when dealing with content +// that comes from different transport layers that may handle JSON differently. +// +// This function uses fallback behavior for non-text content - it returns a string representation +// via fmt.Sprintf for any content that cannot be extracted as text. This is a lossy operation +// intended for convenience in logging and display scenarios. +// +// For strict type validation, use ParseContent() instead, which returns an error for invalid content. +func GetTextFromContent(content any) string { + switch c := content.(type) { + case TextContent: + return c.Text + case map[string]any: + // Handle JSON unmarshaled content + if contentType, exists := c["type"]; exists && contentType == "text" { + if text, exists := c["text"].(string); exists { + return text + } + } + return fmt.Sprintf("%v", content) + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/constants.go b/vendor/github.com/mark3labs/mcp-go/server/constants.go new file mode 100644 index 000000000..e071b2ef4 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/constants.go @@ -0,0 +1,7 @@ +package server + +// Common HTTP header constants used across server transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/ctx.go b/vendor/github.com/mark3labs/mcp-go/server/ctx.go new file mode 100644 index 000000000..43f01bb68 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/ctx.go @@ -0,0 +1,8 @@ +package server + +type contextKey int + +const ( + // This const is used as key for context value lookup + requestHeader contextKey = iota +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/elicitation.go b/vendor/github.com/mark3labs/mcp-go/server/elicitation.go new file mode 100644 index 000000000..d3e6d3d4c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/elicitation.go @@ -0,0 +1,32 @@ +package server + +import ( + "context" + "errors" + + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + // ErrNoActiveSession is returned when there is no active session in the context + ErrNoActiveSession = errors.New("no active session") + // ErrElicitationNotSupported is returned when the session does not support elicitation + ErrElicitationNotSupported = errors.New("session does not support elicitation") +) + +// RequestElicitation sends an elicitation request to the client. +// The client must have declared elicitation capability during initialization. +// The session must implement SessionWithElicitation to support this operation. +func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, ErrNoActiveSession + } + + // Check if the session supports elicitation requests + if elicitationSession, ok := session.(SessionWithElicitation); ok { + return elicitationSession.RequestElicitation(ctx, request) + } + + return nil, ErrElicitationNotSupported +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/errors.go b/vendor/github.com/mark3labs/mcp-go/server/errors.go new file mode 100644 index 000000000..5e65f0760 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/errors.go @@ -0,0 +1,36 @@ +package server + +import ( + "errors" + "fmt" +) + +var ( + // Common server errors + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + + // Session-related errors + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources") + ErrSessionDoesNotSupportResourceTemplates = errors.New("session does not support resource templates") + ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") + + // Notification-related errors + ErrNotificationNotInitialized = errors.New("notification channel not initialized") + ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") +) + +// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration +type ErrDynamicPathConfig struct { + Method string +} + +func (e *ErrDynamicPathConfig) Error() string { + return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go new file mode 100644 index 000000000..4baa1c4e0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/hooks.go @@ -0,0 +1,532 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/hooks.go.tmpl +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. +type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// BeforeAnyHookFunc is a function that is called after the request is +// parsed but before the method is called. +type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) + +// OnSuccessHookFunc is a hook that will be called after the request +// successfully generates a result, but before the result is sent to the client. +type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) + +// OnErrorHookFunc is a hook that will be called when an error occurs, +// either during the request parsing or the method execution. +// +// Example usage: +// ``` +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // Check for specific error types using errors.Is +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported errors +// log.Printf("Capability not supported: %v", err) +// } +// +// // Use errors.As to get specific error types +// var parseErr = &UnparsableMessageError{} +// if errors.As(err, &parseErr) { +// // Access specific methods/fields of the error type +// log.Printf("Failed to parse message for method %s: %v", +// parseErr.GetMethod(), parseErr.Unwrap()) +// // Access the raw message that failed to parse +// rawMsg := parseErr.GetMessage() +// } +// +// // Check for specific resource/prompt/tool errors +// switch { +// case errors.Is(err, ErrResourceNotFound): +// log.Printf("Resource not found: %v", err) +// case errors.Is(err, ErrPromptNotFound): +// log.Printf("Prompt not found: %v", err) +// case errors.Is(err, ErrToolNotFound): +// log.Printf("Tool not found: %v", err) +// } +// }) +type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) + +// OnRequestInitializationFunc is a function that called before handle diff request method +// Should any errors arise during func execution, the service will promptly return the corresponding error message. +type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error + +type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) +type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) + +type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) +type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) + +type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest) +type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) + +type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) +type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) + +type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) +type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) + +type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest) +type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) + +type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest) +type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) + +type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest) +type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) + +type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest) +type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) + +type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) +type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) + +type Hooks struct { + OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc + OnBeforeAny []BeforeAnyHookFunc + OnSuccess []OnSuccessHookFunc + OnError []OnErrorHookFunc + OnRequestInitialization []OnRequestInitializationFunc + OnBeforeInitialize []OnBeforeInitializeFunc + OnAfterInitialize []OnAfterInitializeFunc + OnBeforePing []OnBeforePingFunc + OnAfterPing []OnAfterPingFunc + OnBeforeSetLevel []OnBeforeSetLevelFunc + OnAfterSetLevel []OnAfterSetLevelFunc + OnBeforeListResources []OnBeforeListResourcesFunc + OnAfterListResources []OnAfterListResourcesFunc + OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc + OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc + OnBeforeReadResource []OnBeforeReadResourceFunc + OnAfterReadResource []OnAfterReadResourceFunc + OnBeforeListPrompts []OnBeforeListPromptsFunc + OnAfterListPrompts []OnAfterListPromptsFunc + OnBeforeGetPrompt []OnBeforeGetPromptFunc + OnAfterGetPrompt []OnAfterGetPromptFunc + OnBeforeListTools []OnBeforeListToolsFunc + OnAfterListTools []OnAfterListToolsFunc + OnBeforeCallTool []OnBeforeCallToolFunc + OnAfterCallTool []OnAfterCallToolFunc +} + +func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { + c.OnBeforeAny = append(c.OnBeforeAny, hook) +} + +func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { + c.OnSuccess = append(c.OnSuccess, hook) +} + +// AddOnError registers a hook function that will be called when an error occurs. +// The error parameter contains the actual error object, which can be interrogated +// using Go's error handling patterns like errors.Is and errors.As. +// +// Example: +// ``` +// // Create a channel to receive errors for testing +// errChan := make(chan error, 1) +// +// // Register hook to capture and inspect errors +// hooks := &Hooks{} +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // For capability-related errors +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported +// errChan <- err +// return +// } +// +// // For parsing errors +// var parseErr = &UnparsableMessageError{} +// if errors.As(err, &parseErr) { +// // Handle unparsable message errors +// fmt.Printf("Failed to parse %s request: %v\n", +// parseErr.GetMethod(), parseErr.Unwrap()) +// errChan <- parseErr +// return +// } +// +// // For resource/prompt/tool not found errors +// if errors.Is(err, ErrResourceNotFound) || +// errors.Is(err, ErrPromptNotFound) || +// errors.Is(err, ErrToolNotFound) { +// // Handle not found errors +// errChan <- err +// return +// } +// +// // For other errors +// errChan <- err +// }) +// +// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) +// ``` +func (c *Hooks) AddOnError(hook OnErrorHookFunc) { + c.OnError = append(c.OnError, hook) +} + +func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { + if c == nil { + return + } + for _, hook := range c.OnBeforeAny { + hook(ctx, id, method, message) + } +} + +func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { + if c == nil { + return + } + for _, hook := range c.OnSuccess { + hook(ctx, id, method, message, result) + } +} + +// onError calls all registered error hooks with the error object. +// The err parameter contains the actual error that occurred, which implements +// the standard error interface and may be a wrapped error or custom error type. +// +// This allows consumer code to use Go's error handling patterns: +// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors +// - errors.As(err, &customErr) to extract custom error types +// +// Common error types include: +// - ErrUnsupported: When a capability is not enabled +// - UnparsableMessageError: When request parsing fails +// - ErrResourceNotFound: When a resource is not found +// - ErrPromptNotFound: When a prompt is not found +// - ErrToolNotFound: When a tool is not found +func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + if c == nil { + return + } + for _, hook := range c.OnError { + hook(ctx, id, method, message, err) + } +} + +func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { + c.OnRegisterSession = append(c.OnRegisterSession, hook) +} + +func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnRegisterSession { + hook(ctx, session) + } +} + +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} + +func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) { + c.OnRequestInitialization = append(c.OnRequestInitialization, hook) +} + +func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error { + if c == nil { + return nil + } + for _, hook := range c.OnRequestInitialization { + err := hook(ctx, id, message) + if err != nil { + return err + } + } + return nil +} +func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { + c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) +} + +func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) { + c.OnAfterInitialize = append(c.OnAfterInitialize, hook) +} + +func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) { + c.beforeAny(ctx, id, mcp.MethodInitialize, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeInitialize { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + c.onSuccess(ctx, id, mcp.MethodInitialize, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterInitialize { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) { + c.OnBeforePing = append(c.OnBeforePing, hook) +} + +func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) { + c.OnAfterPing = append(c.OnAfterPing, hook) +} + +func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) { + c.beforeAny(ctx, id, mcp.MethodPing, message) + if c == nil { + return + } + for _, hook := range c.OnBeforePing { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodPing, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterPing { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) { + c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook) +} + +func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) { + c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook) +} + +func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) { + c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeSetLevel { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterSetLevel { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { + c.OnBeforeListResources = append(c.OnBeforeListResources, hook) +} + +func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) { + c.OnAfterListResources = append(c.OnAfterListResources, hook) +} + +func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResources { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResources { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) { + c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook) +} + +func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) { + c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook) +} + +func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResourceTemplates { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResourceTemplates { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) { + c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook) +} + +func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) { + c.OnAfterReadResource = append(c.OnAfterReadResource, hook) +} + +func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesRead, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeReadResource { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterReadResource { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) { + c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook) +} + +func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) { + c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook) +} + +func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListPrompts { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListPrompts { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) { + c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook) +} + +func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) { + c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook) +} + +func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsGet, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeGetPrompt { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterGetPrompt { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) { + c.OnBeforeListTools = append(c.OnBeforeListTools, hook) +} + +func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) { + c.OnAfterListTools = append(c.OnAfterListTools, hook) +} + +func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListTools { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + c.onSuccess(ctx, id, mcp.MethodToolsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListTools { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) { + c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook) +} + +func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) { + c.OnAfterCallTool = append(c.OnAfterCallTool, hook) +} + +func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsCall, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeCallTool { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterCallTool { + hook(ctx, id, message, result) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go new file mode 100644 index 000000000..4f5ad53d0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go @@ -0,0 +1,11 @@ +package server + +import ( + "context" + "net/http" +) + +// HTTPContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context diff --git a/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go b/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go new file mode 100644 index 000000000..59ab0f366 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go @@ -0,0 +1,165 @@ +package server + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +type SamplingHandler interface { + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +// ElicitationHandler defines the interface for handling elicitation requests from servers. +type ElicitationHandler interface { + Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + +// RootsHandler defines the interface for handling roots list requests from servers. +type RootsHandler interface { + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} + +type InProcessSession struct { + sessionID string + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value + clientCapabilities atomic.Value + samplingHandler SamplingHandler + elicitationHandler ElicitationHandler + rootsHandler RootsHandler + mu sync.RWMutex +} + +func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + } +} + +func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler, rootsHandler RootsHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + elicitationHandler: elicitationHandler, + rootsHandler: rootsHandler, + } +} + +func (s *InProcessSession) SessionID() string { + return s.sessionID +} + +func (s *InProcessSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *InProcessSession) Initialize() { + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *InProcessSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *InProcessSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *InProcessSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + handler := s.samplingHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no sampling handler available") + } + + return handler.CreateMessage(ctx, request) +} + +func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + handler := s.elicitationHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no elicitation handler available") + } + + return handler.Elicit(ctx, request) +} + +// ListRoots sends a list roots request to the client and waits for the response. +// Returns an error if no roots handler is available. +func (s *InProcessSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + s.mu.RLock() + handler := s.rootsHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no roots handler available") + } + + return handler.ListRoots(ctx, request) +} + +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func GenerateInProcessSessionID() string { + return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) +} + +// Ensure interface compliance +var ( + _ ClientSession = (*InProcessSession)(nil) + _ SessionWithLogging = (*InProcessSession)(nil) + _ SessionWithClientInfo = (*InProcessSession)(nil) + _ SessionWithSampling = (*InProcessSession)(nil) + _ SessionWithElicitation = (*InProcessSession)(nil) + _ SessionWithRoots = (*InProcessSession)(nil) +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go new file mode 100644 index 000000000..b9175dc4e --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go @@ -0,0 +1,339 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/request_handler.go.tmpl +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" +) + +// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response +func (s *MCPServer) HandleMessage( + ctx context.Context, + message json.RawMessage, +) mcp.JSONRPCMessage { + // Add server to context + ctx = context.WithValue(ctx, serverKey{}, s) + var err *requestError + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` + } + + if err := json.Unmarshal(message, &baseMessage); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse message", + ) + } + + // Check for valid JSONRPC version + if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + } + + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal(message, ¬ification); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse notification", + ) + } + s.handleNotification(ctx, notification) + return nil // Return nil for notifications + } + + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + + handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message) + if handleErr != nil { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + handleErr.Error(), + ) + } + + // Get request header from ctx + h := ctx.Value(requestHeader) + headers, ok := h.(http.Header) + + if headers == nil || !ok { + headers = make(http.Header) + } + + switch baseMessage.Method { + case mcp.MethodInitialize: + var request mcp.InitializeRequest + var result *mcp.InitializeResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) + result, err = s.handleInitialize(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPing: + var request mcp.PingRequest + var result *mcp.EmptyResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforePing(ctx, baseMessage.ID, &request) + result, err = s.handlePing(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterPing(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodSetLogLevel: + var request mcp.SetLevelRequest + var result *mcp.EmptyResult + if s.capabilities.logging == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("logging %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) + result, err = s.handleSetLevel(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesList: + var request mcp.ListResourcesRequest + var result *mcp.ListResourcesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListResources(ctx, baseMessage.ID, &request) + result, err = s.handleListResources(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesTemplatesList: + var request mcp.ListResourceTemplatesRequest + var result *mcp.ListResourceTemplatesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) + result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesRead: + var request mcp.ReadResourceRequest + var result *mcp.ReadResourceResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) + result, err = s.handleReadResource(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsList: + var request mcp.ListPromptsRequest + var result *mcp.ListPromptsResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) + result, err = s.handleListPrompts(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsGet: + var request mcp.GetPromptRequest + var result *mcp.GetPromptResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) + result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsList: + var request mcp.ListToolsRequest + var result *mcp.ListToolsResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListTools(ctx, baseMessage.ID, &request) + result, err = s.handleListTools(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsCall: + var request mcp.CallToolRequest + var result *mcp.CallToolResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) + result, err = s.handleToolCall(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + default: + return createErrorResponse( + baseMessage.ID, + mcp.METHOD_NOT_FOUND, + fmt.Sprintf("Method %s not found", baseMessage.Method), + ) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/roots.go b/vendor/github.com/mark3labs/mcp-go/server/roots.go new file mode 100644 index 000000000..29e0b94d1 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/roots.go @@ -0,0 +1,32 @@ +package server + +import ( + "context" + "errors" + + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + // ErrNoClientSession is returned when there is no active client session in the context + ErrNoClientSession = errors.New("no active client session") + // ErrRootsNotSupported is returned when the session does not support roots + ErrRootsNotSupported = errors.New("session does not support roots") +) + +// RequestRoots sends an list roots request to the client. +// The client must have declared roots capability during initialization. +// The session must implement SessionWithRoots to support this operation. +func (s *MCPServer) RequestRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, ErrNoClientSession + } + + // Check if the session supports roots requests + if rootsSession, ok := session.(SessionWithRoots); ok { + return rootsSession.ListRoots(ctx, request) + } + + return nil, ErrRootsNotSupported +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 000000000..2118db155 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,61 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() + + enabled := true + s.capabilities.sampling = &enabled +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + // Check for inprocess sampling handler in context + if handler := InProcessSamplingHandlerFromContext(ctx); handler != nil { + return handler.CreateMessage(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +// inProcessSamplingHandlerKey is the context key for storing inprocess sampling handler +type inProcessSamplingHandlerKey struct{} + +// WithInProcessSamplingHandler adds a sampling handler to the context for inprocess clients +func WithInProcessSamplingHandler(ctx context.Context, handler SamplingHandler) context.Context { + return context.WithValue(ctx, inProcessSamplingHandlerKey{}, handler) +} + +// InProcessSamplingHandlerFromContext retrieves the inprocess sampling handler from context +func InProcessSamplingHandlerFromContext(ctx context.Context) SamplingHandler { + if handler, ok := ctx.Value(inProcessSamplingHandlerKey{}).(SamplingHandler); ok { + return handler + } + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go new file mode 100644 index 000000000..d46fc868d --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/server.go @@ -0,0 +1,1337 @@ +// Package server provides MCP (Model Context Protocol) server implementations. +package server + +import ( + "cmp" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "maps" + "slices" + "sort" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// resourceEntry holds both a resource and its handler +type resourceEntry struct { + resource mcp.Resource + handler ResourceHandlerFunc +} + +// resourceTemplateEntry holds both a template and its handler +type resourceTemplateEntry struct { + template mcp.ResourceTemplate + handler ResourceTemplateHandlerFunc +} + +// ServerOption is a function that configures an MCPServer. +type ServerOption func(*MCPServer) + +// ResourceHandlerFunc is a function that returns resource contents. +type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// ResourceTemplateHandlerFunc is a function that returns a resource template. +type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// PromptHandlerFunc handles prompt requests with given arguments. +type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + +// ToolHandlerFunc handles tool calls with given arguments. +type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. +type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc + +// ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc. +type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc + +// ToolFilterFunc is a function that filters tools based on context, typically using session information. +type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool + +// ServerTool combines a Tool with its ToolHandlerFunc. +type ServerTool struct { + Tool mcp.Tool + Handler ToolHandlerFunc +} + +// ServerPrompt combines a Prompt with its handler function. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler PromptHandlerFunc +} + +// ServerResource combines a Resource with its handler function. +type ServerResource struct { + Resource mcp.Resource + Handler ResourceHandlerFunc +} + +// ServerResourceTemplate combines a ResourceTemplate with its handler function. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + Handler ResourceTemplateHandlerFunc +} + +// serverKey is the context key for storing the server instance +type serverKey struct{} + +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv + } + return nil +} + +// UnparsableMessageError is attached to the RequestError when json.Unmarshal +// fails on the request. +type UnparsableMessageError struct { + message json.RawMessage + method mcp.MCPMethod + err error +} + +func (e *UnparsableMessageError) Error() string { + return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) +} + +func (e *UnparsableMessageError) Unwrap() error { + return e.err +} + +func (e *UnparsableMessageError) GetMessage() json.RawMessage { + return e.message +} + +func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { + return e.method +} + +// RequestError is an error that can be converted to a JSON-RPC error. +// Implements Unwrap() to allow inspecting the error chain. +type requestError struct { + id any + code int + err error +} + +func (e *requestError) Error() string { + return fmt.Sprintf("request error: %s", e.err) +} + +func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(e.id), + Error: mcp.NewJSONRPCErrorDetails(e.code, e.err.Error(), nil), + } +} + +func (e *requestError) Unwrap() error { + return e.err +} + +// NotificationHandlerFunc handles incoming notifications. +type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) + +// MCPServer implements a Model Context Protocol server that can handle various types of requests +// including resources, prompts, and tools. +type MCPServer struct { + // Separate mutexes for different resource types + resourcesMu sync.RWMutex + resourceMiddlewareMu sync.RWMutex + promptsMu sync.RWMutex + toolsMu sync.RWMutex + toolMiddlewareMu sync.RWMutex + notificationHandlersMu sync.RWMutex + capabilitiesMu sync.RWMutex + toolFiltersMu sync.RWMutex + + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + toolHandlerMiddlewares []ToolHandlerMiddleware + resourceHandlerMiddlewares []ResourceHandlerMiddleware + toolFilters []ToolFilterFunc + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + paginationLimit *int + sessions sync.Map + hooks *Hooks +} + +// WithPaginationLimit sets the pagination limit for the server. +func WithPaginationLimit(limit int) ServerOption { + return func(s *MCPServer) { + s.paginationLimit = &limit + } +} + +// serverCapabilities defines the supported features of the MCP server +type serverCapabilities struct { + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging *bool + sampling *bool + elicitation *bool + roots *bool +} + +// resourceCapabilities defines the supported resource-related features +type resourceCapabilities struct { + subscribe bool + listChanged bool +} + +// promptCapabilities defines the supported prompt-related features +type promptCapabilities struct { + listChanged bool +} + +// toolCapabilities defines the supported tool-related features +type toolCapabilities struct { + listChanged bool +} + +// WithResourceCapabilities configures resource-related server capabilities +func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.resources = &resourceCapabilities{ + subscribe: subscribe, + listChanged: listChanged, + } + } +} + +// WithToolHandlerMiddleware allows adding a middleware for the +// tool handler call chain. +func WithToolHandlerMiddleware( + toolHandlerMiddleware ToolHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.toolMiddlewareMu.Lock() + s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware) + s.toolMiddlewareMu.Unlock() + } +} + +// WithResourceHandlerMiddleware allows adding a middleware for the +// resource handler call chain. +func WithResourceHandlerMiddleware( + resourceHandlerMiddleware ResourceHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.resourceMiddlewareMu.Lock() + s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware) + s.resourceMiddlewareMu.Unlock() + } +} + +// WithResourceRecovery adds a middleware that recovers from panics in resource handlers. +func WithResourceRecovery() ServerOption { + return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc { + return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s resource handler: %v", + request.Params.URI, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + +// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools +func WithToolFilter( + toolFilter ToolFilterFunc, +) ServerOption { + return func(s *MCPServer) { + s.toolFiltersMu.Lock() + s.toolFilters = append(s.toolFilters, toolFilter) + s.toolFiltersMu.Unlock() + } +} + +// WithRecovery adds a middleware that recovers from panics in tool handlers. +func WithRecovery() ServerOption { + return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s tool handler: %v", + request.Params.Name, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + +// WithHooks allows adding hooks that will be called before or after +// either [all] requests or before / after specific request methods, or else +// prior to returning an error to the client. +func WithHooks(hooks *Hooks) ServerOption { + return func(s *MCPServer) { + s.hooks = hooks + } +} + +// WithPromptCapabilities configures prompt-related server capabilities +func WithPromptCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.prompts = &promptCapabilities{ + listChanged: listChanged, + } + } +} + +// WithToolCapabilities configures tool-related server capabilities +func WithToolCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.tools = &toolCapabilities{ + listChanged: listChanged, + } + } +} + +// WithLogging enables logging capabilities for the server +func WithLogging() ServerOption { + return func(s *MCPServer) { + s.capabilities.logging = mcp.ToBoolPtr(true) + } +} + +// WithElicitation enables elicitation capabilities for the server +func WithElicitation() ServerOption { + return func(s *MCPServer) { + s.capabilities.elicitation = mcp.ToBoolPtr(true) + } +} + +// WithRoots returns a ServerOption that enables the roots capability on the MCPServer +func WithRoots() ServerOption { + return func(s *MCPServer) { + s.capabilities.roots = mcp.ToBoolPtr(true) + } +} + +// WithInstructions sets the server instructions for the client returned in the initialize response +func WithInstructions(instructions string) ServerOption { + return func(s *MCPServer) { + s.instructions = instructions + } +} + +// NewMCPServer creates a new MCP server instance with the given name, version and options +func NewMCPServer( + name, version string, + opts ...ServerOption, +) *MCPServer { + s := &MCPServer{ + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0), + resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), + capabilities: serverCapabilities{ + tools: nil, + resources: nil, + prompts: nil, + logging: nil, + }, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func (s *MCPServer) GenerateInProcessSessionID() string { + return GenerateInProcessSessionID() +} + +// AddResources registers multiple resources at once +func (s *MCPServer) AddResources(resources ...ServerResource) { + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() + for _, entry := range resources { + s.resources[entry.Resource.URI] = resourceEntry{ + resource: entry.Resource, + handler: entry.Handler, + } + } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// SetResources replaces all existing resources with the provided list +func (s *MCPServer) SetResources(resources ...ServerResource) { + s.resourcesMu.Lock() + s.resources = make(map[string]resourceEntry, len(resources)) + s.resourcesMu.Unlock() + s.AddResources(resources...) +} + +// AddResource registers a new resource and its handler +func (s *MCPServer) AddResource( + resource mcp.Resource, + handler ResourceHandlerFunc, +) { + s.AddResources(ServerResource{Resource: resource, Handler: handler}) +} + +// DeleteResources removes resources from the server +func (s *MCPServer) DeleteResources(uris ...string) { + s.resourcesMu.Lock() + var exists bool + for _, uri := range uris { + if _, ok := s.resources[uri]; ok { + delete(s.resources, uri) + exists = true + } + } + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// RemoveResource removes a resource from the server +func (s *MCPServer) RemoveResource(uri string) { + s.resourcesMu.Lock() + _, exists := s.resources[uri] + if exists { + delete(s.resources, uri) + } + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// AddResourceTemplates registers multiple resource templates at once +func (s *MCPServer) AddResourceTemplates(resourceTemplates ...ServerResourceTemplate) { + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() + for _, entry := range resourceTemplates { + s.resourceTemplates[entry.Template.URITemplate.Raw()] = resourceTemplateEntry{ + template: entry.Template, + handler: entry.Handler, + } + } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// SetResourceTemplates replaces all existing resource templates with the provided list +func (s *MCPServer) SetResourceTemplates(templates ...ServerResourceTemplate) { + s.resourcesMu.Lock() + s.resourceTemplates = make(map[string]resourceTemplateEntry, len(templates)) + s.resourcesMu.Unlock() + s.AddResourceTemplates(templates...) +} + +// AddResourceTemplate registers a new resource template and its handler +func (s *MCPServer) AddResourceTemplate( + template mcp.ResourceTemplate, + handler ResourceTemplateHandlerFunc, +) { + s.AddResourceTemplates(ServerResourceTemplate{Template: template, Handler: handler}) +} + +// AddPrompts registers multiple prompts at once +func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { + s.implicitlyRegisterPromptCapabilities() + + s.promptsMu.Lock() + for _, entry := range prompts { + s.prompts[entry.Prompt.Name] = entry.Prompt + s.promptHandlers[entry.Prompt.Name] = entry.Handler + } + s.promptsMu.Unlock() + + // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } +} + +// AddPrompt registers a new prompt handler with the given name +func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { + s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) +} + +// SetPrompts replaces all existing prompts with the provided list +func (s *MCPServer) SetPrompts(prompts ...ServerPrompt) { + s.promptsMu.Lock() + s.prompts = make(map[string]mcp.Prompt, len(prompts)) + s.promptHandlers = make(map[string]PromptHandlerFunc, len(prompts)) + s.promptsMu.Unlock() + s.AddPrompts(prompts...) +} + +// DeletePrompts removes prompts from the server +func (s *MCPServer) DeletePrompts(names ...string) { + s.promptsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.prompts[name]; ok { + delete(s.prompts, name) + delete(s.promptHandlers, name) + exists = true + } + } + s.promptsMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt + if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } +} + +// AddTool registers a new tool and its handler +func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { + s.AddTools(ServerTool{Tool: tool, Handler: handler}) +} + +// Register tool capabilities due to a tool being added. Default to +// listChanged: true, but don't change the value if we've already explicitly +// registered tools.listChanged false. +func (s *MCPServer) implicitlyRegisterToolCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.tools != nil }, + func() { s.capabilities.tools = &toolCapabilities{listChanged: true} }, + ) +} + +func (s *MCPServer) implicitlyRegisterResourceCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterPromptCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.prompts != nil }, + func() { s.capabilities.prompts = &promptCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) { + s.capabilitiesMu.RLock() + if check() { + s.capabilitiesMu.RUnlock() + return + } + s.capabilitiesMu.RUnlock() + + s.capabilitiesMu.Lock() + if !check() { + register() + } + s.capabilitiesMu.Unlock() +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + s.implicitlyRegisterToolCapabilities() + + s.toolsMu.Lock() + for _, entry := range tools { + s.tools[entry.Tool.Name] = entry + } + s.toolsMu.Unlock() + + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } +} + +// SetTools replaces all existing tools with the provided list +func (s *MCPServer) SetTools(tools ...ServerTool) { + s.toolsMu.Lock() + s.tools = make(map[string]ServerTool, len(tools)) + s.toolsMu.Unlock() + s.AddTools(tools...) +} + +// GetTool retrieves the specified tool +func (s *MCPServer) GetTool(toolName string) *ServerTool { + s.toolsMu.RLock() + defer s.toolsMu.RUnlock() + if tool, ok := s.tools[toolName]; ok { + return &tool + } + return nil +} + +func (s *MCPServer) ListTools() map[string]*ServerTool { + s.toolsMu.RLock() + defer s.toolsMu.RUnlock() + if len(s.tools) == 0 { + return nil + } + // Create a copy to prevent external modification + toolsCopy := make(map[string]*ServerTool, len(s.tools)) + for name, tool := range s.tools { + toolsCopy[name] = &tool + } + return toolsCopy +} + +// DeleteTools removes tools from the server +func (s *MCPServer) DeleteTools(names ...string) { + s.toolsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.tools[name]; ok { + delete(s.tools, name) + exists = true + } + } + s.toolsMu.Unlock() + + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } +} + +// AddNotificationHandler registers a new handler for incoming notifications +func (s *MCPServer) AddNotificationHandler( + method string, + handler NotificationHandlerFunc, +) { + s.notificationHandlersMu.Lock() + defer s.notificationHandlersMu.Unlock() + s.notificationHandlers[method] = handler +} + +func (s *MCPServer) handleInitialize( + ctx context.Context, + _ any, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, *requestError) { + capabilities := mcp.ServerCapabilities{} + + // Only add resource capabilities if they're configured + if s.capabilities.resources != nil { + capabilities.Resources = &struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` + }{ + Subscribe: s.capabilities.resources.subscribe, + ListChanged: s.capabilities.resources.listChanged, + } + } + + // Only add prompt capabilities if they're configured + if s.capabilities.prompts != nil { + capabilities.Prompts = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.prompts.listChanged, + } + } + + // Only add tool capabilities if they're configured + if s.capabilities.tools != nil { + capabilities.Tools = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.tools.listChanged, + } + } + + if s.capabilities.logging != nil && *s.capabilities.logging { + capabilities.Logging = &struct{}{} + } + + if s.capabilities.sampling != nil && *s.capabilities.sampling { + capabilities.Sampling = &struct{}{} + } + + if s.capabilities.elicitation != nil && *s.capabilities.elicitation { + capabilities.Elicitation = &struct{}{} + } + + if s.capabilities.roots != nil && *s.capabilities.roots { + capabilities.Roots = &struct{}{} + } + + result := mcp.InitializeResult{ + ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), + ServerInfo: mcp.Implementation{ + Name: s.name, + Version: s.version, + }, + Capabilities: capabilities, + Instructions: s.instructions, + } + + if session := ClientSessionFromContext(ctx); session != nil { + session.Initialize() + + // Store client info if the session supports it + if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { + sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) + } + } + + return &result, nil +} + +func (s *MCPServer) protocolVersion(clientVersion string) string { + // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header, + // and has no other way to identify the version - for example, by relying on the protocol version negotiated + // during initialization - the server SHOULD assume protocol version 2025-03-26 + // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + if len(clientVersion) == 0 { + clientVersion = "2025-03-26" + } + + if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { + return clientVersion + } + + return mcp.LATEST_PROTOCOL_VERSION +} + +func (s *MCPServer) handlePing( + _ context.Context, + _ any, + _ mcp.PingRequest, +) (*mcp.EmptyResult, *requestError) { + return &mcp.EmptyResult{}, nil +} + +func (s *MCPServer) handleSetLevel( + ctx context.Context, + id any, + request mcp.SetLevelRequest, +) (*mcp.EmptyResult, *requestError) { + clientSession := ClientSessionFromContext(ctx) + if clientSession == nil || !clientSession.Initialized() { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionNotInitialized, + } + } + + sessionLogging, ok := clientSession.(SessionWithLogging) + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionDoesNotSupportLogging, + } + } + + level := request.Params.Level + // Validate logging level + switch level { + case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice, + mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical, + mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency: + // Valid level + default: + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("invalid logging level '%s'", level), + } + } + + sessionLogging.SetLogLevel(level) + + return &mcp.EmptyResult{}, nil +} + +func listByPagination[T mcp.Named]( + _ context.Context, + s *MCPServer, + cursor mcp.Cursor, + allElements []T, +) ([]T, mcp.Cursor, error) { + startPos := 0 + if cursor != "" { + c, err := base64.StdEncoding.DecodeString(string(cursor)) + if err != nil { + return nil, "", err + } + cString := string(c) + startPos = sort.Search(len(allElements), func(i int) bool { + return allElements[i].GetName() > cString + }) + } + endPos := len(allElements) + if s.paginationLimit != nil { + if len(allElements) > startPos+*s.paginationLimit { + endPos = startPos + *s.paginationLimit + } + } + elementsToReturn := allElements[startPos:endPos] + // set the next cursor + nextCursor := func() mcp.Cursor { + if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit { + nc := elementsToReturn[len(elementsToReturn)-1].GetName() + toString := base64.StdEncoding.EncodeToString([]byte(nc)) + return mcp.Cursor(toString) + } + return "" + }() + return elementsToReturn, nextCursor, nil +} + +func (s *MCPServer) handleListResources( + ctx context.Context, + id any, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, *requestError) { + s.resourcesMu.RLock() + resourceMap := make(map[string]mcp.Resource, len(s.resources)) + for uri, entry := range s.resources { + resourceMap[uri] = entry.resource + } + s.resourcesMu.RUnlock() + + // Check if there are session-specific resources + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithResources, ok := session.(SessionWithResources); ok { + if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil { + // Merge session-specific resources with global resources + for uri, serverResource := range sessionResources { + resourceMap[uri] = serverResource.Resource + } + } + } + } + + // Sort the resources by name + resourcesList := slices.SortedFunc(maps.Values(resourceMap), func(a, b mcp.Resource) int { + return cmp.Compare(a.Name, b.Name) + }) + + // Apply pagination + resourcesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + resourcesList, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListResourcesResult{ + Resources: resourcesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleListResourceTemplates( + ctx context.Context, + id any, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, *requestError) { + // Get global templates + s.resourcesMu.RLock() + templateMap := make(map[string]mcp.ResourceTemplate, len(s.resourceTemplates)) + for uri, entry := range s.resourceTemplates { + templateMap[uri] = entry.template + } + s.resourcesMu.RUnlock() + + // Check if there are session-specific resource templates + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { + if sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates(); sessionTemplates != nil { + // Merge session-specific templates with global templates + // Session templates override global ones + for uriTemplate, serverTemplate := range sessionTemplates { + templateMap[uriTemplate] = serverTemplate.Template + } + } + } + } + + // Convert map to slice for sorting and pagination + templates := make([]mcp.ResourceTemplate, 0, len(templateMap)) + for _, template := range templateMap { + templates = append(templates, template) + } + + sort.Slice(templates, func(i, j int) bool { + return templates[i].Name < templates[j].Name + }) + templatesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + templates, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListResourceTemplatesResult{ + ResourceTemplates: templatesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleReadResource( + ctx context.Context, + id any, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, *requestError) { + s.resourcesMu.RLock() + + // First check session-specific resources + var handler ResourceHandlerFunc + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithResources, typeAssertOk := session.(SessionWithResources); typeAssertOk { + if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil { + resource, sessionOk := sessionResources[request.Params.URI] + if sessionOk { + handler = resource.Handler + ok = true + } + } + } + } + + // If not found in session tools, check global tools + if !ok { + globalResource, rok := s.resources[request.Params.URI] + if rok { + handler = globalResource.handler + ok = true + } + } + + // First try direct resource handlers + if ok { + s.resourcesMu.RUnlock() + + finalHandler := handler + s.resourceMiddlewareMu.RLock() + mw := s.resourceHandlerMiddlewares + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.resourceMiddlewareMu.RUnlock() + + contents, err := finalHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + // If no direct handler found, try matching against templates + var matchedHandler ResourceTemplateHandlerFunc + var matched bool + + // First check session templates if available + if session != nil { + if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { + sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates() + for _, serverTemplate := range sessionTemplates { + if serverTemplate.Template.URITemplate == nil { + continue + } + if matchesTemplate(request.Params.URI, serverTemplate.Template.URITemplate) { + matchedHandler = serverTemplate.Handler + matched = true + matchedVars := serverTemplate.Template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]any, len(matchedVars)) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + } + } + + // If not found in session templates, check global templates + if !matched { + for _, entry := range s.resourceTemplates { + template := entry.template + if template.URITemplate == nil { + continue + } + if matchesTemplate(request.Params.URI, template.URITemplate) { + matchedHandler = entry.handler + matched = true + matchedVars := template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]any, len(matchedVars)) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + } + s.resourcesMu.RUnlock() + + if matched { + // If a match is found, then we have a final handler and can + // apply middlewares. + s.resourceMiddlewareMu.RLock() + finalHandler := ResourceHandlerFunc(matchedHandler) + mw := s.resourceHandlerMiddlewares + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.resourceMiddlewareMu.RUnlock() + contents, err := finalHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + return nil, &requestError{ + id: id, + code: mcp.RESOURCE_NOT_FOUND, + err: fmt.Errorf( + "handler not found for resource URI '%s': %w", + request.Params.URI, + ErrResourceNotFound, + ), + } +} + +// matchesTemplate checks if a URI matches a URI template pattern +func matchesTemplate(uri string, template *mcp.URITemplate) bool { + return template.Regexp().MatchString(uri) +} + +func (s *MCPServer) handleListPrompts( + ctx context.Context, + id any, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, *requestError) { + s.promptsMu.RLock() + prompts := make([]mcp.Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, prompt) + } + s.promptsMu.RUnlock() + + // sort prompts by name + sort.Slice(prompts, func(i, j int) bool { + return prompts[i].Name < prompts[j].Name + }) + promptsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + prompts, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListPromptsResult{ + Prompts: promptsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleGetPrompt( + ctx context.Context, + id any, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, *requestError) { + s.promptsMu.RLock() + handler, ok := s.promptHandlers[request.Params.Name] + s.promptsMu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound), + } + } + + result, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleListTools( + ctx context.Context, + id any, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, *requestError) { + // Get the base tools from the server + s.toolsMu.RLock() + tools := make([]mcp.Tool, 0, len(s.tools)) + + // Get all tool names for consistent ordering + toolNames := make([]string, 0, len(s.tools)) + for name := range s.tools { + toolNames = append(toolNames, name) + } + + // Sort the tool names for consistent ordering + sort.Strings(toolNames) + + // Add tools in sorted order + for _, name := range toolNames { + tools = append(tools, s.tools[name].Tool) + } + s.toolsMu.RUnlock() + + // Check if there are session-specific tools + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + // Override or add session-specific tools + // We need to create a map first to merge the tools properly + toolMap := make(map[string]mcp.Tool) + + // Add global tools first + for _, tool := range tools { + toolMap[tool.Name] = tool + } + + // Then override with session-specific tools + for name, serverTool := range sessionTools { + toolMap[name] = serverTool.Tool + } + + // Convert back to slice + tools = make([]mcp.Tool, 0, len(toolMap)) + for _, tool := range toolMap { + tools = append(tools, tool) + } + + // Sort again to maintain consistent ordering + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + } + } + } + + // Apply tool filters if any are defined + s.toolFiltersMu.RLock() + if len(s.toolFilters) > 0 { + for _, filter := range s.toolFilters { + tools = filter(ctx, tools) + } + } + s.toolFiltersMu.RUnlock() + + // Apply pagination + toolsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + tools, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + + result := mcp.ListToolsResult{ + Tools: toolsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleToolCall( + ctx context.Context, + id any, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, *requestError) { + // First check session-specific tools + var tool ServerTool + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + var sessionOk bool + tool, sessionOk = sessionTools[request.Params.Name] + if sessionOk { + ok = true + } + } + } + } + + // If not found in session tools, check global tools + if !ok { + s.toolsMu.RLock() + tool, ok = s.tools[request.Params.Name] + s.toolsMu.RUnlock() + } + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound), + } + } + + finalHandler := tool.Handler + + s.toolMiddlewareMu.RLock() + mw := s.toolHandlerMiddlewares + + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.toolMiddlewareMu.RUnlock() + + result, err := finalHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) mcp.JSONRPCMessage { + s.notificationHandlersMu.RLock() + handler, ok := s.notificationHandlers[notification.Method] + s.notificationHandlersMu.RUnlock() + + if ok { + handler(ctx, notification) + } + return nil +} + +func createResponse(id any, result any) mcp.JSONRPCMessage { + return mcp.NewJSONRPCResultResponse(mcp.NewRequestId(id), result) +} + +func createErrorResponse( + id any, + code int, + message string, +) mcp.JSONRPCMessage { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(id), + Error: mcp.NewJSONRPCErrorDetails(code, message, nil), + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/session.go b/vendor/github.com/mark3labs/mcp-go/server/session.go new file mode 100644 index 000000000..48fd52d75 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/session.go @@ -0,0 +1,770 @@ +package server + +import ( + "context" + "fmt" + "net/url" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level +type SessionWithLogging interface { + ClientSession + // SetLogLevel sets the minimum log level + SetLogLevel(level mcp.LoggingLevel) + // GetLogLevel retrieves the minimum log level + GetLogLevel() mcp.LoggingLevel +} + +// SessionWithTools is an extension of ClientSession that can store session-specific tool data +type SessionWithTools interface { + ClientSession + // GetSessionTools returns the tools specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionTools() map[string]ServerTool + // SetSessionTools sets tools specific to this session + // This method must be thread-safe for concurrent access + SetSessionTools(tools map[string]ServerTool) +} + +// SessionWithResources is an extension of ClientSession that can store session-specific resource data +type SessionWithResources interface { + ClientSession + // GetSessionResources returns the resources specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionResources() map[string]ServerResource + // SetSessionResources sets resources specific to this session + // This method must be thread-safe for concurrent access + SetSessionResources(resources map[string]ServerResource) +} + +// SessionWithResourceTemplates is an extension of ClientSession that can store session-specific resource template data +type SessionWithResourceTemplates interface { + ClientSession + // GetSessionResourceTemplates returns the resource templates specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionResourceTemplates() map[string]ServerResourceTemplate + // SetSessionResourceTemplates sets resource templates specific to this session + // This method must be thread-safe for concurrent access + SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) +} + +// SessionWithClientInfo is an extension of ClientSession that can store client info +type SessionWithClientInfo interface { + ClientSession + // GetClientInfo returns the client information for this session + GetClientInfo() mcp.Implementation + // SetClientInfo sets the client information for this session + SetClientInfo(clientInfo mcp.Implementation) + // GetClientCapabilities returns the client capabilities for this session + GetClientCapabilities() mcp.ClientCapabilities + // SetClientCapabilities sets the client capabilities for this session + SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) +} + +// SessionWithElicitation is an extension of ClientSession that can send elicitation requests +type SessionWithElicitation interface { + ClientSession + // RequestElicitation sends an elicitation request to the client and waits for response + RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + +// SessionWithRoots is an extension of ClientSession that can send list roots requests +type SessionWithRoots interface { + ClientSession + // ListRoots sends an list roots request to the client and waits for response + ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) +} + +// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations +type SessionWithStreamableHTTPConfig interface { + ClientSession + // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server + // sends notifications to the client + // + // The protocol specification: + // - If the server response contains any JSON-RPC notifications, it MUST either: + // - Return Content-Type: text/event-stream to initiate an SSE stream, OR + // - Return Content-Type: application/json for a single JSON object + // - The client MUST support both response types. + // + // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server + UpgradeToSSEWhenReceiveNotification() +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return ErrSessionExists + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification { + return mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: notification.Method, + Params: mcp.NotificationParams{ + AdditionalFields: map[string]any{ + "level": notification.Params.Level, + "logger": notification.Params.Logger, + "data": notification.Params.Data, + }, + }, + }, + } +} + +func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification)) +} + +func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + select { + case session.NotificationChannel() <- notification: + // Successfully sent notification + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + ctx := context.Background() + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + } + } + return true + }) +} + +func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification)) +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} + +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + s.sendNotificationToAllClients(notification) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) sendNotificationCore( + ctx context.Context, + session ClientSession, + notification mcp.JSONRPCNotification, +) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + method := notification.Method + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationCore(ctx, session, notification) +} + +// SendNotificationToSpecificClient sends a notification to a specific client by session ID +func (s *MCPServer) SendNotificationToSpecificClient( + sessionID string, + method string, + params map[string]any, +) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationToSpecificClient(session, notification) +} + +// AddSessionTool adds a tool for a specific session +func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { + return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) +} + +// AddSessionTools adds tools for a specific session +func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + s.implicitlyRegisterToolCapabilities() + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) + + // Copy existing tools + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Add new tools + for _, tool := range tools { + newSessionTools[tool.Tool.Name] = tool + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionTools removes tools from a specific session +func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + if sessionTools == nil { + return nil + } + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)) + + // Copy existing tools except those being deleted + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Remove specified tools + for _, name := range names { + delete(newSessionTools, name) + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// AddSessionResource adds a resource for a specific session +func (s *MCPServer) AddSessionResource(sessionID string, resource mcp.Resource, handler ResourceHandlerFunc) error { + return s.AddSessionResources(sessionID, ServerResource{Resource: resource, Handler: handler}) +} + +// AddSessionResources adds resources for a specific session +func (s *MCPServer) AddSessionResources(sessionID string, resources ...ServerResource) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResources) + if !ok { + return ErrSessionDoesNotSupportResources + } + + // For session resources, we want listChanged enabled by default + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} }, + ) + + // Get existing resources (this should return a thread-safe copy) + sessionResources := session.GetSessionResources() + + // Create a new map to avoid concurrent modification issues + newSessionResources := make(map[string]ServerResource, len(sessionResources)+len(resources)) + + // Copy existing resources + for k, v := range sessionResources { + newSessionResources[k] = v + } + + // Add new resources with validation + for _, resource := range resources { + // Validate that URI is non-empty + if resource.Resource.URI == "" { + return fmt.Errorf("resource URI cannot be empty") + } + + // Validate that URI conforms to RFC 3986 + if _, err := url.ParseRequestURI(resource.Resource.URI); err != nil { + return fmt.Errorf("invalid resource URI: %w", err) + } + + newSessionResources[resource.Resource.URI] = resource + } + + // Set the resources (this should be thread-safe) + session.SetSessionResources(newSessionResources) + + // It only makes sense to send resource notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial resources/list message. + // + // For initialized sessions, honor resources.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The resources were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding resources: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionResources removes resources from a specific session +func (s *MCPServer) DeleteSessionResources(sessionID string, uris ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResources) + if !ok { + return ErrSessionDoesNotSupportResources + } + + // Get existing resources (this should return a thread-safe copy) + sessionResources := session.GetSessionResources() + if sessionResources == nil { + return nil + } + + // Create a new map to avoid concurrent modification issues + newSessionResources := make(map[string]ServerResource, len(sessionResources)) + + // Copy existing resources except those being deleted + for k, v := range sessionResources { + newSessionResources[k] = v + } + + // Remove specified resources and track if anything was actually deleted + actuallyDeleted := false + for _, uri := range uris { + if _, exists := newSessionResources[uri]; exists { + delete(newSessionResources, uri) + actuallyDeleted = true + } + } + + // Skip no-op write if nothing was actually deleted + if !actuallyDeleted { + return nil + } + + // Set the resources (this should be thread-safe) + session.SetSessionResources(newSessionResources) + + // It only makes sense to send resource notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial resources/list message. + // + // For initialized sessions, honor resources.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + // Only send notification if something was actually deleted + if actuallyDeleted && session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The resources were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting resources: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// AddSessionResourceTemplate adds a resource template for a specific session +func (s *MCPServer) AddSessionResourceTemplate(sessionID string, template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc) error { + return s.AddSessionResourceTemplates(sessionID, ServerResourceTemplate{ + Template: template, + Handler: handler, + }) +} + +// AddSessionResourceTemplates adds resource templates for a specific session +func (s *MCPServer) AddSessionResourceTemplates(sessionID string, templates ...ServerResourceTemplate) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResourceTemplates) + if !ok { + return ErrSessionDoesNotSupportResourceTemplates + } + + // For session resource templates, enable listChanged by default + // This is the same behavior as session resources + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} }, + ) + + // Get existing templates (this returns a thread-safe copy) + sessionTemplates := session.GetSessionResourceTemplates() + + // Create a new map to avoid modifying the returned copy + newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)+len(templates)) + + // Copy existing templates + for k, v := range sessionTemplates { + newTemplates[k] = v + } + + // Validate and add new templates + for _, t := range templates { + if t.Template.URITemplate == nil { + return fmt.Errorf("resource template URITemplate cannot be nil") + } + raw := t.Template.URITemplate.Raw() + if raw == "" { + return fmt.Errorf("resource template URITemplate cannot be empty") + } + if t.Template.Name == "" { + return fmt.Errorf("resource template name cannot be empty") + } + newTemplates[raw] = t + } + + // Set the new templates (this method must handle thread-safety) + session.SetSessionResourceTemplates(newTemplates) + + // Send notification if the session is initialized and listChanged is enabled + if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding resource templates: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionResourceTemplates removes resource templates from a specific session +func (s *MCPServer) DeleteSessionResourceTemplates(sessionID string, uriTemplates ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithResourceTemplates) + if !ok { + return ErrSessionDoesNotSupportResourceTemplates + } + + // Get existing templates (this returns a thread-safe copy) + sessionTemplates := session.GetSessionResourceTemplates() + + // Track if any were actually deleted + deletedAny := false + + // Create a new map without the deleted templates + newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)) + for k, v := range sessionTemplates { + newTemplates[k] = v + } + + // Delete specified templates + for _, uriTemplate := range uriTemplates { + if _, exists := newTemplates[uriTemplate]; exists { + delete(newTemplates, uriTemplate) + deletedAny = true + } + } + + // Only update if something was actually deleted + if deletedAny { + // Set the new templates (this method must handle thread-safety) + session.SetSessionResourceTemplates(newTemplates) + + // Send notification if the session is initialized and listChanged is enabled + if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { + // Log the error but don't fail the operation + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/resources/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting resource templates: %w", err)) + }(sessionID, hooks) + } + } + } + } + + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go new file mode 100644 index 000000000..97c765cc7 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sse.go @@ -0,0 +1,797 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "github.com/mark3labs/mcp-go/mcp" +) + +// sseSession represents an active SSE connection. +type sseSession struct { + done chan struct{} + eventQueue chan string // Channel for queuing events + sessionID string + requestID atomic.Int64 + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + tools sync.Map // stores session-specific tools + resources sync.Map // stores session-specific resources + resourceTemplates sync.Map // stores session-specific resource templates + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities +} + +// SSEContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context + +// DynamicBasePathFunc allows the user to provide a function to generate the +// base path for a given request and sessionID. This is useful for cases where +// the base path is not known at the time of SSE server creation, such as when +// using a reverse proxy or when the base path is dynamically generated. The +// function should return the base path (e.g., "/mcp/tenant123"). +type DynamicBasePathFunc func(r *http.Request, sessionID string) string + +func (s *sseSession) SessionID() string { + return s.sessionID +} + +func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *sseSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *sseSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *sseSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *sseSession) GetSessionResources() map[string]ServerResource { + resources := make(map[string]ServerResource) + s.resources.Range(func(key, value any) bool { + if resource, ok := value.(ServerResource); ok { + resources[key.(string)] = resource + } + return true + }) + return resources +} + +func (s *sseSession) SetSessionResources(resources map[string]ServerResource) { + // Clear existing resources + s.resources.Clear() + + // Set new resources + for name, resource := range resources { + s.resources.Store(name, resource) + } +} + +func (s *sseSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { + templates := make(map[string]ServerResourceTemplate) + s.resourceTemplates.Range(func(key, value any) bool { + if template, ok := value.(ServerResourceTemplate); ok { + templates[key.(string)] = template + } + return true + }) + return templates +} + +func (s *sseSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { + // Clear existing templates + s.resourceTemplates.Clear() + + // Set new templates + for uriTemplate, template := range templates { + s.resourceTemplates.Store(uriTemplate, template) + } +} + +func (s *sseSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.tools.Range(func(key, value any) bool { + if tool, ok := value.(ServerTool); ok { + tools[key.(string)] = tool + } + return true + }) + return tools +} + +func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.tools.Clear() + + // Set new tools + for name, tool := range tools { + s.tools.Store(name, tool) + } +} + +func (s *sseSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +var ( + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithResources = (*sseSession)(nil) + _ SessionWithResourceTemplates = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) +) + +// SSEServer implements a Server-Sent Events (SSE) based MCP server. +// It provides real-time communication capabilities over HTTP using the SSE protocol. +type SSEServer struct { + server *MCPServer + baseURL string + basePath string + appendQueryToMessageEndpoint bool + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc + dynamicBasePathFunc DynamicBasePathFunc + + keepAlive bool + keepAliveInterval time.Duration + + mu sync.RWMutex +} + +// SSEOption defines a function type for configuring SSEServer +type SSEOption func(*SSEServer) + +// WithBaseURL sets the base URL for the SSE server +func WithBaseURL(baseURL string) SSEOption { + return func(s *SSEServer) { + if baseURL != "" { + u, err := url.Parse(baseURL) + if err != nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + // Check if the host is empty or only contains a port + if u.Host == "" || strings.HasPrefix(u.Host, ":") { + return + } + if len(u.Query()) > 0 { + return + } + } + s.baseURL = strings.TrimSuffix(baseURL, "/") + } +} + +// WithStaticBasePath adds a new option for setting a static base path +func WithStaticBasePath(basePath string) SSEOption { + return func(s *SSEServer) { + s.basePath = normalizeURLPath(basePath) + } +} + +// WithBasePath adds a new option for setting a static base path. +// +// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. +// +//go:deprecated +func WithBasePath(basePath string) SSEOption { + return WithStaticBasePath(basePath) +} + +// WithDynamicBasePath accepts a function for generating the base path. This is +// useful for cases where the base path is not known at the time of SSE server +// creation, such as when using a reverse proxy or when the server is mounted +// at a dynamic path. +func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { + return func(s *SSEServer) { + if fn != nil { + s.dynamicBasePathFunc = func(r *http.Request, sid string) string { + bp := fn(r, sid) + return normalizeURLPath(bp) + } + } + } +} + +// WithMessageEndpoint sets the message endpoint path +func WithMessageEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.messageEndpoint = endpoint + } +} + +// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's +// query parameters to the message endpoint URL that is sent to clients during the SSE connection +// initialization. This is useful when you need to preserve query parameters from the initial +// SSE connection request and carry them over to subsequent message requests, maintaining +// context or authentication details across the communication channel. +func WithAppendQueryToMessageEndpoint() SSEOption { + return func(s *SSEServer) { + s.appendQueryToMessageEndpoint = true + } +} + +// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) +// or just the path portion for the message endpoint. Set to false when clients will concatenate +// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". +func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { + return func(s *SSEServer) { + s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint + } +} + +// WithSSEEndpoint sets the SSE endpoint path +func WithSSEEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.sseEndpoint = endpoint + } +} + +// WithHTTPServer sets the HTTP server instance. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithHTTPServer(srv *http.Server) SSEOption { + return func(s *SSEServer) { + s.srv = srv + } +} + +func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption { + return func(s *SSEServer) { + s.keepAlive = true + s.keepAliveInterval = keepAliveInterval + } +} + +func WithKeepAlive(keepAlive bool) SSEOption { + return func(s *SSEServer) { + s.keepAlive = keepAlive + } +} + +// WithSSEContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +func WithSSEContextFunc(fn SSEContextFunc) SSEOption { + return func(s *SSEServer) { + s.contextFunc = fn + } +} + +// NewSSEServer creates a new SSE server instance with the given MCP server and options. +func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { + s := &SSEServer{ + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + + return s +} + +// NewTestServer creates a test server for testing purposes +func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { + sseServer := NewSSEServer(server, opts...) + + testServer := httptest.NewServer(sseServer) + sseServer.baseURL = testServer.URL + return testServer +} + +// Start begins serving SSE connections on the specified address. +// It sets up HTTP handlers for SSE and message endpoints. +func (s *SSEServer) Start(addr string) error { + s.mu.Lock() + if s.srv == nil { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + } else { + if s.srv.Addr == "" { + s.srv.Addr = addr + } else if s.srv.Addr != addr { + return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr) + } + } + srv := s.srv + s.mu.Unlock() + + return srv.ListenAndServe() +} + +// Shutdown gracefully stops the SSE server, closing all active sessions +// and shutting down the HTTP server. +func (s *SSEServer) Shutdown(ctx context.Context) error { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { + s.sessions.Range(func(key, value any) bool { + if session, ok := value.(*sseSession); ok { + close(session.done) + } + s.sessions.Delete(key) + return true + }) + + return srv.Shutdown(ctx) + } + return nil +} + +// handleSSE handles incoming SSE connection requests. +// It sets up appropriate headers and creates a new session for the client. +func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + sessionID := uuid.New().String() + session := &sseSession{ + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + } + + s.sessions.Store(sessionID, session) + defer s.sessions.Delete(sessionID) + + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error( + w, + fmt.Sprintf("Session registration failed: %v", err), + http.StatusInternalServerError, + ) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + + // Start notification handler for this session + go func() { + for { + select { + case notification := <-session.notificationChannel: + eventData, err := json.Marshal(notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return + } + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + + // Start keep alive : ping + if s.keepAlive { + go func() { + ticker := time.NewTicker(s.keepAliveInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(session.requestID.Add(1)), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + select { + case session.eventQueue <- pingMsg: + // Message sent successfully + case <-session.done: + return + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + } + + // Send the initial endpoint event + endpoint := s.GetMessageEndpointForClient(r, sessionID) + if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 { + endpoint += "&" + r.URL.RawQuery + } + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint) + flusher.Flush() + + // Main event loop - this runs in the HTTP handler goroutine + for { + select { + case event := <-session.eventQueue: + // Write the event to the response + fmt.Fprint(w, event) + flusher.Flush() + case <-r.Context().Done(): + close(session.done) + return + case <-session.done: + return + } + } +} + +// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID +// for the given request. This is the canonical way to compute the message endpoint for a client. +// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag. +func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string { + basePath := s.basePath + if s.dynamicBasePathFunc != nil { + basePath = s.dynamicBasePathFunc(r, sessionID) + } + + endpointPath := normalizeURLPath(basePath, s.messageEndpoint) + if s.useFullURLForMessageEndpoint && s.baseURL != "" { + endpointPath = s.baseURL + endpointPath + } + + return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) +} + +// handleMessage processes incoming JSON-RPC messages from clients and sends responses +// back through the SSE connection and 202 code to HTTP response. +func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") + return + } + + sessionID := r.URL.Query().Get("sessionId") + if sessionID == "" { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") + return + } + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") + return + } + session := sessionI.(*sseSession) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Parse message as raw JSON + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") + return + } + + // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. + // this is required because the http ctx will be canceled when the client disconnects + detachedCtx := context.WithoutCancel(ctx) + + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE + w.WriteHeader(http.StatusAccepted) + + // Create a new context for handling the message that will be canceled when the message handling is done + messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header) + messageCtx, cancel := context.WithCancel(messageCtx) + + go func(ctx context.Context) { + defer cancel() + // Use the context that will be canceled when session is done + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + // Only send response if there is one (not for notifications) + if response != nil { + var message string + if eventData, err := json.Marshal(response); err != nil { + // If there is an error marshalling the response, send a generic error response + log.Printf("failed to marshal response: %v", err) + message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" + } else { + message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- message: + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, log this situation + log.Printf("Event queue full for session %s", sessionID) + } + } + }(messageCtx) +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *SSEServer) writeJSONRPCError( + w http.ResponseWriter, + id any, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error( + w, + fmt.Sprintf("Failed to encode response: %v", err), + http.StatusInternalServerError, + ) + return + } +} + +// SendEventToSession sends an event to a specific SSE session identified by sessionID. +// Returns an error if the session is not found or closed. +func (s *SSEServer) SendEventToSession( + sessionID string, + event any, +) error { + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + return fmt.Errorf("session not found: %s", sessionID) + } + session := sessionI.(*sseSession) + + eventData, err := json.Marshal(event) + if err != nil { + return err + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + return nil + case <-session.done: + return fmt.Errorf("session closed") + default: + return fmt.Errorf("event queue full") + } +} + +func (s *SSEServer) GetUrlPath(input string) (string, error) { + parse, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("failed to parse URL %s: %w", input, err) + } + return parse.Path, nil +} + +func (s *SSEServer) CompleteSseEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} + } + + path := normalizeURLPath(s.basePath, s.sseEndpoint) + return s.baseURL + path, nil +} + +func (s *SSEServer) CompleteSsePath() string { + path, err := s.CompleteSseEndpoint() + if err != nil { + return normalizeURLPath(s.basePath, s.sseEndpoint) + } + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.sseEndpoint) + } + return urlPath +} + +func (s *SSEServer) CompleteMessageEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} + } + path := normalizeURLPath(s.basePath, s.messageEndpoint) + return s.baseURL + path, nil +} + +func (s *SSEServer) CompleteMessagePath() string { + path, err := s.CompleteMessageEndpoint() + if err != nil { + return normalizeURLPath(s.basePath, s.messageEndpoint) + } + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.messageEndpoint) + } + return urlPath +} + +// SSEHandler returns an http.Handler for the SSE endpoint. +// +// This method allows you to mount the SSE handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) SSEHandler() http.Handler { + return http.HandlerFunc(s.handleSSE) +} + +// MessageHandler returns an http.Handler for the message endpoint. +// +// This method allows you to mount the message handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) MessageHandler() http.Handler { + return http.HandlerFunc(s.handleMessage) +} + +// ServeHTTP implements the http.Handler interface. +func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.dynamicBasePathFunc != nil { + http.Error( + w, + (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), + http.StatusInternalServerError, + ) + return + } + path := r.URL.Path + // Use exact path matching rather than Contains + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + messagePath := s.CompleteMessagePath() + if messagePath != "" && path == messagePath { + s.handleMessage(w, r) + return + } + + http.NotFound(w, r) +} + +// normalizeURLPath joins path elements like path.Join but ensures the +// result always starts with a leading slash and never ends with a slash +func normalizeURLPath(elem ...string) string { + joined := path.Join(elem...) + + // Ensure leading slash + if !strings.HasPrefix(joined, "/") { + joined = "/" + joined + } + + // Remove trailing slash if not just "/" + if len(joined) > 1 && strings.HasSuffix(joined, "/") { + joined = joined[:len(joined)-1] + } + + return joined +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go new file mode 100644 index 000000000..f5c8ddfd2 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -0,0 +1,877 @@ +package server + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioContextFunc is a function that takes an existing context and returns +// a potentially modified context. +// This can be used to inject context values from environment variables, +// for example. +type StdioContextFunc func(ctx context.Context) context.Context + +// StdioServer wraps a MCPServer and handles stdio communication. +// It provides a simple way to create command-line MCP servers that +// communicate via standard input/output streams using JSON-RPC messages. +type StdioServer struct { + server *MCPServer + errLogger *log.Logger + contextFunc StdioContextFunc + + // Thread-safe tool call processing + toolCallQueue chan *toolCallWork + workerWg sync.WaitGroup + workerPoolSize int + queueSize int + writeMu sync.Mutex // Protects concurrent writes +} + +// toolCallWork represents a queued tool call request +type toolCallWork struct { + ctx context.Context + message json.RawMessage + writer io.Writer +} + +// StdioOption defines a function type for configuring StdioServer +type StdioOption func(*StdioServer) + +// WithErrorLogger sets the error logger for the server +func WithErrorLogger(logger *log.Logger) StdioOption { + return func(s *StdioServer) { + s.errLogger = logger + } +} + +// WithStdioContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func WithStdioContextFunc(fn StdioContextFunc) StdioOption { + return func(s *StdioServer) { + s.contextFunc = fn + } +} + +// WithWorkerPoolSize sets the number of workers for processing tool calls +func WithWorkerPoolSize(size int) StdioOption { + return func(s *StdioServer) { + const maxWorkerPoolSize = 100 + if size > 0 && size <= maxWorkerPoolSize { + s.workerPoolSize = size + } else if size > maxWorkerPoolSize { + s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize) + s.workerPoolSize = maxWorkerPoolSize + } + } +} + +// WithQueueSize sets the size of the tool call queue +func WithQueueSize(size int) StdioOption { + return func(s *StdioServer) { + const maxQueueSize = 10000 + if size > 0 && size <= maxQueueSize { + s.queueSize = size + } else if size > maxQueueSize { + s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize) + s.queueSize = maxQueueSize + } + } +} + +// stdioSession is a static client session, since stdio has only one client. +type stdioSession struct { + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests + pendingRoots map[int64]chan *rootsResponse // for tracking pending list roots requests + pendingMu sync.RWMutex // protects pendingRequests and pendingElicitations +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error +} + +// elicitationResponse represents a response to an elicitation request +type elicitationResponse struct { + result *mcp.ElicitationResult + err error +} + +// rootsResponse represents a response to an list root request +type rootsResponse struct { + result *mcp.ListRootsResult + err error +} + +func (s *stdioSession) SessionID() string { + return "stdio" +} + +func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *stdioSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *stdioSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *stdioSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// ListRoots sends an list roots request to the client and waits for the response. +func (s *stdioSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *rootsResponse, 1) + s.pendingMu.Lock() + s.pendingRoots[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRoots, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodListRoots), + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal list roots request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write list roots request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// RequestElicitation sends an elicitation request to the client and waits for the response. +func (s *stdioSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *elicitationResponse, 1) + s.pendingMu.Lock() + s.pendingElicitations[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingElicitations, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.ElicitationParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodElicitationCreate), + Params: request.Params, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal elicitation request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write elicitation request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + +var ( + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) + _ SessionWithElicitation = (*stdioSession)(nil) + _ SessionWithRoots = (*stdioSession)(nil) +) + +var stdioSessionInstance = stdioSession{ + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), + pendingElicitations: make(map[int64]chan *elicitationResponse), + pendingRoots: make(map[int64]chan *rootsResponse), +} + +// NewStdioServer creates a new stdio server wrapper around an MCPServer. +// It initializes the server with a default error logger that discards all output. +func NewStdioServer(server *MCPServer) *StdioServer { + return &StdioServer{ + server: server, + errLogger: log.New( + os.Stderr, + "", + log.LstdFlags, + ), // Default to discarding logs + workerPoolSize: 5, // Default worker pool size + queueSize: 100, // Default queue size + } +} + +// SetErrorLogger configures where error messages from the StdioServer are logged. +// The provided logger will receive all error messages generated during server operation. +func (s *StdioServer) SetErrorLogger(logger *log.Logger) { + s.errLogger = logger +} + +// SetContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { + s.contextFunc = fn +} + +// handleNotifications continuously processes notifications from the session's notification channel +// and writes them to the provided output. It runs until the context is cancelled. +// Any errors encountered while writing notifications are logged but do not stop the handler. +func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { + for { + select { + case notification := <-stdioSessionInstance.notifications: + if err := s.writeResponse(notification, stdout); err != nil { + s.errLogger.Printf("Error writing notification: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +// processInputStream continuously reads and processes messages from the input stream. +// It handles EOF gracefully as a normal termination condition. +// The function returns when either: +// - The context is cancelled (returns context.Err()) +// - EOF is encountered (returns nil) +// - An error occurs while reading or processing messages (returns the error) +func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { + for { + if err := ctx.Err(); err != nil { + return err + } + + line, err := s.readNextLine(ctx, reader) + if err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error reading input: %v", err) + return err + } + + if err := s.processMessage(ctx, line, stdout); err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error handling message: %v", err) + return err + } + } +} + +// toolCallWorker processes tool calls from the queue +func (s *StdioServer) toolCallWorker(ctx context.Context) { + defer s.workerWg.Done() + + for { + select { + case work, ok := <-s.toolCallQueue: + if !ok { + // Channel closed, exit worker + return + } + // Process the tool call + response := s.server.HandleMessage(work.ctx, work.message) + if response != nil { + if err := s.writeResponse(response, work.writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + case <-ctx.Done(): + return + } + } +} + +// readNextLine reads a single line from the input reader in a context-aware manner. +// It uses channels to make the read operation cancellable via context. +// Returns the read line and any error encountered. If the context is cancelled, +// returns an empty string and the context's error. EOF is returned when the input +// stream is closed. +func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { + type result struct { + line string + err error + } + + resultCh := make(chan result, 1) + + go func() { + line, err := reader.ReadString('\n') + resultCh <- result{line: line, err: err} + }() + + select { + case <-ctx.Done(): + return "", nil + case res := <-resultCh: + return res.line, res.err + } +} + +// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. +// It runs until the context is cancelled or an error occurs. +// Returns an error if there are issues with reading input or writing output. +func (s *StdioServer) Listen( + ctx context.Context, + stdin io.Reader, + stdout io.Writer, +) error { + // Initialize the tool call queue + s.toolCallQueue = make(chan *toolCallWork, s.queueSize) + + // Set a static client context since stdio only has one client + if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { + return fmt.Errorf("register session: %w", err) + } + defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) + ctx = s.server.WithContext(ctx, &stdioSessionInstance) + + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + + // Add in any custom context. + if s.contextFunc != nil { + ctx = s.contextFunc(ctx) + } + + reader := bufio.NewReader(stdin) + + // Start worker pool for tool calls + for i := 0; i < s.workerPoolSize; i++ { + s.workerWg.Add(1) + go s.toolCallWorker(ctx) + } + + // Start notification handler + go s.handleNotifications(ctx, stdout) + + // Process input stream + err := s.processInputStream(ctx, reader, stdout) + + // Shutdown workers gracefully + close(s.toolCallQueue) + s.workerWg.Wait() + + return err +} + +// processMessage handles a single JSON-RPC message and writes the response. +// It parses the message, processes it through the wrapped MCPServer, and writes any response. +// Returns an error if there are issues with message processing or response writing. +func (s *StdioServer) processMessage( + ctx context.Context, + line string, + writer io.Writer, +) error { + // If line is empty, likely due to ctx cancellation + if len(line) == 0 { + return nil + } + + // Parse the message as raw JSON + var rawMessage json.RawMessage + if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { + response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") + return s.writeResponse(response, writer) + } + + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a response to an elicitation request + if s.handleElicitationResponse(rawMessage) { + return nil + } + + // Check if this is a response to an list roots request + if s.handleListRootsResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Queue tool calls for processing by workers + select { + case s.toolCallQueue <- &toolCallWork{ + ctx: ctx, + message: rawMessage, + writer: writer, + }: + return nil + case <-ctx.Done(): + return ctx.Err() + default: + // Queue is full, process synchronously as fallback + s.errLogger.Printf("Tool call queue full, processing synchronously") + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + return s.writeResponse(response, writer) + } + return nil + } + } + + // Handle other messages synchronously + response := s.server.HandleMessage(ctx, rawMessage) + + // Only write response if there is one (not for notifications) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + } + + return nil +} + +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + samplingResp.err = fmt.Errorf("failed to parse sampling response content: %w", err) + } else { + result.Content = content + samplingResp.result = &result + } + } else { + samplingResp.result = &result + } + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + +// handleElicitationResponse checks if the message is a response to an elicitation request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleElicitationResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleElicitationResponse(rawMessage) +} + +// handleElicitationResponse handles incoming elicitation responses for this session +func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + id, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Check if we have a pending elicitation request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingElicitations[id] + s.pendingMu.RUnlock() + + if !exists { + return false + } + + // Parse and send the response + elicitationResp := &elicitationResponse{} + + if response.Error != nil { + elicitationResp.err = fmt.Errorf("elicitation request failed: %s", response.Error.Message) + } else { + var result mcp.ElicitationResult + if err := json.Unmarshal(response.Result, &result); err != nil { + elicitationResp.err = fmt.Errorf("failed to unmarshal elicitation response: %w", err) + } else { + elicitationResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- elicitationResp: + default: + // Channel is full or closed, ignore + } + + return true +} + +// handleListRootsResponse checks if the message is a response to an list roots request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleListRootsResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleListRootsResponse(rawMessage) +} + +// handleListRootsResponse handles incoming list root responses for this session +func (s *stdioSession) handleListRootsResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + id, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Check if we have a pending list root request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRoots[id] + s.pendingMu.RUnlock() + + if !exists { + return false + } + + // Parse and send the response + rootsResp := &rootsResponse{} + + if response.Error != nil { + rootsResp.err = fmt.Errorf("list root request failed: %s", response.Error.Message) + } else { + var result mcp.ListRootsResult + if err := json.Unmarshal(response.Result, &result); err != nil { + rootsResp.err = fmt.Errorf("failed to unmarshal list root response: %w", err) + } else { + rootsResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- rootsResp: + default: + // Channel is full or closed, ignore + } + + return true +} + +// writeResponse marshals and writes a JSON-RPC response message followed by a newline. +// Returns an error if marshaling or writing fails. +func (s *StdioServer) writeResponse( + response mcp.JSONRPCMessage, + writer io.Writer, +) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + + // Protect concurrent writes + s.writeMu.Lock() + defer s.writeMu.Unlock() + + // Write response followed by newline + if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { + return err + } + + return nil +} + +// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. +// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. +// Returns an error if the server encounters any issues during operation. +func ServeStdio(server *MCPServer, opts ...StdioOption) error { + s := NewStdioServer(server) + + for _, opt := range opts { + opt(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-sigChan + cancel() + }() + + return s.Listen(ctx, os.Stdin, os.Stdout) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go new file mode 100644 index 000000000..4535943da --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -0,0 +1,1434 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "mime" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "sync/atomic" + "time" + "unicode" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" +) + +// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer +type StreamableHTTPOption func(*StreamableHTTPServer) + +// WithEndpointPath sets the endpoint path for the server. +// The default is "/mcp". +// It's only works for `Start` method. When used as a http.Handler, it has no effect. +func WithEndpointPath(endpointPath string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one + normalizedPath := "/" + strings.Trim(endpointPath, "/") + s.endpointPath = normalizedPath + } +} + +// WithStateLess sets the server to stateless mode. +// If true, the server will manage no session information. Every request will be treated +// as a new session. No session id returned to the client. +// The default is false. +// +// Note: This is a convenience method. It's identical to set WithSessionIdManager option +// to StatelessSessionIdManager. +func WithStateLess(stateLess bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateLess { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) + } + } +} + +// WithSessionIdManager sets a custom session id generator for the server. +// By default, the server uses StatelessGeneratingSessionIdManager (generates IDs but no local validation). +// Note: Options are applied in order; the last one wins. If combined with +// WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect. +func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if manager == nil { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) + return + } + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager) + } +} + +// WithSessionIdManagerResolver sets a custom session id manager resolver for the server. +// This allows for request-based session id management strategies. +// Note: Options are applied in order; the last one wins. If combined with +// WithStateLess or WithSessionIdManager, whichever is applied last takes effect. +func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if resolver == nil { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) + return + } + s.sessionIdManagerResolver = resolver + } +} + +// WithStateful enables stateful session management using InsecureStatefulSessionIdManager. +// This requires sticky sessions in multi-instance deployments. +func WithStateful(stateful bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateful { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + } + } +} + +// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the +// server will send a heartbeat to the client through the GET connection, to keep +// the connection alive from being closed by the network infrastructure (e.g. +// gateways). If the client does not establish a GET connection, it has no +// effect. The default is not to send heartbeats. +func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.listenHeartbeatInterval = interval + } +} + +// WithDisableStreaming prevents the server from responding to GET requests with +// a streaming response. Instead, it will respond with a 405 Method Not Allowed status. +// This can be useful in scenarios where streaming is not desired or supported. +// The default is false, meaning streaming is enabled. +func WithDisableStreaming(disable bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.disableStreaming = disable + } +} + +// WithHTTPContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +// This can be used to inject context values from headers, for example. +func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.contextFunc = fn + } +} + +// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.httpServer = srv + } +} + +// WithLogger sets the logger for the server +func WithLogger(logger util.Logger) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.logger = logger + } +} + +// WithTLSCert sets the TLS certificate and key files for HTTPS support. +// Both certFile and keyFile must be provided to enable TLS. +func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.tlsCertFile = certFile + s.tlsKeyFile = keyFile + } +} + +// StreamableHTTPServer implements a Streamable-http based MCP server. +// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http +// +// Usage: +// +// server := NewStreamableHTTPServer(mcpServer) +// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default +// +// or the server itself can be used as a http.Handler, which is convenient to +// integrate with existing http servers, or advanced usage: +// +// handler := NewStreamableHTTPServer(mcpServer) +// http.Handle("/streamable-http", handler) +// http.ListenAndServe(":8080", nil) +// +// Notice: +// Except for the GET handlers(listening), the POST handlers(request/notification) will +// not trigger the session registration. So the methods like `SendNotificationToSpecificClient` +// or `hooks.onRegisterSession` will not be triggered for POST messages. +// +// The current implementation does not support the following features from the specification: +// - Stream Resumability +type StreamableHTTPServer struct { + server *MCPServer + sessionTools *sessionToolsStore + sessionResources *sessionResourcesStore + sessionResourceTemplates *sessionResourceTemplatesStore + sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) + + httpServer *http.Server + mu sync.RWMutex + + endpointPath string + contextFunc HTTPContextFunc + sessionIdManagerResolver SessionIdManagerResolver + listenHeartbeatInterval time.Duration + logger util.Logger + sessionLogLevels *sessionLogLevelsStore + disableStreaming bool + + tlsCertFile string + tlsKeyFile string +} + +// NewStreamableHTTPServer creates a new streamable-http server instance +func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { + s := &StreamableHTTPServer{ + server: server, + sessionTools: newSessionToolsStore(), + sessionLogLevels: newSessionLogLevelsStore(), + endpointPath: "/mcp", + sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&StatelessGeneratingSessionIdManager{}), + logger: util.DefaultLogger(), + sessionResources: newSessionResourcesStore(), + sessionResourceTemplates: newSessionResourceTemplatesStore(), + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + return s +} + +// ServeHTTP implements the http.Handler interface. +func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + s.handlePost(w, r) + case http.MethodGet: + s.handleGet(w, r) + case http.MethodDelete: + s.handleDelete(w, r) + default: + http.NotFound(w, r) + } +} + +// Start begins serving the http server on the specified address and path +// (endpointPath). like: +// +// s.Start(":8080") +func (s *StreamableHTTPServer) Start(addr string) error { + s.mu.Lock() + if s.httpServer == nil { + mux := http.NewServeMux() + mux.Handle(s.endpointPath, s) + s.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + } else { + if s.httpServer.Addr == "" { + s.httpServer.Addr = addr + } else if s.httpServer.Addr != addr { + return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr) + } + } + srv := s.httpServer + s.mu.Unlock() + + if s.tlsCertFile != "" || s.tlsKeyFile != "" { + if s.tlsCertFile == "" || s.tlsKeyFile == "" { + return fmt.Errorf("both TLS cert and key must be provided") + } + if _, err := os.Stat(s.tlsCertFile); err != nil { + return fmt.Errorf("failed to find TLS certificate file: %w", err) + } + if _, err := os.Stat(s.tlsKeyFile); err != nil { + return fmt.Errorf("failed to find TLS key file: %w", err) + } + return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) + } + + return srv.ListenAndServe() +} + +// Shutdown gracefully stops the server, closing all active sessions +// and shutting down the HTTP server. +func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { + + // shutdown the server if needed (may use as a http.Handler) + s.mu.RLock() + srv := s.httpServer + s.mu.RUnlock() + if srv != nil { + return srv.Shutdown(ctx) + } + return nil +} + +// --- internal methods --- + +func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { + // post request carry request/notification message + + // Check content type + contentType := r.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil || mediaType != "application/json" { + http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest) + return + } + + // Check the request body is valid json, meanwhile, get the request Method + rawData, err := io.ReadAll(r.Body) + if err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) + return + } + // First, try to parse as a response (sampling responses don't have a method field) + var jsonMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` + } + if err := json.Unmarshal(rawData, &jsonMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") + return + } + + // detect empty ping response, skip session ID validation + isPingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (isJSONEmpty(jsonMessage.Result) && isJSONEmpty(jsonMessage.Error)) + + if isPingResponse { + return + } + + // Check if this is a sampling response (has result/error but no method) + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (jsonMessage.Result != nil || jsonMessage.Error != nil) + + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize + + // Handle sampling responses separately + if isSamplingResponse { + if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { + s.logger.Errorf("Failed to handle sampling response: %v", err) + http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) + } + return + } + + // Prepare the session for the mcp server + // The session is ephemeral. Its life is the same as the request. It's only created + // for interaction with the mcp server. + var sessionID string + sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) + if isInitializeRequest { + // generate a new one for initialize request + sessionID = sessionIdManager.Generate() + } else { + // Get session ID from header. + // Stateful servers need the client to carry the session ID. + sessionID = r.Header.Get(HeaderKeySessionID) + isTerminated, err := sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return + } + } + + // For non-initialize requests, try to reuse existing registered session + var session *streamableHttpSession + if !isInitializeRequest { + if sessionValue, ok := s.server.sessions.Load(sessionID); ok { + if existingSession, ok := sessionValue.(*streamableHttpSession); ok { + session = existingSession + } + } + } + + // Check if a persistent session exists (for sampling support), otherwise create ephemeral session + // Persistent sessions are created by GET (continuous listening) connections + if session == nil { + if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { + if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { + session = persistentSession + } + } + } + + // Create ephemeral session if no persistent session exists + if session == nil { + session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) + } + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // handle potential notifications + mu := sync.Mutex{} + upgradedHeader := false + done := make(chan struct{}) + + ctx = context.WithValue(ctx, requestHeader, r.Header) + go func() { + for { + select { + case nt := <-session.notificationChannel: + func() { + mu.Lock() + defer mu.Unlock() + // if the done chan is closed, as the request is terminated, just return + select { + case <-done: + return + default: + } + defer func() { + flusher, ok := w.(http.Flusher) + if ok { + flusher.Flush() + } + }() + + // if there's notifications, upgradedHeader to SSE response + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + upgradedHeader = true + } + err := writeSSEEvent(w, nt) + if err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + }() + case <-done: + return + case <-ctx.Done(): + return + } + } + }() + + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawData) + if response == nil { + // For notifications, just send 202 Accepted with no body + w.WriteHeader(http.StatusAccepted) + return + } + + // Write response + mu.Lock() + defer mu.Unlock() + // close the done chan before unlock + defer close(done) + if ctx.Err() != nil { + return + } + // If client-server communication already upgraded to SSE stream + if session.upgradeToSSE.Load() { + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + upgradedHeader = true + } + if err := writeSSEEvent(w, response); err != nil { + s.logger.Errorf("Failed to write final SSE response event: %v", err) + } + } else { + w.Header().Set("Content-Type", "application/json") + if isInitializeRequest && sessionID != "" { + // send the session ID back to the client + w.Header().Set(HeaderKeySessionID, sessionID) + } + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write response: %v", err) + } + } + + // Register session after successful initialization + // Only register if not already registered (e.g., by a GET connection) + if isInitializeRequest && sessionID != "" { + if _, exists := s.server.sessions.Load(sessionID); !exists { + // Store in activeSessions to prevent duplicate registration from GET + s.activeSessions.Store(sessionID, session) + // Register the session with the MCPServer for notification support + if err := s.server.RegisterSession(ctx, session); err != nil { + s.logger.Errorf("Failed to register POST session: %v", err) + s.activeSessions.Delete(sessionID) + // Don't fail the request, just log the error + } + } + } +} + +func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { + // get request is for listening to notifications + // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + if s.disableStreaming { + s.logger.Infof("Rejected GET request: streaming is disabled (session: %s)", r.Header.Get(HeaderKeySessionID)) + http.Error(w, "Streaming is disabled on this server", http.StatusMethodNotAllowed) + return + } + + sessionID := r.Header.Get(HeaderKeySessionID) + // the specification didn't say we should validate the session id + + if sessionID == "" { + // It's a stateless server, + // but the MCP server requires a unique ID for registering, so we use a random one + sessionID = uuid.New().String() + } + + // Get or create session atomically to prevent TOCTOU races + // where concurrent GETs could both create and register duplicate sessions + var session *streamableHttpSession + newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) + actual, loaded := s.activeSessions.LoadOrStore(sessionID, newSession) + session = actual.(*streamableHttpSession) + + if !loaded { + // We created a new session, need to register it + if err := s.server.RegisterSession(r.Context(), session); err != nil { + s.activeSessions.Delete(sessionID) + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + defer s.activeSessions.Delete(sessionID) + } + + // Set the client context before handling the message + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + flusher.Flush() + + // Start notification handler for this session + done := make(chan struct{}) + defer close(done) + writeChan := make(chan any, 16) + + go func() { + for { + select { + case nt := <-session.notificationChannel: + select { + case writeChan <- &nt: + case <-done: + return + } + case samplingReq := <-session.samplingRequestChan: + // Send sampling request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(samplingReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + Params: samplingReq.request.CreateMessageParams, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } + case elicitationReq := <-session.elicitationRequestChan: + // Send elicitation request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(elicitationReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: elicitationReq.request.Params, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } + case rootsReq := <-session.rootsRequestChan: + // Send list roots request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(rootsReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodListRoots), + }, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } + case <-done: + return + } + } + }() + + if s.listenHeartbeatInterval > 0 { + // heartbeat to keep the connection alive + go func() { + ticker := time.NewTicker(s.listenHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(s.nextRequestID(sessionID)), + Request: mcp.Request{ + Method: "ping", + }, + } + select { + case writeChan <- message: + case <-done: + return + } + case <-done: + return + } + } + }() + } + + // Keep the connection open until the client disconnects + // + // There's will a Available() check when handler ends, and it maybe race with Flush(), + // so we use a separate channel to send the data, inteading of flushing directly in other goroutine. + for { + select { + case data := <-writeChan: + if data == nil { + continue + } + if err := writeSSEEvent(w, data); err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + // delete request terminate the session + sessionID := r.Header.Get(HeaderKeySessionID) + sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) + notAllowed, err := sessionIdManager.Terminate(sessionID) + if err != nil { + http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) + return + } + if notAllowed { + http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed) + return + } + + // remove the session relateddata from the sessionToolsStore + s.sessionTools.delete(sessionID) + s.sessionResources.delete(sessionID) + s.sessionResourceTemplates.delete(sessionID) + s.sessionLogLevels.delete(sessionID) + // remove current session's requstID information + s.sessionRequestIDs.Delete(sessionID) + + w.WriteHeader(http.StatusOK) +} + +func writeSSEEvent(w io.Writer, data any) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData) + if err != nil { + return fmt.Errorf("failed to write SSE event: %w", err) + } + return nil +} + +// handleSamplingResponse processes incoming sampling responses from clients +func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` +}) error { + // Get session ID from header + sessionID := r.Header.Get(HeaderKeySessionID) + if sessionID == "" { + http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) + return fmt.Errorf("missing session ID") + } + + // Validate session + sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) + isTerminated, err := sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return err + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return fmt.Errorf("session terminated") + } + + // Parse the request ID + var requestID int64 + if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { + http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) + return err + } + + // Create the sampling response item + response := samplingResponseItem{ + requestID: requestID, + } + + // Parse result or error + if responseMessage.Error != nil { + // Parse error + var jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { + response.err = fmt.Errorf("failed to parse error: %v", err) + } else { + response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) + } + } else if responseMessage.Result != nil { + // Store the result to be unmarshaled later + response.result = responseMessage.Result + } else { + response.err = fmt.Errorf("sampling response has neither result nor error") + } + + // Find the corresponding session and deliver the response + // The response is delivered to the specific session identified by sessionID + if err := s.deliverSamplingResponse(sessionID, response); err != nil { + s.logger.Errorf("Failed to deliver sampling response: %v", err) + http.Error(w, "Failed to deliver response", http.StatusInternalServerError) + return err + } + + // Acknowledge receipt + w.WriteHeader(http.StatusOK) + return nil +} + +// deliverSamplingResponse delivers a sampling response to the appropriate session +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { + // Look up the active session + sessionInterface, ok := s.activeSessions.Load(sessionID) + if !ok { + return fmt.Errorf("no active session found for session %s", sessionID) + } + + session, ok := sessionInterface.(*streamableHttpSession) + if !ok { + return fmt.Errorf("invalid session type for session %s", sessionID) + } + + // Look up the dedicated response channel for this specific request + responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + if !exists { + return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) + } + + responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + if !ok { + return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + + // Attempt to deliver the response with timeout to prevent indefinite blocking + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + return nil + default: + return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) + } +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *StreamableHTTPServer) writeJSONRPCError( + w http.ResponseWriter, + id any, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write JSONRPCError: %v", err) + } +} + +// nextRequestID gets the next incrementing requestID for the current session +func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { + actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64)) + counter := actual.(*atomic.Int64) + return counter.Add(1) +} + +// --- session --- +type sessionLogLevelsStore struct { + mu sync.RWMutex + logs map[string]mcp.LoggingLevel +} + +func newSessionLogLevelsStore() *sessionLogLevelsStore { + return &sessionLogLevelsStore{ + logs: make(map[string]mcp.LoggingLevel), + } +} + +func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.logs[sessionID] + if !ok { + return mcp.LoggingLevelError + } + return val +} + +func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) { + s.mu.Lock() + defer s.mu.Unlock() + s.logs[sessionID] = level +} + +func (s *sessionLogLevelsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.logs, sessionID) +} + +type sessionResourcesStore struct { + mu sync.RWMutex + resources map[string]map[string]ServerResource // sessionID -> resourceURI -> resource +} + +func newSessionResourcesStore() *sessionResourcesStore { + return &sessionResourcesStore{ + resources: make(map[string]map[string]ServerResource), + } +} + +func (s *sessionResourcesStore) get(sessionID string) map[string]ServerResource { + s.mu.RLock() + defer s.mu.RUnlock() + cloned := make(map[string]ServerResource, len(s.resources[sessionID])) + maps.Copy(cloned, s.resources[sessionID]) + return cloned +} + +func (s *sessionResourcesStore) set(sessionID string, resources map[string]ServerResource) { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]ServerResource, len(resources)) + maps.Copy(cloned, resources) + s.resources[sessionID] = cloned +} + +func (s *sessionResourcesStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.resources, sessionID) +} + +type sessionResourceTemplatesStore struct { + mu sync.RWMutex + templates map[string]map[string]ServerResourceTemplate // sessionID -> uriTemplate -> template +} + +func newSessionResourceTemplatesStore() *sessionResourceTemplatesStore { + return &sessionResourceTemplatesStore{ + templates: make(map[string]map[string]ServerResourceTemplate), + } +} + +func (s *sessionResourceTemplatesStore) get(sessionID string) map[string]ServerResourceTemplate { + s.mu.RLock() + defer s.mu.RUnlock() + cloned := make(map[string]ServerResourceTemplate, len(s.templates[sessionID])) + maps.Copy(cloned, s.templates[sessionID]) + return cloned +} + +func (s *sessionResourceTemplatesStore) set(sessionID string, templates map[string]ServerResourceTemplate) { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]ServerResourceTemplate, len(templates)) + maps.Copy(cloned, templates) + s.templates[sessionID] = cloned +} + +func (s *sessionResourceTemplatesStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.templates, sessionID) +} + +type sessionToolsStore struct { + mu sync.RWMutex + tools map[string]map[string]ServerTool // sessionID -> toolName -> tool +} + +func newSessionToolsStore() *sessionToolsStore { + return &sessionToolsStore{ + tools: make(map[string]map[string]ServerTool), + } +} + +func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool { + s.mu.RLock() + defer s.mu.RUnlock() + cloned := make(map[string]ServerTool, len(s.tools[sessionID])) + maps.Copy(cloned, s.tools[sessionID]) + return cloned +} + +func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]ServerTool, len(tools)) + maps.Copy(cloned, tools) + s.tools[sessionID] = cloned +} + +func (s *sessionToolsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tools, sessionID) +} + +// Sampling support types for HTTP transport +type samplingRequestItem struct { + requestID int64 + request mcp.CreateMessageRequest + response chan samplingResponseItem +} + +type samplingResponseItem struct { + requestID int64 + result json.RawMessage + err error +} + +// Elicitation support types for HTTP transport +type elicitationRequestItem struct { + requestID int64 + request mcp.ElicitationRequest + response chan samplingResponseItem +} + +// Roots support types for HTTP transport +type rootsRequestItem struct { + requestID int64 + request mcp.ListRootsRequest + response chan samplingResponseItem +} + +// streamableHttpSession is a session for streamable-http transport +// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. +// When in GET handlers(listening), it's a real session, and will be registered in the MCP server. +type streamableHttpSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification // server -> client notifications + tools *sessionToolsStore + resources *sessionResourcesStore + resourceTemplates *sessionResourceTemplatesStore + upgradeToSSE atomic.Bool + logLevels *sessionLogLevelsStore + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + + // Sampling support for bidirectional communication + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests + rootsRequestChan chan rootsRequestItem // server -> client list roots requests + + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs +} + +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, templatesStore *sessionResourceTemplatesStore, levels *sessionLogLevelsStore) *streamableHttpSession { + s := &streamableHttpSession{ + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + resources: resourcesStore, + resourceTemplates: templatesStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), + elicitationRequestChan: make(chan elicitationRequestItem, 10), + rootsRequestChan: make(chan rootsRequestItem, 10), + } + return s +} + +func (s *streamableHttpSession) SessionID() string { + return s.sessionID +} + +func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *streamableHttpSession) Initialize() { + // do nothing + // the session is ephemeral, no real initialized action needed +} + +func (s *streamableHttpSession) Initialized() bool { + // the session is ephemeral, no real initialized action needed + return true +} + +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.logLevels.set(s.sessionID, level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + return s.logLevels.get(s.sessionID) +} + +var _ ClientSession = (*streamableHttpSession)(nil) + +func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { + return s.tools.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { + s.tools.set(s.sessionID, tools) +} + +func (s *streamableHttpSession) GetSessionResources() map[string]ServerResource { + return s.resources.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionResources(resources map[string]ServerResource) { + s.resources.set(s.sessionID, resources) +} + +func (s *streamableHttpSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { + return s.resourceTemplates.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { + s.resourceTemplates.set(s.sessionID, templates) +} + +func (s *streamableHttpSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +var ( + _ SessionWithTools = (*streamableHttpSession)(nil) + _ SessionWithResources = (*streamableHttpSession)(nil) + _ SessionWithResourceTemplates = (*streamableHttpSession)(nil) + _ SessionWithLogging = (*streamableHttpSession)(nil) + _ SessionWithClientInfo = (*streamableHttpSession)(nil) +) + +func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + s.upgradeToSSE.Store(true) +} + +var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) + +// RequestSampling implements SessionWithSampling interface for HTTP transport +func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + samplingRequest := samplingRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.samplingRequestChan <- samplingRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("sampling request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) + } + + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + // HTTP transport unmarshals Content as map[string]any, we need to convert it to the proper type + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("failed to parse sampling response content: %w", err) + } + result.Content = content + } + + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// ListRoots implements SessionWithRoots interface for HTTP transport. +// It sends a list roots request to the client via SSE and waits for the response. +func (s *streamableHttpSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the roots request item + rootsRequest := rootsRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the list roots request via the channel (non-blocking) + select { + case s.rootsRequestChan <- rootsRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("list roots request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.ListRootsResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal list roots response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// RequestElicitation implements SessionWithElicitation interface for HTTP transport +func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + elicitationRequest := elicitationRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.elicitationRequestChan <- elicitationRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("elicitation request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.ElicitationResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +var _ SessionWithSampling = (*streamableHttpSession)(nil) +var _ SessionWithElicitation = (*streamableHttpSession)(nil) +var _ SessionWithRoots = (*streamableHttpSession)(nil) + +// --- session id manager --- + +// SessionIdManagerResolver resolves a SessionIdManager based on the HTTP request +type SessionIdManagerResolver interface { + ResolveSessionIdManager(r *http.Request) SessionIdManager +} + +type SessionIdManager interface { + Generate() string + // Validate checks if a session ID is valid and not terminated. + // Returns isTerminated=true if the ID is valid but belongs to a terminated session. + // Returns err!=nil if the ID format is invalid or lookup failed. + Validate(sessionID string) (isTerminated bool, err error) + // Terminate marks a session ID as terminated. + // Returns isNotAllowed=true if the server policy prevents client termination. + // Returns err!=nil if the ID is invalid or termination failed. + Terminate(sessionID string) (isNotAllowed bool, err error) +} + +// DefaultSessionIdManagerResolver is a simple resolver that returns the same SessionIdManager for all requests +type DefaultSessionIdManagerResolver struct { + manager SessionIdManager +} + +// NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager +func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver { + if manager == nil { + manager = &StatelessSessionIdManager{} + } + return &DefaultSessionIdManagerResolver{manager: manager} +} + +// ResolveSessionIdManager returns the configured SessionIdManager for all requests +func (r *DefaultSessionIdManagerResolver) ResolveSessionIdManager(_ *http.Request) SessionIdManager { + return r.manager +} + +// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless. +type StatelessSessionIdManager struct{} + +func (s *StatelessSessionIdManager) Generate() string { + return "" +} + +func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // In stateless mode, ignore session IDs completely - don't validate or reject them + return false, nil +} + +func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + return false, nil +} + +// StatelessGeneratingSessionIdManager generates session IDs but doesn't validate them locally. +// This allows session IDs to be generated for clients while working across multiple instances. +type StatelessGeneratingSessionIdManager struct{} + +func (s *StatelessGeneratingSessionIdManager) Generate() string { + return idPrefix + uuid.New().String() +} + +func (s *StatelessGeneratingSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // Only validate format, not existence - allows cross-instance operation + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + return false, nil +} + +func (s *StatelessGeneratingSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + // No-op termination since we don't track sessions + return false, nil +} + +// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions. +// It validates both format and existence of session IDs. +// For more secure session id, use a more complex generator, like a JWT. +type InsecureStatefulSessionIdManager struct { + sessions sync.Map + terminated sync.Map +} + +const idPrefix = "mcp-session-" + +func (s *InsecureStatefulSessionIdManager) Generate() string { + sessionID := idPrefix + uuid.New().String() + s.sessions.Store(sessionID, true) + return sessionID +} + +func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, exists := s.terminated.Load(sessionID); exists { + return true, nil + } + if _, exists := s.sessions.Load(sessionID); !exists { + return false, fmt.Errorf("session not found: %s", sessionID) + } + return false, nil +} + +func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + if _, exists := s.terminated.Load(sessionID); exists { + return false, nil + } + if _, exists := s.sessions.Load(sessionID); !exists { + return false, nil + } + s.terminated.Store(sessionID, true) + s.sessions.Delete(sessionID) + return false, nil +} + +// NewTestStreamableHTTPServer creates a test server for testing purposes +func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server { + sseServer := NewStreamableHTTPServer(server, opts...) + testServer := httptest.NewServer(sseServer) + return testServer +} + +// isJSONEmpty reports whether the provided JSON value is "empty": +// - null +// - empty object: {} +// - empty array: [] +// +// It also treats nil/whitespace-only input as empty. +// It does NOT treat 0, false, "" or non-empty composites as empty. +func isJSONEmpty(data json.RawMessage) bool { + if len(data) == 0 { + return true + } + + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 { + return true + } + + switch trimmed[0] { + case '{': + if len(trimmed) == 2 && trimmed[1] == '}' { + return true + } + for i := 1; i < len(trimmed); i++ { + if !unicode.IsSpace(rune(trimmed[i])) { + return trimmed[i] == '}' + } + } + case '[': + if len(trimmed) == 2 && trimmed[1] == ']' { + return true + } + for i := 1; i < len(trimmed); i++ { + if !unicode.IsSpace(rune(trimmed[i])) { + return trimmed[i] == ']' + } + } + + case '"': // treat "" as not empty + return false + + case 'n': // null + return len(trimmed) == 4 && + trimmed[1] == 'u' && + trimmed[2] == 'l' && + trimmed[3] == 'l' + } + return false +} diff --git a/vendor/github.com/mark3labs/mcp-go/util/logger.go b/vendor/github.com/mark3labs/mcp-go/util/logger.go new file mode 100644 index 000000000..8d7555ce3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/util/logger.go @@ -0,0 +1,33 @@ +package util + +import ( + "log" +) + +// Logger defines a minimal logging interface +type Logger interface { + Infof(format string, v ...any) + Errorf(format string, v ...any) +} + +// --- Standard Library Logger Wrapper --- + +// DefaultStdLogger implements Logger using the standard library's log.Logger. +func DefaultLogger() Logger { + return &stdLogger{ + logger: log.Default(), + } +} + +// stdLogger wraps the standard library's log.Logger. +type stdLogger struct { + logger *log.Logger +} + +func (l *stdLogger) Infof(format string, v ...any) { + l.logger.Printf("INFO: "+format, v...) +} + +func (l *stdLogger) Errorf(format string, v ...any) { + l.logger.Printf("ERROR: "+format, v...) +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.gitignore b/vendor/github.com/wk8/go-ordered-map/v2/.gitignore new file mode 100644 index 000000000..57872d0f1 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/.gitignore @@ -0,0 +1 @@ +/vendor/ diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml b/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml new file mode 100644 index 000000000..2417df10d --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml @@ -0,0 +1,80 @@ +run: + tests: false + +linters: + disable-all: true + enable: + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - decorder + - depguard + - dogsled + - dupl + - durationcheck + - errcheck + - errchkjson + # FIXME: commented out as it crashes with 1.18 for now + # - errname + - errorlint + - exportloopref + - forbidigo + - funlen + - gci + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - gofmt + - gofumpt + - goheader + - goimports + - gomnd + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - gosimple + - govet + - grouper + - ifshort + - importas + - ineffassign + - lll + - maintidx + - makezero + - misspell + - nakedret + - nilerr + - nilnil + - noctx + - nolintlint + - paralleltest + - prealloc + - predeclared + - promlinter + # FIXME: doesn't support 1.18 yet + # - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - structcheck + - stylecheck + - tagliatelle + - tenv + - testpackage + - thelper + - tparallel + - typecheck + - unconvert + - unparam + - unused + - varcheck + - varnamelen + - wastedassign + - whitespace diff --git a/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md b/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md new file mode 100644 index 000000000..f27126f84 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md @@ -0,0 +1,38 @@ +# Changelog + +[comment]: # (Changes since last release go here) + +## 2.1.8 - Jun 27th 2023 + +* Added support for YAML serialization/deserialization + +## 2.1.7 - Apr 13th 2023 + +* Renamed test_utils.go to utils_test.go + +## 2.1.6 - Feb 15th 2023 + +* Added `GetAndMoveToBack()` and `GetAndMoveToFront()` methods + +## 2.1.5 - Dec 13th 2022 + +* Added `Value()` method + +## 2.1.4 - Dec 12th 2022 + +* Fixed a bug with UTF-8 special characters in JSON keys + +## 2.1.3 - Dec 11th 2022 + +* Added support for JSON marshalling/unmarshalling of wrapper of primitive types + +## 2.1.2 - Dec 10th 2022 +* Allowing to pass options to `New`, to give a capacity hint, or initial data +* Allowing to deserialize nested ordered maps from JSON without having to explicitly instantiate them +* Added the `AddPairs` method + +## 2.1.1 - Dec 9th 2022 +* Fixing a bug with JSON marshalling + +## 2.1.0 - Dec 7th 2022 +* Added support for JSON serialization/deserialization diff --git a/vendor/github.com/wk8/go-ordered-map/v2/LICENSE b/vendor/github.com/wk8/go-ordered-map/v2/LICENSE new file mode 100644 index 000000000..8dada3eda --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/wk8/go-ordered-map/v2/Makefile b/vendor/github.com/wk8/go-ordered-map/v2/Makefile new file mode 100644 index 000000000..6e0e18a1b --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/Makefile @@ -0,0 +1,32 @@ +.DEFAULT_GOAL := all + +.PHONY: all +all: test_with_fuzz lint + +# the TEST_FLAGS env var can be set to eg run only specific tests +TEST_COMMAND = go test -v -count=1 -race -cover $(TEST_FLAGS) + +.PHONY: test +test: + $(TEST_COMMAND) + +.PHONY: bench +bench: + go test -bench=. + +FUZZ_TIME ?= 10s + +# see https://github.com/golang/go/issues/46312 +# and https://stackoverflow.com/a/72673487/4867444 +# if we end up having more fuzz tests +.PHONY: test_with_fuzz +test_with_fuzz: + $(TEST_COMMAND) -fuzz=FuzzRoundTripJSON -fuzztime=$(FUZZ_TIME) + $(TEST_COMMAND) -fuzz=FuzzRoundTripYAML -fuzztime=$(FUZZ_TIME) + +.PHONY: fuzz +fuzz: test_with_fuzz + +.PHONY: lint +lint: + golangci-lint run diff --git a/vendor/github.com/wk8/go-ordered-map/v2/README.md b/vendor/github.com/wk8/go-ordered-map/v2/README.md new file mode 100644 index 000000000..b02894443 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/README.md @@ -0,0 +1,154 @@ +[![Go Reference](https://pkg.go.dev/badge/github.com/wk8/go-ordered-map/v2.svg)](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2) +[![Build Status](https://circleci.com/gh/wk8/go-ordered-map.svg?style=svg)](https://app.circleci.com/pipelines/github/wk8/go-ordered-map) + +# Golang Ordered Maps + +Same as regular maps, but also remembers the order in which keys were inserted, akin to [Python's `collections.OrderedDict`s](https://docs.python.org/3.7/library/collections.html#ordereddict-objects). + +It offers the following features: +* optimal runtime performance (all operations are constant time) +* optimal memory usage (only one copy of values, no unnecessary memory allocation) +* allows iterating from newest or oldest keys indifferently, without memory copy, allowing to `break` the iteration, and in time linear to the number of keys iterated over rather than the total length of the ordered map +* supports any generic types for both keys and values. If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) that takes and returns generic `interface{}`s instead of using generics +* idiomatic API, akin to that of [`container/list`](https://golang.org/pkg/container/list) +* support for JSON and YAML marshalling + +## Documentation + +[The full documentation is available on pkg.go.dev](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2). + +## Installation +```bash +go get -u github.com/wk8/go-ordered-map/v2 +``` + +Or use your favorite golang vendoring tool! + +## Supported go versions + +Go >= 1.18 is required to use version >= 2 of this library, as it uses generics. + +If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) instead. + +## Example / usage + +```go +package main + +import ( + "fmt" + + "github.com/wk8/go-ordered-map/v2" +) + +func main() { + om := orderedmap.New[string, string]() + + om.Set("foo", "bar") + om.Set("bar", "baz") + om.Set("coucou", "toi") + + fmt.Println(om.Get("foo")) // => "bar", true + fmt.Println(om.Get("i dont exist")) // => "", false + + // iterating pairs from oldest to newest: + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + fmt.Printf("%s => %s\n", pair.Key, pair.Value) + } // prints: + // foo => bar + // bar => baz + // coucou => toi + + // iterating over the 2 newest pairs: + i := 0 + for pair := om.Newest(); pair != nil; pair = pair.Prev() { + fmt.Printf("%s => %s\n", pair.Key, pair.Value) + i++ + if i >= 2 { + break + } + } // prints: + // coucou => toi + // bar => baz +} +``` + +An `OrderedMap`'s keys must implement `comparable`, and its values can be anything, for example: + +```go +type myStruct struct { + payload string +} + +func main() { + om := orderedmap.New[int, *myStruct]() + + om.Set(12, &myStruct{"foo"}) + om.Set(1, &myStruct{"bar"}) + + value, present := om.Get(12) + if !present { + panic("should be there!") + } + fmt.Println(value.payload) // => foo + + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + fmt.Printf("%d => %s\n", pair.Key, pair.Value.payload) + } // prints: + // 12 => foo + // 1 => bar +} +``` + +Also worth noting that you can provision ordered maps with a capacity hint, as you would do by passing an optional hint to `make(map[K]V, capacity`): +```go +om := orderedmap.New[int, *myStruct](28) +``` + +You can also pass in some initial data to store in the map: +```go +om := orderedmap.New[int, string](orderedmap.WithInitialData[int, string]( + orderedmap.Pair[int, string]{ + Key: 12, + Value: "foo", + }, + orderedmap.Pair[int, string]{ + Key: 28, + Value: "bar", + }, +)) +``` + +`OrderedMap`s also support JSON serialization/deserialization, and preserves order: + +```go +// serialization +data, err := json.Marshal(om) +... + +// deserialization +om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect +err := json.Unmarshal(data, &om) +... +``` + +Similarly, it also supports YAML serialization/deserialization using the yaml.v3 package, which also preserves order: + +```go +// serialization +data, err := yaml.Marshal(om) +... + +// deserialization +om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect +err := yaml.Unmarshal(data, &om) +... +``` + +## Alternatives + +There are several other ordered map golang implementations out there, but I believe that at the time of writing none of them offer the same functionality as this library; more specifically: +* [iancoleman/orderedmap](https://github.com/iancoleman/orderedmap) only accepts `string` keys, its `Delete` operations are linear +* [cevaris/ordered_map](https://github.com/cevaris/ordered_map) uses a channel for iterations, and leaks goroutines if the iteration is interrupted before fully traversing the map +* [mantyr/iterator](https://github.com/mantyr/iterator) also uses a channel for iterations, and its `Delete` operations are linear +* [samdolan/go-ordered-map](https://github.com/samdolan/go-ordered-map) adds unnecessary locking (users should add their own locking instead if they need it), its `Delete` and `Get` operations are linear, iterations trigger a linear memory allocation diff --git a/vendor/github.com/wk8/go-ordered-map/v2/json.go b/vendor/github.com/wk8/go-ordered-map/v2/json.go new file mode 100644 index 000000000..a545b536b --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/json.go @@ -0,0 +1,182 @@ +package orderedmap + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "reflect" + "unicode/utf8" + + "github.com/buger/jsonparser" + "github.com/mailru/easyjson/jwriter" +) + +var ( + _ json.Marshaler = &OrderedMap[int, any]{} + _ json.Unmarshaler = &OrderedMap[int, any]{} +) + +// MarshalJSON implements the json.Marshaler interface. +func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen + if om == nil || om.list == nil { + return []byte("null"), nil + } + + writer := jwriter.Writer{} + writer.RawByte('{') + + for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() { + if firstIteration { + firstIteration = false + } else { + writer.RawByte(',') + } + + switch key := any(pair.Key).(type) { + case string: + writer.String(key) + case encoding.TextMarshaler: + writer.RawByte('"') + writer.Raw(key.MarshalText()) + writer.RawByte('"') + case int: + writer.IntStr(key) + case int8: + writer.Int8Str(key) + case int16: + writer.Int16Str(key) + case int32: + writer.Int32Str(key) + case int64: + writer.Int64Str(key) + case uint: + writer.UintStr(key) + case uint8: + writer.Uint8Str(key) + case uint16: + writer.Uint16Str(key) + case uint32: + writer.Uint32Str(key) + case uint64: + writer.Uint64Str(key) + default: + + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() { + case reflect.String: + writer.String(keyValue.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + writer.Int64Str(keyValue.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + writer.Uint64Str(keyValue.Uint()) + default: + return nil, fmt.Errorf("unsupported key type: %T", key) + } + } + + writer.RawByte(':') + // the error is checked at the end of the function + writer.Raw(json.Marshal(pair.Value)) //nolint:errchkjson + } + + writer.RawByte('}') + + return dumpWriter(&writer) +} + +func dumpWriter(writer *jwriter.Writer) ([]byte, error) { + if writer.Error != nil { + return nil, writer.Error + } + + var buf bytes.Buffer + buf.Grow(writer.Size()) + if _, err := writer.DumpTo(&buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error { + if om.list == nil { + om.initialize(0) + } + + return jsonparser.ObjectEach( + data, + func(keyData []byte, valueData []byte, dataType jsonparser.ValueType, offset int) error { + if dataType == jsonparser.String { + // jsonparser removes the enclosing quotes; we need to restore them to make a valid JSON + valueData = data[offset-len(valueData)-2 : offset] + } + + var key K + var value V + + switch typedKey := any(&key).(type) { + case *string: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + *typedKey = s + case encoding.TextUnmarshaler: + if err := typedKey.UnmarshalText(keyData); err != nil { + return err + } + case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64: + if err := json.Unmarshal(keyData, typedKey); err != nil { + return err + } + default: + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch reflect.TypeOf(key).Kind() { + case reflect.String: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + + convertedKeyData := reflect.ValueOf(s).Convert(reflect.TypeOf(key)) + reflect.ValueOf(&key).Elem().Set(convertedKeyData) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := json.Unmarshal(keyData, &key); err != nil { + return err + } + default: + return fmt.Errorf("unsupported key type: %T", key) + } + } + + if err := json.Unmarshal(valueData, &value); err != nil { + return err + } + + om.Set(key, value) + return nil + }) +} + +func decodeUTF8(input []byte) (string, error) { + remaining, offset := input, 0 + runes := make([]rune, 0, len(remaining)) + + for len(remaining) > 0 { + r, size := utf8.DecodeRune(remaining) + if r == utf8.RuneError && size <= 1 { + return "", fmt.Errorf("not a valid UTF-8 string (at position %d): %s", offset, string(input)) + } + + runes = append(runes, r) + remaining = remaining[size:] + offset += size + } + + return string(runes), nil +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go b/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go new file mode 100644 index 000000000..064714191 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go @@ -0,0 +1,296 @@ +// Package orderedmap implements an ordered map, i.e. a map that also keeps track of +// the order in which keys were inserted. +// +// All operations are constant-time. +// +// Github repo: https://github.com/wk8/go-ordered-map +// +package orderedmap + +import ( + "fmt" + + list "github.com/bahlo/generic-list-go" +) + +type Pair[K comparable, V any] struct { + Key K + Value V + + element *list.Element[*Pair[K, V]] +} + +type OrderedMap[K comparable, V any] struct { + pairs map[K]*Pair[K, V] + list *list.List[*Pair[K, V]] +} + +type initConfig[K comparable, V any] struct { + capacity int + initialData []Pair[K, V] +} + +type InitOption[K comparable, V any] func(config *initConfig[K, V]) + +// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity). +func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.capacity = capacity + } +} + +// WithInitialData allows passing in initial data for the map. +func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.initialData = initialData + if c.capacity < len(initialData) { + c.capacity = len(initialData) + } + } +} + +// New creates a new OrderedMap. +// options can either be one or several InitOption[K, V], or a single integer, +// which is then interpreted as a capacity hint, à la make(map[K]V, capacity). +func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen + orderedMap := &OrderedMap[K, V]{} + + var config initConfig[K, V] + for _, untypedOption := range options { + switch option := untypedOption.(type) { + case int: + if len(options) != 1 { + invalidOption() + } + config.capacity = option + + case InitOption[K, V]: + option(&config) + + default: + invalidOption() + } + } + + orderedMap.initialize(config.capacity) + orderedMap.AddPairs(config.initialData...) + + return orderedMap +} + +const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll + +func invalidOption() { panic(invalidOptionMessage) } + +func (om *OrderedMap[K, V]) initialize(capacity int) { + om.pairs = make(map[K]*Pair[K, V], capacity) + om.list = list.New[*Pair[K, V]]() +} + +// Get looks for the given key, and returns the value associated with it, +// or V's nil value if not found. The boolean it returns says whether the key is present in the map. +func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + return pair.Value, true + } + + return +} + +// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Load(key K) (V, bool) { + return om.Get(key) +} + +// Value returns the value associated with the given key or the zero value. +func (om *OrderedMap[K, V]) Value(key K) (val V) { + if pair, present := om.pairs[key]; present { + val = pair.Value + } + return +} + +// GetPair looks for the given key, and returns the pair associated with it, +// or nil if not found. The Pair struct can then be used to iterate over the ordered map +// from that point, either forward or backward. +func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] { + return om.pairs[key] +} + +// Set sets the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Set`. +func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) { + if pair, present := om.pairs[key]; present { + oldValue := pair.Value + pair.Value = value + return oldValue, true + } + + pair := &Pair[K, V]{ + Key: key, + Value: value, + } + pair.element = om.list.PushBack(pair) + om.pairs[key] = pair + + return +} + +// AddPairs allows setting multiple pairs at a time. It's equivalent to calling +// Set on each pair sequentially. +func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) { + for _, pair := range pairs { + om.Set(pair.Key, pair.Value) + } +} + +// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) { + return om.Set(key, value) +} + +// Delete removes the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Delete`. +func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + om.list.Remove(pair.element) + delete(om.pairs, key) + return pair.Value, true + } + return +} + +// Len returns the length of the ordered map. +func (om *OrderedMap[K, V]) Len() int { + if om == nil || om.pairs == nil { + return 0 + } + return len(om.pairs) +} + +// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's +// pairs from the oldest to the newest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Front()) +} + +// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's +// pairs from the newest to the oldest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Newest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Back()) +} + +// Next returns a pointer to the next pair. +func (p *Pair[K, V]) Next() *Pair[K, V] { + return listElementToPair(p.element.Next()) +} + +// Prev returns a pointer to the previous pair. +func (p *Pair[K, V]) Prev() *Pair[K, V] { + return listElementToPair(p.element.Prev()) +} + +func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] { + if element == nil { + return nil + } + return element.Value +} + +// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present +// in the map. +type KeyNotFoundError[K comparable] struct { + MissingKey K +} + +func (e *KeyNotFoundError[K]) Error() string { + return fmt.Sprintf("missing key: %v", e.MissingKey) +} + +// MoveAfter moves the value associated with key to its new position after the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveAfter(elements[0], elements[1]) + return nil +} + +// MoveBefore moves the value associated with key to its new position before the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveBefore(elements[0], elements[1]) + return nil +} + +func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) { + elements := make([]*list.Element[*Pair[K, V]], len(keys)) + for i, k := range keys { + pair, present := om.pairs[k] + if !present { + return nil, &KeyNotFoundError[K]{k} + } + elements[i] = pair.element + } + return elements, nil +} + +// MoveToBack moves the value associated with key to the back of the ordered map, +// i.e. makes it the newest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToBack(key K) error { + _, err := om.GetAndMoveToBack(key) + return err +} + +// MoveToFront moves the value associated with key to the front of the ordered map, +// i.e. makes it the oldest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToFront(key K) error { + _, err := om.GetAndMoveToFront(key) + return err +} + +// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToBack(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} + +// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToFront(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/yaml.go b/vendor/github.com/wk8/go-ordered-map/v2/yaml.go new file mode 100644 index 000000000..602247128 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/yaml.go @@ -0,0 +1,71 @@ +package orderedmap + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +var ( + _ yaml.Marshaler = &OrderedMap[int, any]{} + _ yaml.Unmarshaler = &OrderedMap[int, any]{} +) + +// MarshalYAML implements the yaml.Marshaler interface. +func (om *OrderedMap[K, V]) MarshalYAML() (interface{}, error) { + if om == nil { + return []byte("null"), nil + } + + node := yaml.Node{ + Kind: yaml.MappingNode, + } + + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + key, value := pair.Key, pair.Value + + keyNode := &yaml.Node{} + + // serialize key to yaml, then deserialize it back into the node + // this is a hack to get the correct tag for the key + if err := keyNode.Encode(key); err != nil { + return nil, err + } + + valueNode := &yaml.Node{} + if err := valueNode.Encode(value); err != nil { + return nil, err + } + + node.Content = append(node.Content, keyNode, valueNode) + } + + return &node, nil +} + +// UnmarshalYAML implements the yaml.Unmarshaler interface. +func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.MappingNode { + return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind) + } + + if om.list == nil { + om.initialize(0) + } + + for index := 0; index < len(value.Content); index += 2 { + var key K + var val V + + if err := value.Content[index].Decode(&key); err != nil { + return err + } + if err := value.Content[index+1].Decode(&val); err != nil { + return err + } + + om.Set(key, val) + } + + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/LICENSE b/vendor/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 000000000..79e8f8757 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/yosida95/uritemplate/v3/README.rst b/vendor/github.com/yosida95/uritemplate/v3/README.rst new file mode 100644 index 000000000..6815d0a46 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/README.rst @@ -0,0 +1,46 @@ +uritemplate +=========== + +`uritemplate`_ is a Go implementation of `URI Template`_ [RFC6570] with +full functionality of URI Template Level 4. + +uritemplate can also generate a regexp that matches expansion of the +URI Template from a URI Template. + +Getting Started +--------------- + +Installation +~~~~~~~~~~~~ + +.. code-block:: sh + + $ go get -u github.com/yosida95/uritemplate/v3 + +Documentation +~~~~~~~~~~~~~ + +The documentation is available on GoDoc_. + +Examples +-------- + +See `examples on GoDoc`_. + +License +------- + +`uritemplate`_ is distributed under the BSD 3-Clause license. +PLEASE READ ./LICENSE carefully and follow its clauses to use this software. + +Author +------ + +yosida95_ + + +.. _`URI Template`: https://tools.ietf.org/html/rfc6570 +.. _Godoc: https://godoc.org/github.com/yosida95/uritemplate +.. _`examples on GoDoc`: https://godoc.org/github.com/yosida95/uritemplate#pkg-examples +.. _yosida95: https://yosida95.com/ +.. _uritemplate: https://github.com/yosida95/uritemplate diff --git a/vendor/github.com/yosida95/uritemplate/v3/compile.go b/vendor/github.com/yosida95/uritemplate/v3/compile.go new file mode 100644 index 000000000..bd774d15d --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/compile.go @@ -0,0 +1,224 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode/utf8" +) + +type compiler struct { + prog *prog +} + +func (c *compiler) init() { + c.prog = &prog{} +} + +func (c *compiler) op(opcode progOpcode) uint32 { + i := len(c.prog.op) + c.prog.op = append(c.prog.op, progOp{code: opcode}) + return uint32(i) +} + +func (c *compiler) opWithRune(opcode progOpcode, r rune) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).r = r + return addr +} + +func (c *compiler) opWithRuneClass(opcode progOpcode, rc runeClass) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).rc = rc + return addr +} + +func (c *compiler) opWithAddr(opcode progOpcode, absaddr uint32) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).i = absaddr + return addr +} + +func (c *compiler) opWithAddrDelta(opcode progOpcode, delta uint32) uint32 { + return c.opWithAddr(opcode, uint32(len(c.prog.op))+delta) +} + +func (c *compiler) opWithName(opcode progOpcode, name string) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).name = name + return addr +} + +func (c *compiler) compileString(str string) { + for i := 0; i < len(str); { + // NOTE(yosida95): It is confirmed at parse time that literals + // consist of only valid-UTF8 runes. + r, size := utf8.DecodeRuneInString(str[i:]) + c.opWithRune(opRune, r) + i += size + } +} + +func (c *compiler) compileRuneClass(rc runeClass, maxlen int) { + for i := 0; i < maxlen; i++ { + if i > 0 { + c.opWithAddrDelta(opSplit, 7) + } + c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + } +} + +func (c *compiler) compileRuneClassInfinite(rc runeClass) { + start := c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithAddrDelta(opSplit, 2) // loop + c.opWithAddr(opJmp, start) // +} + +func (c *compiler) compileVarspecValue(spec varspec, expr *expression) { + var specname string + if spec.maxlen > 0 { + specname = fmt.Sprintf("%s:%d", spec.name, spec.maxlen) + } else { + specname = spec.name + } + + c.prog.numCap++ + + c.opWithName(opCapStart, specname) + + split := c.op(opSplit) + if spec.maxlen > 0 { + c.compileRuneClass(expr.allow, spec.maxlen) + } else { + c.compileRuneClassInfinite(expr.allow) + } + + capEnd := c.opWithName(opCapEnd, specname) + c.prog.op[split].i = capEnd +} + +func (c *compiler) compileVarspec(spec varspec, expr *expression) { + switch { + case expr.named && spec.explode: + split1 := c.op(opSplit) + noop := c.op(opNoop) + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + c.compileVarspecValue(spec, expr) + + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.opWithAddr(opJmp, noop) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + c.opWithAddr(opJmp, split3) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + c.prog.op[split3].i = uint32(len(c.prog.op)) + + case expr.named && !spec.explode: + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + + split3 := c.op(opSplit) + + split4 := c.op(opSplit) + c.compileVarspecValue(spec, expr) + + split5 := c.op(opSplit) + c.prog.op[split4].i = split5 + c.compileString(",") + c.opWithAddr(opJmp, split4) + + c.prog.op[split3].i = uint32(len(c.prog.op)) + c.compileString(",") + jmp1 := c.op(opJmp) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + + c.prog.op[split5].i = uint32(len(c.prog.op)) + c.prog.op[jmp1].i = uint32(len(c.prog.op)) + + case !expr.named: + start := uint32(len(c.prog.op)) + c.compileVarspecValue(spec, expr) + + split1 := c.op(opSplit) + jmp := c.op(opJmp) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + if spec.explode { + c.compileString(expr.sep) + } else { + c.opWithRune(opRune, ',') + } + c.opWithAddr(opJmp, start) + + c.prog.op[jmp].i = uint32(len(c.prog.op)) + } +} + +func (c *compiler) compileExpression(expr *expression) { + if len(expr.vars) < 1 { + return + } + + split1 := c.op(opSplit) + c.compileString(expr.first) + + for i, size := 0, len(expr.vars); i < size; i++ { + spec := expr.vars[i] + + split2 := c.op(opSplit) + if i > 0 { + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.prog.op[split3].i = uint32(len(c.prog.op)) + } + c.compileVarspec(spec, expr) + c.prog.op[split2].i = uint32(len(c.prog.op)) + } + + c.prog.op[split1].i = uint32(len(c.prog.op)) +} + +func (c *compiler) compileLiterals(lt literals) { + c.compileString(string(lt)) +} + +func (c *compiler) compile(tmpl *Template) { + c.op(opLineBegin) + for i := range tmpl.exprs { + expr := tmpl.exprs[i] + switch expr := expr.(type) { + default: + panic("unhandled expression") + case *expression: + c.compileExpression(expr) + case literals: + c.compileLiterals(expr) + } + } + c.op(opLineEnd) + c.op(opEnd) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/equals.go b/vendor/github.com/yosida95/uritemplate/v3/equals.go new file mode 100644 index 000000000..aa59a5c03 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/equals.go @@ -0,0 +1,53 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +type CompareFlags uint8 + +const ( + CompareVarname CompareFlags = 1 << iota +) + +// Equals reports whether or not two URI Templates t1 and t2 are equivalent. +func Equals(t1 *Template, t2 *Template, flags CompareFlags) bool { + if len(t1.exprs) != len(t2.exprs) { + return false + } + for i := 0; i < len(t1.exprs); i++ { + switch t1 := t1.exprs[i].(type) { + case literals: + t2, ok := t2.exprs[i].(literals) + if !ok { + return false + } + if t1 != t2 { + return false + } + case *expression: + t2, ok := t2.exprs[i].(*expression) + if !ok { + return false + } + if t1.op != t2.op || len(t1.vars) != len(t2.vars) { + return false + } + for n := 0; n < len(t1.vars); n++ { + v1 := t1.vars[n] + v2 := t2.vars[n] + if flags&CompareVarname == CompareVarname && v1.name != v2.name { + return false + } + if v1.maxlen != v2.maxlen || v1.explode != v2.explode { + return false + } + } + default: + panic("unhandled case") + } + } + return true +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/error.go b/vendor/github.com/yosida95/uritemplate/v3/error.go new file mode 100644 index 000000000..2fd34a808 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/error.go @@ -0,0 +1,16 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" +) + +func errorf(pos int, format string, a ...interface{}) error { + msg := fmt.Sprintf(format, a...) + return fmt.Errorf("uritemplate:%d:%s", pos, msg) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/escape.go b/vendor/github.com/yosida95/uritemplate/v3/escape.go new file mode 100644 index 000000000..6d27e693a --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/escape.go @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +var ( + hex = []byte("0123456789ABCDEF") + // reserved = gen-delims / sub-delims + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + // sub-delims = "!" / "$" / "&" / "’" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + rangeReserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x21, Hi: 0x21, Stride: 1}, // '!' + {Lo: 0x23, Hi: 0x24, Stride: 1}, // '#' - '$' + {Lo: 0x26, Hi: 0x2C, Stride: 1}, // '&' - ',' + {Lo: 0x2F, Hi: 0x2F, Stride: 1}, // '/' + {Lo: 0x3A, Hi: 0x3B, Stride: 1}, // ':' - ';' + {Lo: 0x3D, Hi: 0x3D, Stride: 1}, // '=' + {Lo: 0x3F, Hi: 0x40, Stride: 1}, // '?' - '@' + {Lo: 0x5B, Hi: 0x5B, Stride: 1}, // '[' + {Lo: 0x5D, Hi: 0x5D, Stride: 1}, // ']' + }, + LatinOffset: 9, + } + reReserved = `\x21\x23\x24\x26-\x2c\x2f\x3a\x3b\x3d\x3f\x40\x5b\x5d` + // ALPHA = %x41-5A / %x61-7A + // DIGIT = %x30-39 + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + rangeUnreserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2D, Hi: 0x2E, Stride: 1}, // '-' - '.' + {Lo: 0x30, Hi: 0x39, Stride: 1}, // '0' - '9' + {Lo: 0x41, Hi: 0x5A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x5F, Hi: 0x5F, Stride: 1}, // '_' + {Lo: 0x61, Hi: 0x7A, Stride: 1}, // 'a' - 'z' + {Lo: 0x7E, Hi: 0x7E, Stride: 1}, // '~' + }, + } + reUnreserved = `\x2d\x2e\x30-\x39\x41-\x5a\x5f\x61-\x7a\x7e` +) + +type runeClass uint8 + +const ( + runeClassU runeClass = 1 << iota + runeClassR + runeClassPctE + runeClassLast + + runeClassUR = runeClassU | runeClassR +) + +var runeClassNames = []string{ + "U", + "R", + "pct-encoded", +} + +func (rc runeClass) String() string { + ret := make([]string, 0, len(runeClassNames)) + for i, j := 0, runeClass(1); j < runeClassLast; j <<= 1 { + if rc&j == j { + ret = append(ret, runeClassNames[i]) + } + i++ + } + return strings.Join(ret, "+") +} + +func pctEncode(w *strings.Builder, r rune) { + if s := r >> 24 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 16 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 8 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +func pctDecode(s string) string { + size := len(s) + for i := 0; i < len(s); { + switch s[i] { + case '%': + size -= 2 + i += 3 + default: + i++ + } + } + if size == len(s) { + return s + } + + buf := make([]byte, size) + j := 0 + for i := 0; i < len(s); { + switch c := s[i]; c { + case '%': + buf[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + i += 3 + j++ + default: + buf[j] = c + i++ + j++ + } + } + return string(buf) +} + +type escapeFunc func(*strings.Builder, string) error + +func escapeLiteral(w *strings.Builder, v string) error { + w.WriteString(v) + return nil +} + +func escapeExceptU(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + if unicode.Is(rangeUnreserved, r) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} + +func escapeExceptUR(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + // TODO(yosida95): is pct-encoded triplets allowed here? + if unicode.In(r, rangeUnreserved, rangeReserved) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/expression.go b/vendor/github.com/yosida95/uritemplate/v3/expression.go new file mode 100644 index 000000000..4858c2dde --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/expression.go @@ -0,0 +1,173 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "regexp" + "strconv" + "strings" +) + +type template interface { + expand(*strings.Builder, Values) error + regexp(*strings.Builder) +} + +type literals string + +func (l literals) expand(b *strings.Builder, _ Values) error { + b.WriteString(string(l)) + return nil +} + +func (l literals) regexp(b *strings.Builder) { + b.WriteString("(?:") + b.WriteString(regexp.QuoteMeta(string(l))) + b.WriteByte(')') +} + +type varspec struct { + name string + maxlen int + explode bool +} + +type expression struct { + vars []varspec + op parseOp + first string + sep string + named bool + ifemp string + escape escapeFunc + allow runeClass +} + +func (e *expression) init() { + switch e.op { + case parseOpSimple: + e.sep = "," + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpPlus: + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpCrosshatch: + e.first = "#" + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpDot: + e.first = "." + e.sep = "." + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSlash: + e.first = "/" + e.sep = "/" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSemicolon: + e.first = ";" + e.sep = ";" + e.named = true + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpQuestion: + e.first = "?" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpAmpersand: + e.first = "&" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + } +} + +func (e *expression) expand(w *strings.Builder, values Values) error { + first := true + for _, varspec := range e.vars { + value := values.Get(varspec.name) + if !value.Valid() { + continue + } + + if first { + w.WriteString(e.first) + first = false + } else { + w.WriteString(e.sep) + } + + if err := value.expand(w, varspec, e); err != nil { + return err + } + + } + return nil +} + +func (e *expression) regexp(b *strings.Builder) { + if e.first != "" { + b.WriteString("(?:") // $1 + b.WriteString(regexp.QuoteMeta(e.first)) + } + b.WriteByte('(') // $2 + runeClassToRegexp(b, e.allow, e.named || e.vars[0].explode) + if len(e.vars) > 1 || e.vars[0].explode { + max := len(e.vars) - 1 + for i := 0; i < len(e.vars); i++ { + if e.vars[i].explode { + max = -1 + break + } + } + + b.WriteString("(?:") // $3 + b.WriteString(regexp.QuoteMeta(e.sep)) + runeClassToRegexp(b, e.allow, e.named || max < 0) + b.WriteByte(')') // $3 + if max > 0 { + b.WriteString("{0,") + b.WriteString(strconv.Itoa(max)) + b.WriteByte('}') + } else { + b.WriteByte('*') + } + } + b.WriteByte(')') // $2 + if e.first != "" { + b.WriteByte(')') // $1 + } + b.WriteByte('?') +} + +func runeClassToRegexp(b *strings.Builder, class runeClass, named bool) { + b.WriteString("(?:(?:[") + if class&runeClassR == 0 { + b.WriteString(`\x2c`) + if named { + b.WriteString(`\x3d`) + } + } + if class&runeClassU == runeClassU { + b.WriteString(reUnreserved) + } + if class&runeClassR == runeClassR { + b.WriteString(reReserved) + } + b.WriteString("]") + b.WriteString("|%[[:xdigit:]][[:xdigit:]]") + b.WriteString(")*)") +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/machine.go b/vendor/github.com/yosida95/uritemplate/v3/machine.go new file mode 100644 index 000000000..7b1d0b518 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/machine.go @@ -0,0 +1,23 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +// threadList implements https://research.swtch.com/sparse. +type threadList struct { + dense []threadEntry + sparse []uint32 +} + +type threadEntry struct { + pc uint32 + t *thread +} + +type thread struct { + op *progOp + cap map[string][]int +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/match.go b/vendor/github.com/yosida95/uritemplate/v3/match.go new file mode 100644 index 000000000..02fe6385a --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/match.go @@ -0,0 +1,213 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "unicode" + "unicode/utf8" +) + +type matcher struct { + prog *prog + + list1 threadList + list2 threadList + matched bool + cap map[string][]int + + input string +} + +func (m *matcher) at(pos int) (rune, int, bool) { + if l := len(m.input); pos < l { + c := m.input[pos] + if c < utf8.RuneSelf { + return rune(c), 1, pos+1 < l + } + r, size := utf8.DecodeRuneInString(m.input[pos:]) + return r, size, pos+size < l + } + return -1, 0, false +} + +func (m *matcher) add(list *threadList, pc uint32, pos int, next bool, cap map[string][]int) { + if i := list.sparse[pc]; i < uint32(len(list.dense)) && list.dense[i].pc == pc { + return + } + + n := len(list.dense) + list.dense = list.dense[:n+1] + list.sparse[pc] = uint32(n) + + e := &list.dense[n] + e.pc = pc + e.t = nil + + op := &m.prog.op[pc] + switch op.code { + default: + panic("unhandled opcode") + case opRune, opRuneClass, opEnd: + e.t = &thread{ + op: &m.prog.op[pc], + cap: make(map[string][]int, len(m.cap)), + } + for k, v := range cap { + e.t.cap[k] = make([]int, len(v)) + copy(e.t.cap[k], v) + } + case opLineBegin: + if pos == 0 { + m.add(list, pc+1, pos, next, cap) + } + case opLineEnd: + if !next { + m.add(list, pc+1, pos, next, cap) + } + case opCapStart, opCapEnd: + ocap := make(map[string][]int, len(m.cap)) + for k, v := range cap { + ocap[k] = make([]int, len(v)) + copy(ocap[k], v) + } + ocap[op.name] = append(ocap[op.name], pos) + m.add(list, pc+1, pos, next, ocap) + case opSplit: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmp: + m.add(list, op.i, pos, next, cap) + case opJmpIfNotDefined: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotFirst: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotEmpty: + m.add(list, op.i, pos, next, cap) + m.add(list, pc+1, pos, next, cap) + case opNoop: + m.add(list, pc+1, pos, next, cap) + } +} + +func (m *matcher) step(clist *threadList, nlist *threadList, r rune, pos int, nextPos int, next bool) { + debug.Printf("===== %q =====", string(r)) + for i := 0; i < len(clist.dense); i++ { + e := clist.dense[i] + if debug { + var buf bytes.Buffer + dumpProg(&buf, m.prog, e.pc) + debug.Printf("\n%s", buf.String()) + } + if e.t == nil { + continue + } + + t := e.t + op := t.op + switch op.code { + default: + panic("unhandled opcode") + case opRune: + if op.r == r { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opRuneClass: + ret := false + if !ret && op.rc&runeClassU == runeClassU { + ret = ret || unicode.Is(rangeUnreserved, r) + } + if !ret && op.rc&runeClassR == runeClassR { + ret = ret || unicode.Is(rangeReserved, r) + } + if !ret && op.rc&runeClassPctE == runeClassPctE { + ret = ret || unicode.Is(unicode.ASCII_Hex_Digit, r) + } + if ret { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opEnd: + m.matched = true + for k, v := range t.cap { + m.cap[k] = make([]int, len(v)) + copy(m.cap[k], v) + } + clist.dense = clist.dense[:0] + } + } + clist.dense = clist.dense[:0] +} + +func (m *matcher) match() bool { + pos := 0 + clist, nlist := &m.list1, &m.list2 + for { + if len(clist.dense) == 0 && m.matched { + break + } + r, width, next := m.at(pos) + if !m.matched { + m.add(clist, 0, pos, next, m.cap) + } + m.step(clist, nlist, r, pos, pos+width, next) + + if width < 1 { + break + } + pos += width + + clist, nlist = nlist, clist + } + return m.matched +} + +func (tmpl *Template) Match(expansion string) Values { + tmpl.mu.Lock() + if tmpl.prog == nil { + c := compiler{} + c.init() + c.compile(tmpl) + tmpl.prog = c.prog + } + prog := tmpl.prog + tmpl.mu.Unlock() + + n := len(prog.op) + m := matcher{ + prog: prog, + list1: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + list2: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + cap: make(map[string][]int, prog.numCap), + input: expansion, + } + if !m.match() { + return nil + } + + match := make(Values, len(m.cap)) + for name, indices := range m.cap { + v := Value{V: make([]string, len(indices)/2)} + for i := range v.V { + v.V[i] = pctDecode(expansion[indices[2*i]:indices[2*i+1]]) + } + if len(v.V) == 1 { + v.T = ValueTypeString + } else { + v.T = ValueTypeList + } + match[name] = v + } + return match +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/parse.go b/vendor/github.com/yosida95/uritemplate/v3/parse.go new file mode 100644 index 000000000..fd38a682f --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/parse.go @@ -0,0 +1,277 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode" + "unicode/utf8" +) + +type parseOp int + +const ( + parseOpSimple parseOp = iota + parseOpPlus + parseOpCrosshatch + parseOpDot + parseOpSlash + parseOpSemicolon + parseOpQuestion + parseOpAmpersand +) + +var ( + rangeVarchar = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0030, Hi: 0x0039, Stride: 1}, // '0' - '9' + {Lo: 0x0041, Hi: 0x005A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + }, + LatinOffset: 4, + } + rangeLiterals = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0021, Hi: 0x0021, Stride: 1}, // '!' + {Lo: 0x0023, Hi: 0x0024, Stride: 1}, // '#' - '$' + {Lo: 0x0026, Hi: 0x003B, Stride: 1}, // '&' ''' '(' - ';'. '''/27 used to be excluded but an errata is in the review process https://www.rfc-editor.org/errata/eid6937 + {Lo: 0x003D, Hi: 0x003D, Stride: 1}, // '=' + {Lo: 0x003F, Hi: 0x005B, Stride: 1}, // '?' - '[' + {Lo: 0x005D, Hi: 0x005D, Stride: 1}, // ']' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + {Lo: 0x007E, Hi: 0x007E, Stride: 1}, // '~' + {Lo: 0x00A0, Hi: 0xD7FF, Stride: 1}, // ucschar + {Lo: 0xE000, Hi: 0xF8FF, Stride: 1}, // iprivate + {Lo: 0xF900, Hi: 0xFDCF, Stride: 1}, // ucschar + {Lo: 0xFDF0, Hi: 0xFFEF, Stride: 1}, // ucschar + }, + R32: []unicode.Range32{ + {Lo: 0x00010000, Hi: 0x0001FFFD, Stride: 1}, // ucschar + {Lo: 0x00020000, Hi: 0x0002FFFD, Stride: 1}, // ucschar + {Lo: 0x00030000, Hi: 0x0003FFFD, Stride: 1}, // ucschar + {Lo: 0x00040000, Hi: 0x0004FFFD, Stride: 1}, // ucschar + {Lo: 0x00050000, Hi: 0x0005FFFD, Stride: 1}, // ucschar + {Lo: 0x00060000, Hi: 0x0006FFFD, Stride: 1}, // ucschar + {Lo: 0x00070000, Hi: 0x0007FFFD, Stride: 1}, // ucschar + {Lo: 0x00080000, Hi: 0x0008FFFD, Stride: 1}, // ucschar + {Lo: 0x00090000, Hi: 0x0009FFFD, Stride: 1}, // ucschar + {Lo: 0x000A0000, Hi: 0x000AFFFD, Stride: 1}, // ucschar + {Lo: 0x000B0000, Hi: 0x000BFFFD, Stride: 1}, // ucschar + {Lo: 0x000C0000, Hi: 0x000CFFFD, Stride: 1}, // ucschar + {Lo: 0x000D0000, Hi: 0x000DFFFD, Stride: 1}, // ucschar + {Lo: 0x000E1000, Hi: 0x000EFFFD, Stride: 1}, // ucschar + {Lo: 0x000F0000, Hi: 0x000FFFFD, Stride: 1}, // iprivate + {Lo: 0x00100000, Hi: 0x0010FFFD, Stride: 1}, // iprivate + }, + LatinOffset: 10, + } +) + +type parser struct { + r string + start int + stop int + state parseState +} + +func (p *parser) errorf(i rune, format string, a ...interface{}) error { + return fmt.Errorf("%s: %s%s", fmt.Sprintf(format, a...), p.r[0:p.stop], string(i)) +} + +func (p *parser) rune() (rune, int) { + r, size := utf8.DecodeRuneInString(p.r[p.stop:]) + if r != utf8.RuneError { + p.stop += size + } + return r, size +} + +func (p *parser) unread(r rune) { + p.stop -= utf8.RuneLen(r) +} + +type parseState int + +const ( + parseStateDefault = parseState(iota) + parseStateOperator + parseStateVarList + parseStateVarName + parseStatePrefix +) + +func (p *parser) setState(state parseState) { + p.state = state + p.start = p.stop +} + +func (p *parser) parseURITemplate() (*Template, error) { + tmpl := Template{ + raw: p.r, + exprs: []template{}, + } + + var exp *expression + for { + r, size := p.rune() + if r == utf8.RuneError { + if size == 0 { + if p.state != parseStateDefault { + return nil, p.errorf('_', "incomplete expression") + } + if p.start < p.stop { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:p.stop])) + } + return &tmpl, nil + } + return nil, p.errorf('_', "invalid UTF-8 sequence") + } + + switch p.state { + case parseStateDefault: + switch r { + case '{': + if stop := p.stop - size; stop > p.start { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:stop])) + } + exp = &expression{} + tmpl.exprs = append(tmpl.exprs, exp) + p.setState(parseStateOperator) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + default: + if !unicode.Is(rangeLiterals, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable character (hint: use %%XX encoding)") + } + } + case parseStateOperator: + switch r { + default: + p.unread(r) + exp.op = parseOpSimple + case '+': + exp.op = parseOpPlus + case '#': + exp.op = parseOpCrosshatch + case '.': + exp.op = parseOpDot + case '/': + exp.op = parseOpSlash + case ';': + exp.op = parseOpSemicolon + case '?': + exp.op = parseOpQuestion + case '&': + exp.op = parseOpAmpersand + case '=', ',', '!', '@', '|': // op-reserved + return nil, p.errorf('|', "unimplemented operator (op-reserved)") + } + p.setState(parseStateVarName) + case parseStateVarList: + switch r { + case ',': + p.setState(parseStateVarName) + case '}': + exp.init() + p.setState(parseStateDefault) + default: + p.unread(r) + return nil, p.errorf('_', "unrecognized value modifier") + } + case parseStateVarName: + switch r { + case ':', '*': + name := p.r[p.start : p.stop-size] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + explode := r == '*' + exp.vars = append(exp.vars, varspec{ + name: name, + explode: explode, + }) + if explode { + p.setState(parseStateVarList) + } else { + p.setState(parseStatePrefix) + } + case ',', '}': + p.unread(r) + name := p.r[p.start:p.stop] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + exp.vars = append(exp.vars, varspec{ + name: name, + }) + p.setState(parseStateVarList) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + case '.': + if dot := p.stop - size; dot == p.start || p.r[dot-1] == '.' { + return nil, p.errorf('|', "unacceptable variable name") + } + default: + if !unicode.Is(rangeVarchar, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable variable name") + } + } + case parseStatePrefix: + spec := &(exp.vars[len(exp.vars)-1]) + switch { + case '0' <= r && r <= '9': + spec.maxlen *= 10 + spec.maxlen += int(r - '0') + if spec.maxlen == 0 || spec.maxlen > 9999 { + return nil, p.errorf('|', "max-length must be (0, 9999]") + } + default: + p.unread(r) + if spec.maxlen == 0 { + return nil, p.errorf('_', "max-length must be (0, 9999]") + } + p.setState(parseStateVarList) + } + default: + p.unread(r) + panic(p.errorf('_', "unhandled parseState(%d)", p.state)) + } + } +} + +func isValidVarname(name string) bool { + if l := len(name); l == 0 || name[0] == '.' || name[l-1] == '.' { + return false + } + for i := 1; i < len(name)-1; i++ { + switch c := name[i]; c { + case '.': + if name[i-1] == '.' { + return false + } + } + } + return true +} + +func (p *parser) consumeTriplet() error { + if len(p.r)-p.stop < 3 || p.r[p.stop] != '%' || !ishex(p.r[p.stop+1]) || !ishex(p.r[p.stop+2]) { + return p.errorf('_', "incomplete pct-encodeed") + } + p.stop += 3 + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/prog.go b/vendor/github.com/yosida95/uritemplate/v3/prog.go new file mode 100644 index 000000000..97af4f0ea --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/prog.go @@ -0,0 +1,130 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "strconv" +) + +type progOpcode uint16 + +const ( + // match + opRune progOpcode = iota + opRuneClass + opLineBegin + opLineEnd + // capture + opCapStart + opCapEnd + // stack + opSplit + opJmp + opJmpIfNotDefined + opJmpIfNotEmpty + opJmpIfNotFirst + // result + opEnd + // fake + opNoop + opcodeMax +) + +var opcodeNames = []string{ + // match + "opRune", + "opRuneClass", + "opLineBegin", + "opLineEnd", + // capture + "opCapStart", + "opCapEnd", + // stack + "opSplit", + "opJmp", + "opJmpIfNotDefined", + "opJmpIfNotEmpty", + "opJmpIfNotFirst", + // result + "opEnd", +} + +func (code progOpcode) String() string { + if code >= opcodeMax { + return "" + } + return opcodeNames[code] +} + +type progOp struct { + code progOpcode + r rune + rc runeClass + i uint32 + + name string +} + +func dumpProgOp(b *bytes.Buffer, op *progOp) { + b.WriteString(op.code.String()) + switch op.code { + case opRune: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(string(op.r))) + b.WriteString(")") + case opRuneClass: + b.WriteString("(") + b.WriteString(op.rc.String()) + b.WriteString(")") + case opCapStart, opCapEnd: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + case opSplit: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmp, opJmpIfNotFirst: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmpIfNotDefined, opJmpIfNotEmpty: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + } +} + +type prog struct { + op []progOp + numCap int +} + +func dumpProg(b *bytes.Buffer, prog *prog, pc uint32) { + for i := range prog.op { + op := prog.op[i] + + pos := strconv.Itoa(i) + if uint32(i) == pc { + pos = "*" + pos + } + b.WriteString(" "[len(pos):]) + b.WriteString(pos) + + b.WriteByte('\t') + dumpProgOp(b, &op) + + b.WriteByte('\n') + } +} + +func (p *prog) String() string { + b := bytes.Buffer{} + dumpProg(&b, p, 0) + return b.String() +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go new file mode 100644 index 000000000..dbd267375 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go @@ -0,0 +1,116 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "log" + "regexp" + "strings" + "sync" +) + +var ( + debug = debugT(false) +) + +type debugT bool + +func (t debugT) Printf(format string, v ...interface{}) { + if t { + log.Printf(format, v...) + } +} + +// Template represents a URI Template. +type Template struct { + raw string + exprs []template + + // protects the rest of fields + mu sync.Mutex + varnames []string + re *regexp.Regexp + prog *prog +} + +// New parses and constructs a new Template instance based on the template. +// New returns an error if the template cannot be recognized. +func New(template string) (*Template, error) { + return (&parser{r: template}).parseURITemplate() +} + +// MustNew panics if the template cannot be recognized. +func MustNew(template string) *Template { + ret, err := New(template) + if err != nil { + panic(err) + } + return ret +} + +// Raw returns a raw URI template passed to New in string. +func (t *Template) Raw() string { + return t.raw +} + +// Varnames returns variable names used in the template. +func (t *Template) Varnames() []string { + t.mu.Lock() + defer t.mu.Unlock() + if t.varnames != nil { + return t.varnames + } + + reg := map[string]struct{}{} + t.varnames = []string{} + for i := range t.exprs { + expr, ok := t.exprs[i].(*expression) + if !ok { + continue + } + for _, spec := range expr.vars { + if _, ok := reg[spec.name]; ok { + continue + } + reg[spec.name] = struct{}{} + t.varnames = append(t.varnames, spec.name) + } + } + + return t.varnames +} + +// Expand returns a URI reference corresponding to the template expanded using the passed variables. +func (t *Template) Expand(vars Values) (string, error) { + var w strings.Builder + for i := range t.exprs { + expr := t.exprs[i] + if err := expr.expand(&w, vars); err != nil { + return w.String(), err + } + } + return w.String(), nil +} + +// Regexp converts the template to regexp and returns compiled *regexp.Regexp. +func (t *Template) Regexp() *regexp.Regexp { + t.mu.Lock() + defer t.mu.Unlock() + if t.re != nil { + return t.re + } + + var b strings.Builder + b.WriteByte('^') + for _, expr := range t.exprs { + expr.regexp(&b) + } + b.WriteByte('$') + t.re = regexp.MustCompile(b.String()) + + return t.re +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/value.go b/vendor/github.com/yosida95/uritemplate/v3/value.go new file mode 100644 index 000000000..0550eabdb --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/value.go @@ -0,0 +1,216 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import "strings" + +// A varname containing pct-encoded characters is not the same variable as +// a varname with those same characters decoded. +// +// -- https://tools.ietf.org/html/rfc6570#section-2.3 +type Values map[string]Value + +func (v Values) Set(name string, value Value) { + v[name] = value +} + +func (v Values) Get(name string) Value { + if v == nil { + return Value{} + } + return v[name] +} + +type ValueType uint8 + +const ( + ValueTypeString = iota + ValueTypeList + ValueTypeKV + valueTypeLast +) + +var valueTypeNames = []string{ + "String", + "List", + "KV", +} + +func (vt ValueType) String() string { + if vt < valueTypeLast { + return valueTypeNames[vt] + } + return "" +} + +type Value struct { + T ValueType + V []string +} + +func (v Value) String() string { + if v.Valid() && v.T == ValueTypeString { + return v.V[0] + } + return "" +} + +func (v Value) List() []string { + if v.Valid() && v.T == ValueTypeList { + return v.V + } + return nil +} + +func (v Value) KV() []string { + if v.Valid() && v.T == ValueTypeKV { + return v.V + } + return nil +} + +func (v Value) Valid() bool { + switch v.T { + default: + return false + case ValueTypeString: + return len(v.V) > 0 + case ValueTypeList: + return len(v.V) > 0 + case ValueTypeKV: + return len(v.V) > 0 && len(v.V)%2 == 0 + } +} + +func (v Value) expand(w *strings.Builder, spec varspec, exp *expression) error { + switch v.T { + case ValueTypeString: + val := v.V[0] + var maxlen int + if max := len(val); spec.maxlen < 1 || spec.maxlen > max { + maxlen = max + } else { + maxlen = spec.maxlen + } + + if exp.named { + w.WriteString(spec.name) + if val == "" { + w.WriteString(exp.ifemp) + return nil + } + w.WriteByte('=') + } + return exp.escape(w, val[:maxlen]) + case ValueTypeList: + var sep string + if spec.explode { + sep = exp.sep + } else { + sep = "," + } + + var pre string + var preifemp string + if spec.explode && exp.named { + pre = spec.name + "=" + preifemp = spec.name + exp.ifemp + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + for i := range v.V { + val := v.V[i] + if i > 0 { + w.WriteString(sep) + } + if val == "" { + w.WriteString(preifemp) + continue + } + w.WriteString(pre) + + if err := exp.escape(w, val); err != nil { + return err + } + } + case ValueTypeKV: + var sep string + var kvsep string + if spec.explode { + sep = exp.sep + kvsep = "=" + } else { + sep = "," + kvsep = "," + } + + var ifemp string + var kescape escapeFunc + if spec.explode && exp.named { + ifemp = exp.ifemp + kescape = escapeLiteral + } else { + ifemp = "," + kescape = exp.escape + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + + for i := 0; i < len(v.V); i += 2 { + if i > 0 { + w.WriteString(sep) + } + if err := kescape(w, v.V[i]); err != nil { + return err + } + if v.V[i+1] == "" { + w.WriteString(ifemp) + continue + } + w.WriteString(kvsep) + + if err := exp.escape(w, v.V[i+1]); err != nil { + return err + } + } + } + return nil +} + +// String returns Value that represents string. +func String(v string) Value { + return Value{ + T: ValueTypeString, + V: []string{v}, + } +} + +// List returns Value that represents list. +func List(v ...string) Value { + return Value{ + T: ValueTypeList, + V: v, + } +} + +// KV returns Value that represents associative list. +// KV panics if len(kv) is not even. +func KV(kv ...string) Value { + if len(kv)%2 != 0 { + panic("uritemplate.go: count of the kv must be even number") + } + return Value{ + T: ValueTypeKV, + V: kv, + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index cb35bbe31..5a9a01d33 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -153,9 +153,15 @@ github.com/aws/smithy-go/tracing github.com/aws/smithy-go/transport/http github.com/aws/smithy-go/transport/http/internal/io github.com/aws/smithy-go/waiter +# github.com/bahlo/generic-list-go v0.2.0 +## explicit; go 1.18 +github.com/bahlo/generic-list-go # github.com/benbjohnson/clock v1.3.5 ## explicit; go 1.15 github.com/benbjohnson/clock +# github.com/buger/jsonparser v1.1.1 +## explicit; go 1.13 +github.com/buger/jsonparser # github.com/cenkalti/backoff/v4 v4.3.0 ## explicit; go 1.18 github.com/cenkalti/backoff/v4 @@ -266,6 +272,9 @@ github.com/hashicorp/hcl/json/token # github.com/inconshreveable/mousetrap v1.1.0 ## explicit; go 1.18 github.com/inconshreveable/mousetrap +# github.com/invopop/jsonschema v0.13.0 +## explicit; go 1.18 +github.com/invopop/jsonschema # github.com/jellydator/ttlcache/v3 v3.3.0 ## explicit; go 1.18 github.com/jellydator/ttlcache/v3 @@ -285,6 +294,15 @@ github.com/lufia/plan9stats # github.com/magiconair/properties v1.8.9 ## explicit; go 1.19 github.com/magiconair/properties +# github.com/mailru/easyjson v0.7.7 +## explicit; go 1.12 +github.com/mailru/easyjson/buffer +github.com/mailru/easyjson/jwriter +# github.com/mark3labs/mcp-go v0.43.2 +## explicit; go 1.23.0 +github.com/mark3labs/mcp-go/mcp +github.com/mark3labs/mcp-go/server +github.com/mark3labs/mcp-go/util # github.com/mattn/go-isatty v0.0.20 ## explicit; go 1.15 github.com/mattn/go-isatty @@ -401,6 +419,12 @@ github.com/tklauser/go-sysconf # github.com/tklauser/numcpus v0.11.0 ## explicit; go 1.24.0 github.com/tklauser/numcpus +# github.com/wk8/go-ordered-map/v2 v2.1.8 +## explicit; go 1.18 +github.com/wk8/go-ordered-map/v2 +# github.com/yosida95/uritemplate/v3 v3.0.2 +## explicit; go 1.14 +github.com/yosida95/uritemplate/v3 # github.com/yusufpapurcu/wmi v1.2.4 ## explicit; go 1.16 github.com/yusufpapurcu/wmi From e50a1b01d520e0fbb6f6c395de971b7dce4fb5dc Mon Sep 17 00:00:00 2001 From: Russell Haering Date: Tue, 13 Jan 2026 17:12:09 -0800 Subject: [PATCH 2/2] Switch MCP SDK from mark3labs/mcp-go to modelcontextprotocol/go-sdk Replace the community MCP SDK with the official modelcontextprotocol/go-sdk. This provides: - Typed input/output structs with jsonschema tags for tool definitions - Cleaner handler signatures: func(ctx, req, Input) (Result, Output, error) - Official protocol support Updated all handlers to use typed structs instead of manual argument parsing. --- go.mod | 10 +- go.sum | 27 +- pkg/cli/commands.go | 2 +- pkg/mcp/handlers.go | 645 +++--- pkg/mcp/server.go | 349 +-- .../github.com/bahlo/generic-list-go/LICENSE | 27 - .../bahlo/generic-list-go/README.md | 5 - .../github.com/bahlo/generic-list-go/list.go | 235 -- vendor/github.com/buger/jsonparser/.gitignore | 12 - .../github.com/buger/jsonparser/.travis.yml | 11 - vendor/github.com/buger/jsonparser/Dockerfile | 12 - vendor/github.com/buger/jsonparser/Makefile | 36 - vendor/github.com/buger/jsonparser/README.md | 365 --- vendor/github.com/buger/jsonparser/bytes.go | 47 - .../github.com/buger/jsonparser/bytes_safe.go | 25 - .../buger/jsonparser/bytes_unsafe.go | 44 - vendor/github.com/buger/jsonparser/escape.go | 173 -- vendor/github.com/buger/jsonparser/fuzz.go | 117 - .../buger/jsonparser/oss-fuzz-build.sh | 47 - vendor/github.com/buger/jsonparser/parser.go | 1283 ----------- .../mcp-go => google/jsonschema-go}/LICENSE | 2 +- .../jsonschema-go/jsonschema/annotations.go | 76 + .../google/jsonschema-go/jsonschema/doc.go | 101 + .../google/jsonschema-go/jsonschema/infer.go | 248 ++ .../jsonschema-go/jsonschema/json_pointer.go | 146 ++ .../jsonschema-go/jsonschema/resolve.go | 548 +++++ .../google/jsonschema-go/jsonschema/schema.go | 436 ++++ .../google/jsonschema-go/jsonschema/util.go | 463 ++++ .../jsonschema-go/jsonschema/validate.go | 789 +++++++ .../github.com/invopop/jsonschema/.gitignore | 2 - .../invopop/jsonschema/.golangci.yml | 69 - vendor/github.com/invopop/jsonschema/COPYING | 19 - .../github.com/invopop/jsonschema/README.md | 374 --- vendor/github.com/invopop/jsonschema/id.go | 76 - .../github.com/invopop/jsonschema/reflect.go | 1148 ---------- .../invopop/jsonschema/reflect_comments.go | 146 -- .../github.com/invopop/jsonschema/schema.go | 94 - vendor/github.com/invopop/jsonschema/utils.go | 26 - vendor/github.com/mailru/easyjson/LICENSE | 7 - .../github.com/mailru/easyjson/buffer/pool.go | 278 --- .../mailru/easyjson/jwriter/writer.go | 405 ---- .../github.com/mark3labs/mcp-go/mcp/consts.go | 9 - .../github.com/mark3labs/mcp-go/mcp/errors.go | 85 - .../mark3labs/mcp-go/mcp/prompts.go | 176 -- .../mark3labs/mcp-go/mcp/resources.go | 99 - .../github.com/mark3labs/mcp-go/mcp/tools.go | 1331 ----------- .../mark3labs/mcp-go/mcp/typed_tools.go | 42 - .../github.com/mark3labs/mcp-go/mcp/types.go | 1252 ---------- .../github.com/mark3labs/mcp-go/mcp/utils.go | 979 -------- .../mark3labs/mcp-go/server/constants.go | 7 - .../github.com/mark3labs/mcp-go/server/ctx.go | 8 - .../mark3labs/mcp-go/server/elicitation.go | 32 - .../mark3labs/mcp-go/server/errors.go | 36 - .../mark3labs/mcp-go/server/hooks.go | 532 ----- .../mcp-go/server/http_transport_options.go | 11 - .../mcp-go/server/inprocess_session.go | 165 -- .../mcp-go/server/request_handler.go | 339 --- .../mark3labs/mcp-go/server/roots.go | 32 - .../mark3labs/mcp-go/server/sampling.go | 61 - .../mark3labs/mcp-go/server/server.go | 1337 ----------- .../mark3labs/mcp-go/server/session.go | 770 ------- .../github.com/mark3labs/mcp-go/server/sse.go | 797 ------- .../mark3labs/mcp-go/server/stdio.go | 877 ------- .../mcp-go/server/streamable_http.go | 1434 ------------ .../mark3labs/mcp-go/util/logger.go | 33 - .../go-sdk}/LICENSE | 2 +- .../modelcontextprotocol/go-sdk/auth/auth.go | 168 ++ .../go-sdk/auth/client.go | 123 + .../go-sdk/internal/jsonrpc2/conn.go | 841 +++++++ .../go-sdk/internal/jsonrpc2/frame.go | 208 ++ .../go-sdk/internal/jsonrpc2/jsonrpc2.go | 121 + .../go-sdk/internal/jsonrpc2/messages.go | 212 ++ .../go-sdk/internal/jsonrpc2/net.go | 138 ++ .../go-sdk/internal/jsonrpc2/serve.go | 330 +++ .../go-sdk/internal/jsonrpc2/wire.go | 97 + .../go-sdk/internal/util/util.go | 44 + .../go-sdk/internal/xcontext/xcontext.go | 23 + .../go-sdk/jsonrpc/jsonrpc.go | 56 + .../modelcontextprotocol/go-sdk/mcp/client.go | 1075 +++++++++ .../modelcontextprotocol/go-sdk/mcp/cmd.go | 108 + .../go-sdk/mcp/content.go | 289 +++ .../modelcontextprotocol/go-sdk/mcp/event.go | 429 ++++ .../go-sdk/mcp/features.go | 114 + .../go-sdk/mcp/logging.go | 207 ++ .../modelcontextprotocol/go-sdk/mcp/mcp.go | 88 + .../modelcontextprotocol/go-sdk/mcp/prompt.go | 17 + .../go-sdk/mcp/protocol.go | 1357 +++++++++++ .../go-sdk/mcp/requests.go | 38 + .../go-sdk/mcp/resource.go | 164 ++ .../go-sdk/mcp/resource_go124.go | 29 + .../go-sdk/mcp/resource_pre_go124.go | 25 + .../modelcontextprotocol/go-sdk/mcp/server.go | 1497 ++++++++++++ .../go-sdk/mcp/session.go | 29 + .../modelcontextprotocol/go-sdk/mcp/shared.go | 610 +++++ .../modelcontextprotocol/go-sdk/mcp/sse.go | 479 ++++ .../go-sdk/mcp/streamable.go | 2040 +++++++++++++++++ .../go-sdk/mcp/streamable_client.go | 226 ++ .../go-sdk/mcp/streamable_server.go | 160 ++ .../modelcontextprotocol/go-sdk/mcp/tool.go | 139 ++ .../go-sdk/mcp/transport.go | 655 ++++++ .../modelcontextprotocol/go-sdk/mcp/util.go | 43 + .../go-sdk/oauthex/auth_meta.go | 187 ++ .../go-sdk/oauthex/dcr.go | 261 +++ .../go-sdk/oauthex/oauth2.go | 91 + .../go-sdk/oauthex/oauthex.go | 92 + .../go-sdk/oauthex/resource_meta.go | 281 +++ .../wk8/go-ordered-map/v2/.gitignore | 1 - .../wk8/go-ordered-map/v2/.golangci.yml | 80 - .../wk8/go-ordered-map/v2/CHANGELOG.md | 38 - .../github.com/wk8/go-ordered-map/v2/LICENSE | 201 -- .../github.com/wk8/go-ordered-map/v2/Makefile | 32 - .../wk8/go-ordered-map/v2/README.md | 154 -- .../github.com/wk8/go-ordered-map/v2/json.go | 182 -- .../wk8/go-ordered-map/v2/orderedmap.go | 296 --- .../github.com/wk8/go-ordered-map/v2/yaml.go | 71 - .../clientcredentials/clientcredentials.go | 10 +- vendor/golang.org/x/oauth2/internal/doc.go | 2 +- vendor/golang.org/x/oauth2/internal/oauth2.go | 2 +- vendor/golang.org/x/oauth2/internal/token.go | 50 +- .../golang.org/x/oauth2/internal/transport.go | 4 +- vendor/golang.org/x/oauth2/jws/jws.go | 40 +- vendor/golang.org/x/oauth2/jwt/jwt.go | 13 +- vendor/golang.org/x/oauth2/oauth2.go | 63 +- vendor/golang.org/x/oauth2/pkce.go | 15 +- vendor/golang.org/x/oauth2/token.go | 17 +- vendor/golang.org/x/oauth2/transport.go | 24 +- vendor/modules.txt | 37 +- 127 files changed, 16398 insertions(+), 17396 deletions(-) delete mode 100644 vendor/github.com/bahlo/generic-list-go/LICENSE delete mode 100644 vendor/github.com/bahlo/generic-list-go/README.md delete mode 100644 vendor/github.com/bahlo/generic-list-go/list.go delete mode 100644 vendor/github.com/buger/jsonparser/.gitignore delete mode 100644 vendor/github.com/buger/jsonparser/.travis.yml delete mode 100644 vendor/github.com/buger/jsonparser/Dockerfile delete mode 100644 vendor/github.com/buger/jsonparser/Makefile delete mode 100644 vendor/github.com/buger/jsonparser/README.md delete mode 100644 vendor/github.com/buger/jsonparser/bytes.go delete mode 100644 vendor/github.com/buger/jsonparser/bytes_safe.go delete mode 100644 vendor/github.com/buger/jsonparser/bytes_unsafe.go delete mode 100644 vendor/github.com/buger/jsonparser/escape.go delete mode 100644 vendor/github.com/buger/jsonparser/fuzz.go delete mode 100644 vendor/github.com/buger/jsonparser/oss-fuzz-build.sh delete mode 100644 vendor/github.com/buger/jsonparser/parser.go rename vendor/github.com/{mark3labs/mcp-go => google/jsonschema-go}/LICENSE (95%) create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/annotations.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/doc.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/infer.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/resolve.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/schema.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/util.go create mode 100644 vendor/github.com/google/jsonschema-go/jsonschema/validate.go delete mode 100644 vendor/github.com/invopop/jsonschema/.gitignore delete mode 100644 vendor/github.com/invopop/jsonschema/.golangci.yml delete mode 100644 vendor/github.com/invopop/jsonschema/COPYING delete mode 100644 vendor/github.com/invopop/jsonschema/README.md delete mode 100644 vendor/github.com/invopop/jsonschema/id.go delete mode 100644 vendor/github.com/invopop/jsonschema/reflect.go delete mode 100644 vendor/github.com/invopop/jsonschema/reflect_comments.go delete mode 100644 vendor/github.com/invopop/jsonschema/schema.go delete mode 100644 vendor/github.com/invopop/jsonschema/utils.go delete mode 100644 vendor/github.com/mailru/easyjson/LICENSE delete mode 100644 vendor/github.com/mailru/easyjson/buffer/pool.go delete mode 100644 vendor/github.com/mailru/easyjson/jwriter/writer.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/consts.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/errors.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/prompts.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/resources.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/tools.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/types.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/mcp/utils.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/constants.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/ctx.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/elicitation.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/errors.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/hooks.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/request_handler.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/roots.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/sampling.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/server.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/session.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/sse.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/stdio.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/server/streamable_http.go delete mode 100644 vendor/github.com/mark3labs/mcp-go/util/logger.go rename vendor/github.com/{buger/jsonparser => modelcontextprotocol/go-sdk}/LICENSE (96%) create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go create mode 100644 vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/.gitignore delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/LICENSE delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/Makefile delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/README.md delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/json.go delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go delete mode 100644 vendor/github.com/wk8/go-ordered-map/v2/yaml.go diff --git a/go.mod b/go.mod index 2e200c647..9c6edea1c 100644 --- a/go.mod +++ b/go.mod @@ -26,9 +26,9 @@ require ( github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/klauspost/compress v1.18.0 - github.com/mark3labs/mcp-go v0.43.2 github.com/maypok86/otter/v2 v2.2.1 github.com/mitchellh/mapstructure v1.5.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/pquerna/xjwt v0.3.0 github.com/pquerna/xjwt/xkeyset v0.0.0-20241217022915-10fc997b2a9f github.com/segmentio/ksuid v1.0.4 @@ -52,7 +52,7 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/crypto v0.34.0 golang.org/x/net v0.35.0 - golang.org/x/oauth2 v0.26.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.11.0 golang.org/x/sys v0.38.0 golang.org/x/term v0.29.0 @@ -77,9 +77,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.10 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.12 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.11 // indirect - github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect - github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -89,14 +87,13 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/jellydator/ttlcache/v3 v3.3.0 // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect github.com/magiconair/properties v1.8.9 // indirect - github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -113,7 +110,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/numcpus v0.11.0 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect diff --git a/go.sum b/go.sum index 1160f9dd6..ae97267be 100644 --- a/go.sum +++ b/go.sum @@ -52,13 +52,9 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.10 h1:g9d+TOsu3ac7SgmY2dUf1qMgu/u github.com/aws/aws-sdk-go-v2/service/sts v1.33.10/go.mod h1:WZfNmntu92HO44MVZAubQaz3qCuIdeOdog2sADfU6hU= github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= -github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= -github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= -github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -110,6 +106,8 @@ github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -121,6 +119,8 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -133,11 +133,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= -github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc= github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -156,10 +153,6 @@ github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 h1:7UMa6KCCMjZEMD github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM= github.com/magiconair/properties v1.8.9/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= -github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= @@ -168,6 +161,8 @@ github.com/maypok86/otter/v2 v2.2.1 h1:hnGssisMFkdisYcvQ8L019zpYQcdtPse+g0ps2i7c github.com/maypok86/otter/v2 v2.2.1/go.mod h1:1NKY9bY+kB5jwCXBJfE59u+zAwOt6C7ni1FTlFFMqVs= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= @@ -230,8 +225,6 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= -github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -309,8 +302,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= -golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -345,8 +338,8 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= -golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index aceb553d2..40ed588c1 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -664,7 +664,7 @@ func MakeMCPServerCommand[T field.Configurable]( } l.Info("MCP server starting on stdio") - return mcpServer.Serve() + return mcpServer.Serve(runCtx) } } diff --git a/pkg/mcp/handlers.go b/pkg/mcp/handlers.go index 37c2c6525..1913b6744 100644 --- a/pkg/mcp/handlers.go +++ b/pkg/mcp/handlers.go @@ -2,423 +2,378 @@ package mcp import ( "context" - "encoding/json" "fmt" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" ) const defaultPageSize = 50 -// handleGetMetadata handles the get_metadata tool. -func (m *MCPServer) handleGetMetadata(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +// Input/Output types for handlers. + +type EmptyInput struct{} + +type PaginationInput struct { + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type ResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type (e.g. user or group)"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID"` +} + +type ResourcePaginationInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID"` + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type ListResourcesInput struct { + ResourceTypeID string `json:"resource_type_id" jsonschema:"required,description=The resource type ID to list (e.g. user or group)"` + ParentResourceType string `json:"parent_resource_type,omitempty" jsonschema:"description=Parent resource type (optional)"` + ParentResourceID string `json:"parent_resource_id,omitempty" jsonschema:"description=Parent resource ID (optional)"` + PageSize int `json:"page_size,omitempty" jsonschema:"description=Number of items per page (default 50)"` + PageToken string `json:"page_token,omitempty" jsonschema:"description=Pagination token from previous response"` +} + +type GrantInput struct { + EntitlementResourceType string `json:"entitlement_resource_type" jsonschema:"required,description=Resource type of the entitlement"` + EntitlementResourceID string `json:"entitlement_resource_id" jsonschema:"required,description=Resource ID of the entitlement"` + EntitlementID string `json:"entitlement_id" jsonschema:"required,description=The entitlement ID"` + PrincipalResourceType string `json:"principal_resource_type" jsonschema:"required,description=Resource type of the principal (e.g. user or group)"` + PrincipalResourceID string `json:"principal_resource_id" jsonschema:"required,description=Resource ID of the principal"` +} + +type RevokeInput struct { + GrantID string `json:"grant_id" jsonschema:"required,description=The grant ID to revoke"` + EntitlementResourceType string `json:"entitlement_resource_type" jsonschema:"required,description=Resource type of the entitlement"` + EntitlementResourceID string `json:"entitlement_resource_id" jsonschema:"required,description=Resource ID of the entitlement"` + EntitlementID string `json:"entitlement_id" jsonschema:"required,description=The entitlement ID"` + PrincipalResourceType string `json:"principal_resource_type" jsonschema:"required,description=Resource type of the principal"` + PrincipalResourceID string `json:"principal_resource_id" jsonschema:"required,description=Resource ID of the principal"` +} + +type CreateResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type to create"` + DisplayName string `json:"display_name" jsonschema:"required,description=Display name for the new resource"` + ParentResourceType string `json:"parent_resource_type,omitempty" jsonschema:"description=Parent resource type (optional)"` + ParentResourceID string `json:"parent_resource_id,omitempty" jsonschema:"description=Parent resource ID (optional)"` +} + +type DeleteResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"required,description=The resource type"` + ResourceID string `json:"resource_id" jsonschema:"required,description=The resource ID to delete"` +} + +type CreateTicketInput struct { + SchemaID string `json:"schema_id" jsonschema:"required,description=The ticket schema ID"` + DisplayName string `json:"display_name" jsonschema:"required,description=Display name for the ticket"` + Description string `json:"description,omitempty" jsonschema:"description=Description of the ticket"` +} + +type GetTicketInput struct { + TicketID string `json:"ticket_id" jsonschema:"required,description=The ticket ID"` +} + +// Output types. + +type MetadataOutput struct { + Metadata map[string]any `json:"metadata"` +} + +type ValidateOutput struct { + Valid bool `json:"valid"` + Annotations any `json:"annotations,omitempty"` +} + +type ListResourceTypesOutput struct { + ResourceTypes []map[string]any `json:"resource_types"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ListResourcesOutput struct { + Resources []map[string]any `json:"resources"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ResourceOutput struct { + Resource map[string]any `json:"resource"` +} + +type ListEntitlementsOutput struct { + Entitlements []map[string]any `json:"entitlements"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type ListGrantsOutput struct { + Grants []map[string]any `json:"grants"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type GrantOutput struct { + Grants []map[string]any `json:"grants"` +} + +type SuccessOutput struct { + Success bool `json:"success"` +} + +type ListTicketSchemasOutput struct { + Schemas []map[string]any `json:"schemas"` + NextPageToken string `json:"next_page_token,omitempty"` + HasMore bool `json:"has_more"` +} + +type TicketOutput struct { + Ticket map[string]any `json:"ticket"` +} + +// Handler implementations. + +func (m *MCPServer) handleGetMetadata(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, MetadataOutput, error) { resp, err := m.connector.GetMetadata(ctx, &v2.ConnectorServiceGetMetadataRequest{}) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get metadata: %v", err)), nil + return nil, MetadataOutput{}, fmt.Errorf("failed to get metadata: %w", err) } result, err := protoToMap(resp.GetMetadata()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize metadata: %v", err)), nil - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, MetadataOutput{}, fmt.Errorf("failed to serialize metadata: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, MetadataOutput{Metadata: result}, nil } -// handleValidate handles the validate tool. -func (m *MCPServer) handleValidate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (m *MCPServer) handleValidate(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, ValidateOutput, error) { resp, err := m.connector.Validate(ctx, &v2.ConnectorServiceValidateRequest{}) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("validation failed: %v", err)), nil - } - - result := map[string]any{ - "valid": true, - "annotations": resp.GetAnnotations(), - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ValidateOutput{}, fmt.Errorf("validation failed: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ValidateOutput{ + Valid: true, + Annotations: resp.GetAnnotations(), + }, nil } -// handleListResourceTypes handles the list_resource_types tool. -func (m *MCPServer) handleListResourceTypes(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - pageSize := getPageSize(req) - pageToken := getStringArg(req, "page_token") +func (m *MCPServer) handleListResourceTypes(ctx context.Context, req *mcp.CallToolRequest, input PaginationInput) (*mcp.CallToolResult, ListResourceTypesOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) + } resp, err := m.connector.ListResourceTypes(ctx, &v2.ResourceTypesServiceListResourceTypesRequest{ PageSize: pageSize, - PageToken: pageToken, + PageToken: input.PageToken, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list resource types: %v", err)), nil + return nil, ListResourceTypesOutput{}, fmt.Errorf("failed to list resource types: %w", err) } resourceTypes, err := protoListToMaps(resp.GetList()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource types: %v", err)), nil - } - - result := map[string]any{ - "resource_types": resourceTypes, - "next_page_token": resp.GetNextPageToken(), - "has_more": resp.GetNextPageToken() != "", - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ListResourceTypesOutput{}, fmt.Errorf("failed to serialize resource types: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ListResourceTypesOutput{ + ResourceTypes: resourceTypes, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil } -// handleListResources handles the list_resources tool. -func (m *MCPServer) handleListResources(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceTypeID, err := req.RequireString("resource_type_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type_id is required: %v", err)), nil +func (m *MCPServer) handleListResources(ctx context.Context, req *mcp.CallToolRequest, input ListResourcesInput) (*mcp.CallToolResult, ListResourcesOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) } - pageSize := getPageSize(req) - pageToken := getStringArg(req, "page_token") - - // Build parent resource ID if specified. var parentResourceID *v2.ResourceId - parentType := getStringArg(req, "parent_resource_type") - parentID := getStringArg(req, "parent_resource_id") - if parentType != "" && parentID != "" { + if input.ParentResourceType != "" && input.ParentResourceID != "" { parentResourceID = &v2.ResourceId{ - ResourceType: parentType, - Resource: parentID, + ResourceType: input.ParentResourceType, + Resource: input.ParentResourceID, } } resp, err := m.connector.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ - ResourceTypeId: resourceTypeID, + ResourceTypeId: input.ResourceTypeID, ParentResourceId: parentResourceID, PageSize: pageSize, - PageToken: pageToken, + PageToken: input.PageToken, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list resources: %v", err)), nil + return nil, ListResourcesOutput{}, fmt.Errorf("failed to list resources: %w", err) } resources, err := protoListToMaps(resp.GetList()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resources: %v", err)), nil - } - - result := map[string]any{ - "resources": resources, - "next_page_token": resp.GetNextPageToken(), - "has_more": resp.GetNextPageToken() != "", - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ListResourcesOutput{}, fmt.Errorf("failed to serialize resources: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ListResourcesOutput{ + Resources: resources, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil } -// handleGetResource handles the get_resource tool. -func (m *MCPServer) handleGetResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType, err := req.RequireString("resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil - } - - resourceID, err := req.RequireString("resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil - } - +func (m *MCPServer) handleGetResource(ctx context.Context, req *mcp.CallToolRequest, input ResourceInput) (*mcp.CallToolResult, ResourceOutput, error) { resp, err := m.connector.GetResource(ctx, &v2.ResourceGetterServiceGetResourceRequest{ ResourceId: &v2.ResourceId{ - ResourceType: resourceType, - Resource: resourceID, + ResourceType: input.ResourceType, + Resource: input.ResourceID, }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get resource: %v", err)), nil + return nil, ResourceOutput{}, fmt.Errorf("failed to get resource: %w", err) } resource, err := protoToMap(resp.GetResource()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource: %v", err)), nil - } - - jsonBytes, err := json.Marshal(resource) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ResourceOutput{}, fmt.Errorf("failed to serialize resource: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ResourceOutput{Resource: resource}, nil } -// handleListEntitlements handles the list_entitlements tool. -func (m *MCPServer) handleListEntitlements(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType, err := req.RequireString("resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil - } - - resourceID, err := req.RequireString("resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil +func (m *MCPServer) handleListEntitlements(ctx context.Context, req *mcp.CallToolRequest, input ResourcePaginationInput) (*mcp.CallToolResult, ListEntitlementsOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) } - pageSize := getPageSize(req) - pageToken := getStringArg(req, "page_token") - resp, err := m.connector.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{ Resource: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: resourceType, - Resource: resourceID, + ResourceType: input.ResourceType, + Resource: input.ResourceID, }, }, PageSize: pageSize, - PageToken: pageToken, + PageToken: input.PageToken, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list entitlements: %v", err)), nil + return nil, ListEntitlementsOutput{}, fmt.Errorf("failed to list entitlements: %w", err) } entitlements, err := protoListToMaps(resp.GetList()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize entitlements: %v", err)), nil - } - - result := map[string]any{ - "entitlements": entitlements, - "next_page_token": resp.GetNextPageToken(), - "has_more": resp.GetNextPageToken() != "", - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ListEntitlementsOutput{}, fmt.Errorf("failed to serialize entitlements: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ListEntitlementsOutput{ + Entitlements: entitlements, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil } -// handleListGrants handles the list_grants tool. -func (m *MCPServer) handleListGrants(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType, err := req.RequireString("resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil - } - - resourceID, err := req.RequireString("resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil +func (m *MCPServer) handleListGrants(ctx context.Context, req *mcp.CallToolRequest, input ResourcePaginationInput) (*mcp.CallToolResult, ListGrantsOutput, error) { + pageSize := uint32(defaultPageSize) + if input.PageSize > 0 { + pageSize = uint32(input.PageSize) } - pageSize := getPageSize(req) - pageToken := getStringArg(req, "page_token") - resp, err := m.connector.ListGrants(ctx, &v2.GrantsServiceListGrantsRequest{ Resource: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: resourceType, - Resource: resourceID, + ResourceType: input.ResourceType, + Resource: input.ResourceID, }, }, PageSize: pageSize, - PageToken: pageToken, + PageToken: input.PageToken, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list grants: %v", err)), nil + return nil, ListGrantsOutput{}, fmt.Errorf("failed to list grants: %w", err) } grants, err := protoListToMaps(resp.GetList()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize grants: %v", err)), nil - } - - result := map[string]any{ - "grants": grants, - "next_page_token": resp.GetNextPageToken(), - "has_more": resp.GetNextPageToken() != "", - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ListGrantsOutput{}, fmt.Errorf("failed to serialize grants: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ListGrantsOutput{ + Grants: grants, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil } -// handleGrant handles the grant tool. -func (m *MCPServer) handleGrant(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - entResourceType, err := req.RequireString("entitlement_resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_type is required: %v", err)), nil - } - - entResourceID, err := req.RequireString("entitlement_resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_id is required: %v", err)), nil - } - - entID, err := req.RequireString("entitlement_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_id is required: %v", err)), nil - } - - principalType, err := req.RequireString("principal_resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("principal_resource_type is required: %v", err)), nil - } - - principalID, err := req.RequireString("principal_resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("principal_resource_id is required: %v", err)), nil - } - +func (m *MCPServer) handleGrant(ctx context.Context, req *mcp.CallToolRequest, input GrantInput) (*mcp.CallToolResult, GrantOutput, error) { resp, err := m.connector.Grant(ctx, &v2.GrantManagerServiceGrantRequest{ Entitlement: &v2.Entitlement{ Resource: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: entResourceType, - Resource: entResourceID, + ResourceType: input.EntitlementResourceType, + Resource: input.EntitlementResourceID, }, }, - Id: entID, + Id: input.EntitlementID, }, Principal: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: principalType, - Resource: principalID, + ResourceType: input.PrincipalResourceType, + Resource: input.PrincipalResourceID, }, }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("grant failed: %v", err)), nil + return nil, GrantOutput{}, fmt.Errorf("grant failed: %w", err) } grants, err := protoListToMaps(resp.GetGrants()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize grants: %v", err)), nil - } - - result := map[string]any{ - "grants": grants, - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, GrantOutput{}, fmt.Errorf("failed to serialize grants: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, GrantOutput{Grants: grants}, nil } -// handleRevoke handles the revoke tool. -func (m *MCPServer) handleRevoke(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - grantID, err := req.RequireString("grant_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("grant_id is required: %v", err)), nil - } - - entResourceType, err := req.RequireString("entitlement_resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_type is required: %v", err)), nil - } - - entResourceID, err := req.RequireString("entitlement_resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_resource_id is required: %v", err)), nil - } - - entID, err := req.RequireString("entitlement_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("entitlement_id is required: %v", err)), nil - } - - principalType, err := req.RequireString("principal_resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("principal_resource_type is required: %v", err)), nil - } - - principalID, err := req.RequireString("principal_resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("principal_resource_id is required: %v", err)), nil - } - - _, err = m.connector.Revoke(ctx, &v2.GrantManagerServiceRevokeRequest{ +func (m *MCPServer) handleRevoke(ctx context.Context, req *mcp.CallToolRequest, input RevokeInput) (*mcp.CallToolResult, SuccessOutput, error) { + _, err := m.connector.Revoke(ctx, &v2.GrantManagerServiceRevokeRequest{ Grant: &v2.Grant{ - Id: grantID, + Id: input.GrantID, Entitlement: &v2.Entitlement{ Resource: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: entResourceType, - Resource: entResourceID, + ResourceType: input.EntitlementResourceType, + Resource: input.EntitlementResourceID, }, }, - Id: entID, + Id: input.EntitlementID, }, Principal: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: principalType, - Resource: principalID, + ResourceType: input.PrincipalResourceType, + Resource: input.PrincipalResourceID, }, }, }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("revoke failed: %v", err)), nil - } - - result := map[string]any{ - "success": true, - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, SuccessOutput{}, fmt.Errorf("revoke failed: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, SuccessOutput{Success: true}, nil } -// handleCreateResource handles the create_resource tool. -func (m *MCPServer) handleCreateResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType, err := req.RequireString("resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil - } - - displayName, err := req.RequireString("display_name") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("display_name is required: %v", err)), nil - } - - // Build parent resource if specified. +func (m *MCPServer) handleCreateResource(ctx context.Context, req *mcp.CallToolRequest, input CreateResourceInput) (*mcp.CallToolResult, ResourceOutput, error) { var parentResource *v2.Resource - parentType := getStringArg(req, "parent_resource_type") - parentID := getStringArg(req, "parent_resource_id") - if parentType != "" && parentID != "" { + if input.ParentResourceType != "" && input.ParentResourceID != "" { parentResource = &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: parentType, - Resource: parentID, + ResourceType: input.ParentResourceType, + Resource: input.ParentResourceID, }, } } @@ -426,180 +381,90 @@ func (m *MCPServer) handleCreateResource(ctx context.Context, req mcp.CallToolRe resp, err := m.connector.CreateResource(ctx, &v2.CreateResourceRequest{ Resource: &v2.Resource{ Id: &v2.ResourceId{ - ResourceType: resourceType, + ResourceType: input.ResourceType, }, - DisplayName: displayName, + DisplayName: input.DisplayName, ParentResourceId: parentResource.GetId(), }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("create resource failed: %v", err)), nil + return nil, ResourceOutput{}, fmt.Errorf("create resource failed: %w", err) } resource, err := protoToMap(resp.GetCreated()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize resource: %v", err)), nil - } - - jsonBytes, err := json.Marshal(resource) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ResourceOutput{}, fmt.Errorf("failed to serialize resource: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ResourceOutput{Resource: resource}, nil } -// handleDeleteResource handles the delete_resource tool. -func (m *MCPServer) handleDeleteResource(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType, err := req.RequireString("resource_type") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_type is required: %v", err)), nil - } - - resourceID, err := req.RequireString("resource_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("resource_id is required: %v", err)), nil - } - - _, err = m.connector.DeleteResource(ctx, &v2.DeleteResourceRequest{ +func (m *MCPServer) handleDeleteResource(ctx context.Context, req *mcp.CallToolRequest, input DeleteResourceInput) (*mcp.CallToolResult, SuccessOutput, error) { + _, err := m.connector.DeleteResource(ctx, &v2.DeleteResourceRequest{ ResourceId: &v2.ResourceId{ - ResourceType: resourceType, - Resource: resourceID, + ResourceType: input.ResourceType, + Resource: input.ResourceID, }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("delete resource failed: %v", err)), nil - } - - result := map[string]any{ - "success": true, - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, SuccessOutput{}, fmt.Errorf("delete resource failed: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, SuccessOutput{Success: true}, nil } -// handleListTicketSchemas handles the list_ticket_schemas tool. -func (m *MCPServer) handleListTicketSchemas(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (m *MCPServer) handleListTicketSchemas(ctx context.Context, req *mcp.CallToolRequest, input EmptyInput) (*mcp.CallToolResult, ListTicketSchemasOutput, error) { resp, err := m.connector.ListTicketSchemas(ctx, &v2.TicketsServiceListTicketSchemasRequest{}) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list ticket schemas: %v", err)), nil + return nil, ListTicketSchemasOutput{}, fmt.Errorf("failed to list ticket schemas: %w", err) } schemas, err := protoListToMaps(resp.GetList()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket schemas: %v", err)), nil - } - - result := map[string]any{ - "schemas": schemas, - "next_page_token": resp.GetNextPageToken(), - "has_more": resp.GetNextPageToken() != "", - } - - jsonBytes, err := json.Marshal(result) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, ListTicketSchemasOutput{}, fmt.Errorf("failed to serialize ticket schemas: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, ListTicketSchemasOutput{ + Schemas: schemas, + NextPageToken: resp.GetNextPageToken(), + HasMore: resp.GetNextPageToken() != "", + }, nil } -// handleCreateTicket handles the create_ticket tool. -func (m *MCPServer) handleCreateTicket(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - schemaID, err := req.RequireString("schema_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("schema_id is required: %v", err)), nil - } - - displayName, err := req.RequireString("display_name") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("display_name is required: %v", err)), nil - } - - description := getStringArg(req, "description") - +func (m *MCPServer) handleCreateTicket(ctx context.Context, req *mcp.CallToolRequest, input CreateTicketInput) (*mcp.CallToolResult, TicketOutput, error) { resp, err := m.connector.CreateTicket(ctx, &v2.TicketsServiceCreateTicketRequest{ Schema: &v2.TicketSchema{ - Id: schemaID, + Id: input.SchemaID, }, Request: &v2.TicketRequest{ - DisplayName: displayName, - Description: description, + DisplayName: input.DisplayName, + Description: input.Description, }, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("create ticket failed: %v", err)), nil + return nil, TicketOutput{}, fmt.Errorf("create ticket failed: %w", err) } ticket, err := protoToMap(resp.GetTicket()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket: %v", err)), nil - } - - jsonBytes, err := json.Marshal(ticket) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, TicketOutput{}, fmt.Errorf("failed to serialize ticket: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil + return nil, TicketOutput{Ticket: ticket}, nil } -// handleGetTicket handles the get_ticket tool. -func (m *MCPServer) handleGetTicket(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - ticketID, err := req.RequireString("ticket_id") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("ticket_id is required: %v", err)), nil - } - +func (m *MCPServer) handleGetTicket(ctx context.Context, req *mcp.CallToolRequest, input GetTicketInput) (*mcp.CallToolResult, TicketOutput, error) { resp, err := m.connector.GetTicket(ctx, &v2.TicketsServiceGetTicketRequest{ - Id: ticketID, + Id: input.TicketID, }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("get ticket failed: %v", err)), nil + return nil, TicketOutput{}, fmt.Errorf("get ticket failed: %w", err) } ticket, err := protoToMap(resp.GetTicket()) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to serialize ticket: %v", err)), nil - } - - jsonBytes, err := json.Marshal(ticket) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return nil, TicketOutput{}, fmt.Errorf("failed to serialize ticket: %w", err) } - return mcp.NewToolResultText(string(jsonBytes)), nil -} - -// Helper functions. - -func getPageSize(req mcp.CallToolRequest) uint32 { - args := req.GetArguments() - if args == nil { - return defaultPageSize - } - if ps, ok := args["page_size"]; ok { - if psFloat, ok := ps.(float64); ok { - return uint32(psFloat) - } - } - return defaultPageSize -} - -func getStringArg(req mcp.CallToolRequest, name string) string { - args := req.GetArguments() - if args == nil { - return "" - } - if v, ok := args[name]; ok { - if s, ok := v.(string); ok { - return s - } - } - return "" + return nil, TicketOutput{Ticket: ticket}, nil } diff --git a/pkg/mcp/server.go b/pkg/mcp/server.go index 42a11fb52..cfc7b4d63 100644 --- a/pkg/mcp/server.go +++ b/pkg/mcp/server.go @@ -4,8 +4,7 @@ import ( "context" "fmt" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/modelcontextprotocol/go-sdk/mcp" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" "github.com/conductorone/baton-sdk/pkg/types" @@ -14,7 +13,7 @@ import ( // MCPServer wraps a ConnectorServer and exposes its functionality via MCP. type MCPServer struct { connector types.ConnectorServer - server *server.MCPServer + server *mcp.Server caps *v2.ConnectorCapabilities } @@ -26,11 +25,12 @@ func NewMCPServer(ctx context.Context, name string, connector types.ConnectorSer return nil, fmt.Errorf("failed to get connector metadata: %w", err) } - s := server.NewMCPServer( - name, - "1.0.0", - server.WithToolCapabilities(false), - server.WithRecovery(), + s := mcp.NewServer( + &mcp.Implementation{ + Name: name, + Version: "1.0.0", + }, + nil, ) m := &MCPServer{ @@ -44,8 +44,8 @@ func NewMCPServer(ctx context.Context, name string, connector types.ConnectorSer } // Serve starts the MCP server on stdio. -func (m *MCPServer) Serve() error { - return server.ServeStdio(m.server) +func (m *MCPServer) Serve(ctx context.Context) error { + return m.server.Run(ctx, &mcp.StdioTransport{}) } // registerTools registers all MCP tools based on connector capabilities. @@ -79,260 +79,93 @@ func (m *MCPServer) hasCapability(cap v2.Capability) bool { // registerReadTools registers read-only tools that are always available. func (m *MCPServer) registerReadTools() { - // get_metadata - Get connector metadata and capabilities. - m.server.AddTool( - mcp.NewTool("get_metadata", - mcp.WithDescription("Get connector metadata including display name, description, and capabilities"), - ), - m.handleGetMetadata, - ) - - // validate - Validate connector configuration. - m.server.AddTool( - mcp.NewTool("validate", - mcp.WithDescription("Validate the connector configuration and connectivity"), - ), - m.handleValidate, - ) - - // list_resource_types - List available resource types. - m.server.AddTool( - mcp.NewTool("list_resource_types", - mcp.WithDescription("List all resource types supported by this connector"), - mcp.WithNumber("page_size", - mcp.Description("Number of items per page (default 50)"), - ), - mcp.WithString("page_token", - mcp.Description("Pagination token from previous response"), - ), - ), - m.handleListResourceTypes, - ) - - // list_resources - List resources of a specific type. - m.server.AddTool( - mcp.NewTool("list_resources", - mcp.WithDescription("List resources of a specific type"), - mcp.WithString("resource_type_id", - mcp.Required(), - mcp.Description("The resource type ID to list (e.g., 'user', 'group')"), - ), - mcp.WithString("parent_resource_type", - mcp.Description("Parent resource type (optional, for hierarchical resources)"), - ), - mcp.WithString("parent_resource_id", - mcp.Description("Parent resource ID (optional, for hierarchical resources)"), - ), - mcp.WithNumber("page_size", - mcp.Description("Number of items per page (default 50)"), - ), - mcp.WithString("page_token", - mcp.Description("Pagination token from previous response"), - ), - ), - m.handleListResources, - ) - - // get_resource - Get a specific resource by ID. - m.server.AddTool( - mcp.NewTool("get_resource", - mcp.WithDescription("Get a specific resource by its type and ID"), - mcp.WithString("resource_type", - mcp.Required(), - mcp.Description("The resource type (e.g., 'user', 'group')"), - ), - mcp.WithString("resource_id", - mcp.Required(), - mcp.Description("The resource ID"), - ), - ), - m.handleGetResource, - ) - - // list_entitlements - List entitlements for a resource. - m.server.AddTool( - mcp.NewTool("list_entitlements", - mcp.WithDescription("List entitlements (permissions, roles, memberships) for a resource"), - mcp.WithString("resource_type", - mcp.Required(), - mcp.Description("The resource type"), - ), - mcp.WithString("resource_id", - mcp.Required(), - mcp.Description("The resource ID"), - ), - mcp.WithNumber("page_size", - mcp.Description("Number of items per page (default 50)"), - ), - mcp.WithString("page_token", - mcp.Description("Pagination token from previous response"), - ), - ), - m.handleListEntitlements, - ) - - // list_grants - List grants for a resource. - m.server.AddTool( - mcp.NewTool("list_grants", - mcp.WithDescription("List grants (who has what access) for a resource"), - mcp.WithString("resource_type", - mcp.Required(), - mcp.Description("The resource type"), - ), - mcp.WithString("resource_id", - mcp.Required(), - mcp.Description("The resource ID"), - ), - mcp.WithNumber("page_size", - mcp.Description("Number of items per page (default 50)"), - ), - mcp.WithString("page_token", - mcp.Description("Pagination token from previous response"), - ), - ), - m.handleListGrants, - ) + // get_metadata + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_metadata", + Description: "Get connector metadata including display name, description, and capabilities", + }, m.handleGetMetadata) + + // validate + mcp.AddTool(m.server, &mcp.Tool{ + Name: "validate", + Description: "Validate the connector configuration and connectivity", + }, m.handleValidate) + + // list_resource_types + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_resource_types", + Description: "List all resource types supported by this connector", + }, m.handleListResourceTypes) + + // list_resources + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_resources", + Description: "List resources of a specific type", + }, m.handleListResources) + + // get_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_resource", + Description: "Get a specific resource by its type and ID", + }, m.handleGetResource) + + // list_entitlements + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_entitlements", + Description: "List entitlements (permissions, roles, memberships) for a resource", + }, m.handleListEntitlements) + + // list_grants + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_grants", + Description: "List grants (who has what access) for a resource", + }, m.handleListGrants) } // registerProvisioningTools registers tools for provisioning operations. func (m *MCPServer) registerProvisioningTools() { - // grant - Grant an entitlement to a principal. - m.server.AddTool( - mcp.NewTool("grant", - mcp.WithDescription("Grant an entitlement to a principal (user or group)"), - mcp.WithString("entitlement_resource_type", - mcp.Required(), - mcp.Description("Resource type of the entitlement"), - ), - mcp.WithString("entitlement_resource_id", - mcp.Required(), - mcp.Description("Resource ID of the entitlement"), - ), - mcp.WithString("entitlement_id", - mcp.Required(), - mcp.Description("The entitlement ID"), - ), - mcp.WithString("principal_resource_type", - mcp.Required(), - mcp.Description("Resource type of the principal (e.g., 'user', 'group')"), - ), - mcp.WithString("principal_resource_id", - mcp.Required(), - mcp.Description("Resource ID of the principal"), - ), - ), - m.handleGrant, - ) - - // revoke - Revoke a grant. - m.server.AddTool( - mcp.NewTool("revoke", - mcp.WithDescription("Revoke a grant from a principal"), - mcp.WithString("grant_id", - mcp.Required(), - mcp.Description("The grant ID to revoke"), - ), - mcp.WithString("entitlement_resource_type", - mcp.Required(), - mcp.Description("Resource type of the entitlement"), - ), - mcp.WithString("entitlement_resource_id", - mcp.Required(), - mcp.Description("Resource ID of the entitlement"), - ), - mcp.WithString("entitlement_id", - mcp.Required(), - mcp.Description("The entitlement ID"), - ), - mcp.WithString("principal_resource_type", - mcp.Required(), - mcp.Description("Resource type of the principal"), - ), - mcp.WithString("principal_resource_id", - mcp.Required(), - mcp.Description("Resource ID of the principal"), - ), - ), - m.handleRevoke, - ) - - // create_resource - Create a new resource. - m.server.AddTool( - mcp.NewTool("create_resource", - mcp.WithDescription("Create a new resource"), - mcp.WithString("resource_type", - mcp.Required(), - mcp.Description("The resource type to create"), - ), - mcp.WithString("display_name", - mcp.Required(), - mcp.Description("Display name for the new resource"), - ), - mcp.WithString("parent_resource_type", - mcp.Description("Parent resource type (optional)"), - ), - mcp.WithString("parent_resource_id", - mcp.Description("Parent resource ID (optional)"), - ), - ), - m.handleCreateResource, - ) - - // delete_resource - Delete a resource. - m.server.AddTool( - mcp.NewTool("delete_resource", - mcp.WithDescription("Delete a resource"), - mcp.WithString("resource_type", - mcp.Required(), - mcp.Description("The resource type"), - ), - mcp.WithString("resource_id", - mcp.Required(), - mcp.Description("The resource ID to delete"), - ), - ), - m.handleDeleteResource, - ) + // grant + mcp.AddTool(m.server, &mcp.Tool{ + Name: "grant", + Description: "Grant an entitlement to a principal (user or group)", + }, m.handleGrant) + + // revoke + mcp.AddTool(m.server, &mcp.Tool{ + Name: "revoke", + Description: "Revoke a grant from a principal", + }, m.handleRevoke) + + // create_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "create_resource", + Description: "Create a new resource", + }, m.handleCreateResource) + + // delete_resource + mcp.AddTool(m.server, &mcp.Tool{ + Name: "delete_resource", + Description: "Delete a resource", + }, m.handleDeleteResource) } // registerTicketingTools registers tools for ticketing operations. func (m *MCPServer) registerTicketingTools() { - // list_ticket_schemas - List available ticket schemas. - m.server.AddTool( - mcp.NewTool("list_ticket_schemas", - mcp.WithDescription("List available ticket schemas"), - ), - m.handleListTicketSchemas, - ) - - // create_ticket - Create a new ticket. - m.server.AddTool( - mcp.NewTool("create_ticket", - mcp.WithDescription("Create a new ticket"), - mcp.WithString("schema_id", - mcp.Required(), - mcp.Description("The ticket schema ID"), - ), - mcp.WithString("display_name", - mcp.Required(), - mcp.Description("Display name for the ticket"), - ), - mcp.WithString("description", - mcp.Description("Description of the ticket"), - ), - ), - m.handleCreateTicket, - ) - - // get_ticket - Get a ticket by ID. - m.server.AddTool( - mcp.NewTool("get_ticket", - mcp.WithDescription("Get a ticket by ID"), - mcp.WithString("ticket_id", - mcp.Required(), - mcp.Description("The ticket ID"), - ), - ), - m.handleGetTicket, - ) + // list_ticket_schemas + mcp.AddTool(m.server, &mcp.Tool{ + Name: "list_ticket_schemas", + Description: "List available ticket schemas", + }, m.handleListTicketSchemas) + + // create_ticket + mcp.AddTool(m.server, &mcp.Tool{ + Name: "create_ticket", + Description: "Create a new ticket", + }, m.handleCreateTicket) + + // get_ticket + mcp.AddTool(m.server, &mcp.Tool{ + Name: "get_ticket", + Description: "Get a ticket by ID", + }, m.handleGetTicket) } diff --git a/vendor/github.com/bahlo/generic-list-go/LICENSE b/vendor/github.com/bahlo/generic-list-go/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/vendor/github.com/bahlo/generic-list-go/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/bahlo/generic-list-go/README.md b/vendor/github.com/bahlo/generic-list-go/README.md deleted file mode 100644 index 68bbce9fb..000000000 --- a/vendor/github.com/bahlo/generic-list-go/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# generic-list-go [![CI](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml/badge.svg)](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml) - -Go [container/list](https://pkg.go.dev/container/list) but with generics. - -The code is based on `container/list` in `go1.18beta2`. diff --git a/vendor/github.com/bahlo/generic-list-go/list.go b/vendor/github.com/bahlo/generic-list-go/list.go deleted file mode 100644 index a06a7c612..000000000 --- a/vendor/github.com/bahlo/generic-list-go/list.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package list implements a doubly linked list. -// -// To iterate over a list (where l is a *List): -// for e := l.Front(); e != nil; e = e.Next() { -// // do something with e.Value -// } -// -package list - -// Element is an element of a linked list. -type Element[T any] struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *Element[T] - - // The list to which this element belongs. - list *List[T] - - // The value stored with this element. - Value T -} - -// Next returns the next list element or nil. -func (e *Element[T]) Next() *Element[T] { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *Element[T]) Prev() *Element[T] { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// List represents a doubly linked list. -// The zero value for List is an empty list ready to use. -type List[T any] struct { - root Element[T] // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *List[T]) Init() *List[T] { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// New returns an initialized list. -func New[T any]() *List[T] { return new(List[T]).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *List[T]) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *List[T]) Front() *Element[T] { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *List[T]) Back() *Element[T] { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *List[T]) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *List[T]) insert(e, at *Element[T]) *Element[T] { - e.prev = at - e.next = at.next - e.prev.next = e - e.next.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { - return l.insert(&Element[T]{Value: v}, at) -} - -// remove removes e from its list, decrements l.len -func (l *List[T]) remove(e *Element[T]) { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- -} - -// move moves e to next to at. -func (l *List[T]) move(e, at *Element[T]) { - if e == at { - return - } - e.prev.next = e.next - e.next.prev = e.prev - - e.prev = at - e.next = at.next - e.prev.next = e - e.next.prev = e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *List[T]) Remove(e *Element[T]) T { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *List[T]) PushFront(v T) *Element[T] { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *List[T]) PushBack(v T) *Element[T] { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *List[T]) MoveToFront(e *Element[T]) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.move(e, &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *List[T]) MoveToBack(e *Element[T]) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.move(e, l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *List[T]) MoveBefore(e, mark *Element[T]) { - if e.list != l || e == mark || mark.list != l { - return - } - l.move(e, mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *List[T]) MoveAfter(e, mark *Element[T]) { - if e.list != l || e == mark || mark.list != l { - return - } - l.move(e, mark) -} - -// PushBackList inserts a copy of another list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *List[T]) PushBackList(other *List[T]) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of another list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *List[T]) PushFrontList(other *List[T]) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/vendor/github.com/buger/jsonparser/.gitignore b/vendor/github.com/buger/jsonparser/.gitignore deleted file mode 100644 index 5598d8a56..000000000 --- a/vendor/github.com/buger/jsonparser/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ - -*.test - -*.out - -*.mprof - -.idea - -vendor/github.com/buger/goterm/ -prof.cpu -prof.mem diff --git a/vendor/github.com/buger/jsonparser/.travis.yml b/vendor/github.com/buger/jsonparser/.travis.yml deleted file mode 100644 index dbfb7cf98..000000000 --- a/vendor/github.com/buger/jsonparser/.travis.yml +++ /dev/null @@ -1,11 +0,0 @@ -language: go -arch: - - amd64 - - ppc64le -go: - - 1.7.x - - 1.8.x - - 1.9.x - - 1.10.x - - 1.11.x -script: go test -v ./. diff --git a/vendor/github.com/buger/jsonparser/Dockerfile b/vendor/github.com/buger/jsonparser/Dockerfile deleted file mode 100644 index 37fc9fd0b..000000000 --- a/vendor/github.com/buger/jsonparser/Dockerfile +++ /dev/null @@ -1,12 +0,0 @@ -FROM golang:1.6 - -RUN go get github.com/Jeffail/gabs -RUN go get github.com/bitly/go-simplejson -RUN go get github.com/pquerna/ffjson -RUN go get github.com/antonholmquist/jason -RUN go get github.com/mreiferson/go-ujson -RUN go get -tags=unsafe -u github.com/ugorji/go/codec -RUN go get github.com/mailru/easyjson - -WORKDIR /go/src/github.com/buger/jsonparser -ADD . /go/src/github.com/buger/jsonparser \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/Makefile b/vendor/github.com/buger/jsonparser/Makefile deleted file mode 100644 index e843368cf..000000000 --- a/vendor/github.com/buger/jsonparser/Makefile +++ /dev/null @@ -1,36 +0,0 @@ -SOURCE = parser.go -CONTAINER = jsonparser -SOURCE_PATH = /go/src/github.com/buger/jsonparser -BENCHMARK = JsonParser -BENCHTIME = 5s -TEST = . -DRUN = docker run -v `pwd`:$(SOURCE_PATH) -i -t $(CONTAINER) - -build: - docker build -t $(CONTAINER) . - -race: - $(DRUN) --env GORACE="halt_on_error=1" go test ./. $(ARGS) -v -race -timeout 15s - -bench: - $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -benchtime $(BENCHTIME) -v - -bench_local: - $(DRUN) go test $(LDFLAGS) -test.benchmem -bench . $(ARGS) -benchtime $(BENCHTIME) -v - -profile: - $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -memprofile mem.mprof -v - $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -cpuprofile cpu.out -v - $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -c - -test: - $(DRUN) go test $(LDFLAGS) ./ -run $(TEST) -timeout 10s $(ARGS) -v - -fmt: - $(DRUN) go fmt ./... - -vet: - $(DRUN) go vet ./. - -bash: - $(DRUN) /bin/bash \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/README.md b/vendor/github.com/buger/jsonparser/README.md deleted file mode 100644 index d7e0ec397..000000000 --- a/vendor/github.com/buger/jsonparser/README.md +++ /dev/null @@ -1,365 +0,0 @@ -[![Go Report Card](https://goreportcard.com/badge/github.com/buger/jsonparser)](https://goreportcard.com/report/github.com/buger/jsonparser) ![License](https://img.shields.io/dub/l/vibe-d.svg) -# Alternative JSON parser for Go (10x times faster standard library) - -It does not require you to know the structure of the payload (eg. create structs), and allows accessing fields by providing the path to them. It is up to **10 times faster** than standard `encoding/json` package (depending on payload size and usage), **allocates no memory**. See benchmarks below. - -## Rationale -Originally I made this for a project that relies on a lot of 3rd party APIs that can be unpredictable and complex. -I love simplicity and prefer to avoid external dependecies. `encoding/json` requires you to know exactly your data structures, or if you prefer to use `map[string]interface{}` instead, it will be very slow and hard to manage. -I investigated what's on the market and found that most libraries are just wrappers around `encoding/json`, there is few options with own parsers (`ffjson`, `easyjson`), but they still requires you to create data structures. - - -Goal of this project is to push JSON parser to the performance limits and not sacrifice with compliance and developer user experience. - -## Example -For the given JSON our goal is to extract the user's full name, number of github followers and avatar. - -```go -import "github.com/buger/jsonparser" - -... - -data := []byte(`{ - "person": { - "name": { - "first": "Leonid", - "last": "Bugaev", - "fullName": "Leonid Bugaev" - }, - "github": { - "handle": "buger", - "followers": 109 - }, - "avatars": [ - { "url": "https://avatars1.githubusercontent.com/u/14009?v=3&s=460", "type": "thumbnail" } - ] - }, - "company": { - "name": "Acme" - } -}`) - -// You can specify key path by providing arguments to Get function -jsonparser.Get(data, "person", "name", "fullName") - -// There is `GetInt` and `GetBoolean` helpers if you exactly know key data type -jsonparser.GetInt(data, "person", "github", "followers") - -// When you try to get object, it will return you []byte slice pointer to data containing it -// In `company` it will be `{"name": "Acme"}` -jsonparser.Get(data, "company") - -// If the key doesn't exist it will throw an error -var size int64 -if value, err := jsonparser.GetInt(data, "company", "size"); err == nil { - size = value -} - -// You can use `ArrayEach` helper to iterate items [item1, item2 .... itemN] -jsonparser.ArrayEach(data, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { - fmt.Println(jsonparser.Get(value, "url")) -}, "person", "avatars") - -// Or use can access fields by index! -jsonparser.GetString(data, "person", "avatars", "[0]", "url") - -// You can use `ObjectEach` helper to iterate objects { "key1":object1, "key2":object2, .... "keyN":objectN } -jsonparser.ObjectEach(data, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { - fmt.Printf("Key: '%s'\n Value: '%s'\n Type: %s\n", string(key), string(value), dataType) - return nil -}, "person", "name") - -// The most efficient way to extract multiple keys is `EachKey` - -paths := [][]string{ - []string{"person", "name", "fullName"}, - []string{"person", "avatars", "[0]", "url"}, - []string{"company", "url"}, -} -jsonparser.EachKey(data, func(idx int, value []byte, vt jsonparser.ValueType, err error){ - switch idx { - case 0: // []string{"person", "name", "fullName"} - ... - case 1: // []string{"person", "avatars", "[0]", "url"} - ... - case 2: // []string{"company", "url"}, - ... - } -}, paths...) - -// For more information see docs below -``` - -## Need to speedup your app? - -I'm available for consulting and can help you push your app performance to the limits. Ping me at: leonsbox@gmail.com. - -## Reference - -Library API is really simple. You just need the `Get` method to perform any operation. The rest is just helpers around it. - -You also can view API at [godoc.org](https://godoc.org/github.com/buger/jsonparser) - - -### **`Get`** -```go -func Get(data []byte, keys ...string) (value []byte, dataType jsonparser.ValueType, offset int, err error) -``` -Receives data structure, and key path to extract value from. - -Returns: -* `value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error -* `dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` -* `offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. -* `err` - If the key is not found or any other parsing issue, it should return error. If key not found it also sets `dataType` to `NotExist` - -Accepts multiple keys to specify path to JSON value (in case of quering nested structures). -If no keys are provided it will try to extract the closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. - -Note that keys can be an array indexes: `jsonparser.GetInt("person", "avatars", "[0]", "url")`, pretty cool, yeah? - -### **`GetString`** -```go -func GetString(data []byte, keys ...string) (val string, err error) -``` -Returns strings properly handing escaped and unicode characters. Note that this will cause additional memory allocations. - -### **`GetUnsafeString`** -If you need string in your app, and ready to sacrifice with support of escaped symbols in favor of speed. It returns string mapped to existing byte slice memory, without any allocations: -```go -s, _, := jsonparser.GetUnsafeString(data, "person", "name", "title") -switch s { - case 'CEO': - ... - case 'Engineer' - ... - ... -} -``` -Note that `unsafe` here means that your string will exist until GC will free underlying byte slice, for most of cases it means that you can use this string only in current context, and should not pass it anywhere externally: through channels or any other way. - - -### **`GetBoolean`**, **`GetInt`** and **`GetFloat`** -```go -func GetBoolean(data []byte, keys ...string) (val bool, err error) - -func GetFloat(data []byte, keys ...string) (val float64, err error) - -func GetInt(data []byte, keys ...string) (val int64, err error) -``` -If you know the key type, you can use the helpers above. -If key data type do not match, it will return error. - -### **`ArrayEach`** -```go -func ArrayEach(data []byte, cb func(value []byte, dataType jsonparser.ValueType, offset int, err error), keys ...string) -``` -Needed for iterating arrays, accepts a callback function with the same return arguments as `Get`. - -### **`ObjectEach`** -```go -func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) -``` -Needed for iterating object, accepts a callback function. Example: -```go -var handler func([]byte, []byte, jsonparser.ValueType, int) error -handler = func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { - //do stuff here -} -jsonparser.ObjectEach(myJson, handler) -``` - - -### **`EachKey`** -```go -func EachKey(data []byte, cb func(idx int, value []byte, dataType jsonparser.ValueType, err error), paths ...[]string) -``` -When you need to read multiple keys, and you do not afraid of low-level API `EachKey` is your friend. It read payload only single time, and calls callback function once path is found. For example when you call multiple times `Get`, it has to process payload multiple times, each time you call it. Depending on payload `EachKey` can be multiple times faster than `Get`. Path can use nested keys as well! - -```go -paths := [][]string{ - []string{"uuid"}, - []string{"tz"}, - []string{"ua"}, - []string{"st"}, -} -var data SmallPayload - -jsonparser.EachKey(smallFixture, func(idx int, value []byte, vt jsonparser.ValueType, err error){ - switch idx { - case 0: - data.Uuid, _ = value - case 1: - v, _ := jsonparser.ParseInt(value) - data.Tz = int(v) - case 2: - data.Ua, _ = value - case 3: - v, _ := jsonparser.ParseInt(value) - data.St = int(v) - } -}, paths...) -``` - -### **`Set`** -```go -func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) -``` -Receives existing data structure, key path to set, and value to set at that key. *This functionality is experimental.* - -Returns: -* `value` - Pointer to original data structure with updated or added key value. -* `err` - If any parsing issue, it should return error. - -Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). - -Note that keys can be an array indexes: `jsonparser.Set(data, []byte("http://github.com"), "person", "avatars", "[0]", "url")` - -### **`Delete`** -```go -func Delete(data []byte, keys ...string) value []byte -``` -Receives existing data structure, and key path to delete. *This functionality is experimental.* - -Returns: -* `value` - Pointer to original data structure with key path deleted if it can be found. If there is no key path, then the whole data structure is deleted. - -Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). - -Note that keys can be an array indexes: `jsonparser.Delete(data, "person", "avatars", "[0]", "url")` - - -## What makes it so fast? -* It does not rely on `encoding/json`, `reflection` or `interface{}`, the only real package dependency is `bytes`. -* Operates with JSON payload on byte level, providing you pointers to the original data structure: no memory allocation. -* No automatic type conversions, by default everything is a []byte, but it provides you value type, so you can convert by yourself (there is few helpers included). -* Does not parse full record, only keys you specified - - -## Benchmarks - -There are 3 benchmark types, trying to simulate real-life usage for small, medium and large JSON payloads. -For each metric, the lower value is better. Time/op is in nanoseconds. Values better than standard encoding/json marked as bold text. -Benchmarks run on standard Linode 1024 box. - -Compared libraries: -* https://golang.org/pkg/encoding/json -* https://github.com/Jeffail/gabs -* https://github.com/a8m/djson -* https://github.com/bitly/go-simplejson -* https://github.com/antonholmquist/jason -* https://github.com/mreiferson/go-ujson -* https://github.com/ugorji/go/codec -* https://github.com/pquerna/ffjson -* https://github.com/mailru/easyjson -* https://github.com/buger/jsonparser - -#### TLDR -If you want to skip next sections we have 2 winner: `jsonparser` and `easyjson`. -`jsonparser` is up to 10 times faster than standard `encoding/json` package (depending on payload size and usage), and almost infinitely (literally) better in memory consumption because it operates with data on byte level, and provide direct slice pointers. -`easyjson` wins in CPU in medium tests and frankly i'm impressed with this package: it is remarkable results considering that it is almost drop-in replacement for `encoding/json` (require some code generation). - -It's hard to fully compare `jsonparser` and `easyjson` (or `ffson`), they a true parsers and fully process record, unlike `jsonparser` which parse only keys you specified. - -If you searching for replacement of `encoding/json` while keeping structs, `easyjson` is an amazing choice. If you want to process dynamic JSON, have memory constrains, or more control over your data you should try `jsonparser`. - -`jsonparser` performance heavily depends on usage, and it works best when you do not need to process full record, only some keys. The more calls you need to make, the slower it will be, in contrast `easyjson` (or `ffjson`, `encoding/json`) parser record only 1 time, and then you can make as many calls as you want. - -With great power comes great responsibility! :) - - -#### Small payload - -Each test processes 190 bytes of http log as a JSON record. -It should read multiple fields. -https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_small_payload_test.go - -Library | time/op | bytes/op | allocs/op - ------ | ------- | -------- | ------- -encoding/json struct | 7879 | 880 | 18 -encoding/json interface{} | 8946 | 1521 | 38 -Jeffail/gabs | 10053 | 1649 | 46 -bitly/go-simplejson | 10128 | 2241 | 36 -antonholmquist/jason | 27152 | 7237 | 101 -github.com/ugorji/go/codec | 8806 | 2176 | 31 -mreiferson/go-ujson | **7008** | **1409** | 37 -a8m/djson | 3862 | 1249 | 30 -pquerna/ffjson | **3769** | **624** | **15** -mailru/easyjson | **2002** | **192** | **9** -buger/jsonparser | **1367** | **0** | **0** -buger/jsonparser (EachKey API) | **809** | **0** | **0** - -Winners are ffjson, easyjson and jsonparser, where jsonparser is up to 9.8x faster than encoding/json and 4.6x faster than ffjson, and slightly faster than easyjson. -If you look at memory allocation, jsonparser has no rivals, as it makes no data copy and operates with raw []byte structures and pointers to it. - -#### Medium payload - -Each test processes a 2.4kb JSON record (based on Clearbit API). -It should read multiple nested fields and 1 array. - -https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_medium_payload_test.go - -| Library | time/op | bytes/op | allocs/op | -| ------- | ------- | -------- | --------- | -| encoding/json struct | 57749 | 1336 | 29 | -| encoding/json interface{} | 79297 | 10627 | 215 | -| Jeffail/gabs | 83807 | 11202 | 235 | -| bitly/go-simplejson | 88187 | 17187 | 220 | -| antonholmquist/jason | 94099 | 19013 | 247 | -| github.com/ugorji/go/codec | 114719 | 6712 | 152 | -| mreiferson/go-ujson | **56972** | 11547 | 270 | -| a8m/djson | 28525 | 10196 | 198 | -| pquerna/ffjson | **20298** | **856** | **20** | -| mailru/easyjson | **10512** | **336** | **12** | -| buger/jsonparser | **15955** | **0** | **0** | -| buger/jsonparser (EachKey API) | **8916** | **0** | **0** | - -The difference between ffjson and jsonparser in CPU usage is smaller, while the memory consumption difference is growing. On the other hand `easyjson` shows remarkable performance for medium payload. - -`gabs`, `go-simplejson` and `jason` are based on encoding/json and map[string]interface{} and actually only helpers for unstructured JSON, their performance correlate with `encoding/json interface{}`, and they will skip next round. -`go-ujson` while have its own parser, shows same performance as `encoding/json`, also skips next round. Same situation with `ugorji/go/codec`, but it showed unexpectedly bad performance for complex payloads. - - -#### Large payload - -Each test processes a 24kb JSON record (based on Discourse API) -It should read 2 arrays, and for each item in array get a few fields. -Basically it means processing a full JSON file. - -https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_large_payload_test.go - -| Library | time/op | bytes/op | allocs/op | -| --- | --- | --- | --- | -| encoding/json struct | 748336 | 8272 | 307 | -| encoding/json interface{} | 1224271 | 215425 | 3395 | -| a8m/djson | 510082 | 213682 | 2845 | -| pquerna/ffjson | **312271** | **7792** | **298** | -| mailru/easyjson | **154186** | **6992** | **288** | -| buger/jsonparser | **85308** | **0** | **0** | - -`jsonparser` now is a winner, but do not forget that it is way more lightweight parser than `ffson` or `easyjson`, and they have to parser all the data, while `jsonparser` parse only what you need. All `ffjson`, `easysjon` and `jsonparser` have their own parsing code, and does not depend on `encoding/json` or `interface{}`, thats one of the reasons why they are so fast. `easyjson` also use a bit of `unsafe` package to reduce memory consuption (in theory it can lead to some unexpected GC issue, but i did not tested enough) - -Also last benchmark did not included `EachKey` test, because in this particular case we need to read lot of Array values, and using `ArrayEach` is more efficient. - -## Questions and support - -All bug-reports and suggestions should go though Github Issues. - -## Contributing - -1. Fork it -2. Create your feature branch (git checkout -b my-new-feature) -3. Commit your changes (git commit -am 'Added some feature') -4. Push to the branch (git push origin my-new-feature) -5. Create new Pull Request - -## Development - -All my development happens using Docker, and repo include some Make tasks to simplify development. - -* `make build` - builds docker image, usually can be called only once -* `make test` - run tests -* `make fmt` - run go fmt -* `make bench` - run benchmarks (if you need to run only single benchmark modify `BENCHMARK` variable in make file) -* `make profile` - runs benchmark and generate 3 files- `cpu.out`, `mem.mprof` and `benchmark.test` binary, which can be used for `go tool pprof` -* `make bash` - enter container (i use it for running `go tool pprof` above) diff --git a/vendor/github.com/buger/jsonparser/bytes.go b/vendor/github.com/buger/jsonparser/bytes.go deleted file mode 100644 index 0bb0ff395..000000000 --- a/vendor/github.com/buger/jsonparser/bytes.go +++ /dev/null @@ -1,47 +0,0 @@ -package jsonparser - -import ( - bio "bytes" -) - -// minInt64 '-9223372036854775808' is the smallest representable number in int64 -const minInt64 = `9223372036854775808` - -// About 2x faster then strconv.ParseInt because it only supports base 10, which is enough for JSON -func parseInt(bytes []byte) (v int64, ok bool, overflow bool) { - if len(bytes) == 0 { - return 0, false, false - } - - var neg bool = false - if bytes[0] == '-' { - neg = true - bytes = bytes[1:] - } - - var b int64 = 0 - for _, c := range bytes { - if c >= '0' && c <= '9' { - b = (10 * v) + int64(c-'0') - } else { - return 0, false, false - } - if overflow = (b < v); overflow { - break - } - v = b - } - - if overflow { - if neg && bio.Equal(bytes, []byte(minInt64)) { - return b, true, false - } - return 0, false, true - } - - if neg { - return -v, true, false - } else { - return v, true, false - } -} diff --git a/vendor/github.com/buger/jsonparser/bytes_safe.go b/vendor/github.com/buger/jsonparser/bytes_safe.go deleted file mode 100644 index ff16a4a19..000000000 --- a/vendor/github.com/buger/jsonparser/bytes_safe.go +++ /dev/null @@ -1,25 +0,0 @@ -// +build appengine appenginevm - -package jsonparser - -import ( - "strconv" -) - -// See fastbytes_unsafe.go for explanation on why *[]byte is used (signatures must be consistent with those in that file) - -func equalStr(b *[]byte, s string) bool { - return string(*b) == s -} - -func parseFloat(b *[]byte) (float64, error) { - return strconv.ParseFloat(string(*b), 64) -} - -func bytesToString(b *[]byte) string { - return string(*b) -} - -func StringToBytes(s string) []byte { - return []byte(s) -} diff --git a/vendor/github.com/buger/jsonparser/bytes_unsafe.go b/vendor/github.com/buger/jsonparser/bytes_unsafe.go deleted file mode 100644 index 589fea87e..000000000 --- a/vendor/github.com/buger/jsonparser/bytes_unsafe.go +++ /dev/null @@ -1,44 +0,0 @@ -// +build !appengine,!appenginevm - -package jsonparser - -import ( - "reflect" - "strconv" - "unsafe" - "runtime" -) - -// -// The reason for using *[]byte rather than []byte in parameters is an optimization. As of Go 1.6, -// the compiler cannot perfectly inline the function when using a non-pointer slice. That is, -// the non-pointer []byte parameter version is slower than if its function body is manually -// inlined, whereas the pointer []byte version is equally fast to the manually inlined -// version. Instruction count in assembly taken from "go tool compile" confirms this difference. -// -// TODO: Remove hack after Go 1.7 release -// -func equalStr(b *[]byte, s string) bool { - return *(*string)(unsafe.Pointer(b)) == s -} - -func parseFloat(b *[]byte) (float64, error) { - return strconv.ParseFloat(*(*string)(unsafe.Pointer(b)), 64) -} - -// A hack until issue golang/go#2632 is fixed. -// See: https://github.com/golang/go/issues/2632 -func bytesToString(b *[]byte) string { - return *(*string)(unsafe.Pointer(b)) -} - -func StringToBytes(s string) []byte { - b := make([]byte, 0, 0) - bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) - bh.Data = sh.Data - bh.Cap = sh.Len - bh.Len = sh.Len - runtime.KeepAlive(s) - return b -} diff --git a/vendor/github.com/buger/jsonparser/escape.go b/vendor/github.com/buger/jsonparser/escape.go deleted file mode 100644 index 49669b942..000000000 --- a/vendor/github.com/buger/jsonparser/escape.go +++ /dev/null @@ -1,173 +0,0 @@ -package jsonparser - -import ( - "bytes" - "unicode/utf8" -) - -// JSON Unicode stuff: see https://tools.ietf.org/html/rfc7159#section-7 - -const supplementalPlanesOffset = 0x10000 -const highSurrogateOffset = 0xD800 -const lowSurrogateOffset = 0xDC00 - -const basicMultilingualPlaneReservedOffset = 0xDFFF -const basicMultilingualPlaneOffset = 0xFFFF - -func combineUTF16Surrogates(high, low rune) rune { - return supplementalPlanesOffset + (high-highSurrogateOffset)<<10 + (low - lowSurrogateOffset) -} - -const badHex = -1 - -func h2I(c byte) int { - switch { - case c >= '0' && c <= '9': - return int(c - '0') - case c >= 'A' && c <= 'F': - return int(c - 'A' + 10) - case c >= 'a' && c <= 'f': - return int(c - 'a' + 10) - } - return badHex -} - -// decodeSingleUnicodeEscape decodes a single \uXXXX escape sequence. The prefix \u is assumed to be present and -// is not checked. -// In JSON, these escapes can either come alone or as part of "UTF16 surrogate pairs" that must be handled together. -// This function only handles one; decodeUnicodeEscape handles this more complex case. -func decodeSingleUnicodeEscape(in []byte) (rune, bool) { - // We need at least 6 characters total - if len(in) < 6 { - return utf8.RuneError, false - } - - // Convert hex to decimal - h1, h2, h3, h4 := h2I(in[2]), h2I(in[3]), h2I(in[4]), h2I(in[5]) - if h1 == badHex || h2 == badHex || h3 == badHex || h4 == badHex { - return utf8.RuneError, false - } - - // Compose the hex digits - return rune(h1<<12 + h2<<8 + h3<<4 + h4), true -} - -// isUTF16EncodedRune checks if a rune is in the range for non-BMP characters, -// which is used to describe UTF16 chars. -// Source: https://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane -func isUTF16EncodedRune(r rune) bool { - return highSurrogateOffset <= r && r <= basicMultilingualPlaneReservedOffset -} - -func decodeUnicodeEscape(in []byte) (rune, int) { - if r, ok := decodeSingleUnicodeEscape(in); !ok { - // Invalid Unicode escape - return utf8.RuneError, -1 - } else if r <= basicMultilingualPlaneOffset && !isUTF16EncodedRune(r) { - // Valid Unicode escape in Basic Multilingual Plane - return r, 6 - } else if r2, ok := decodeSingleUnicodeEscape(in[6:]); !ok { // Note: previous decodeSingleUnicodeEscape success guarantees at least 6 bytes remain - // UTF16 "high surrogate" without manditory valid following Unicode escape for the "low surrogate" - return utf8.RuneError, -1 - } else if r2 < lowSurrogateOffset { - // Invalid UTF16 "low surrogate" - return utf8.RuneError, -1 - } else { - // Valid UTF16 surrogate pair - return combineUTF16Surrogates(r, r2), 12 - } -} - -// backslashCharEscapeTable: when '\X' is found for some byte X, it is to be replaced with backslashCharEscapeTable[X] -var backslashCharEscapeTable = [...]byte{ - '"': '"', - '\\': '\\', - '/': '/', - 'b': '\b', - 'f': '\f', - 'n': '\n', - 'r': '\r', - 't': '\t', -} - -// unescapeToUTF8 unescapes the single escape sequence starting at 'in' into 'out' and returns -// how many characters were consumed from 'in' and emitted into 'out'. -// If a valid escape sequence does not appear as a prefix of 'in', (-1, -1) to signal the error. -func unescapeToUTF8(in, out []byte) (inLen int, outLen int) { - if len(in) < 2 || in[0] != '\\' { - // Invalid escape due to insufficient characters for any escape or no initial backslash - return -1, -1 - } - - // https://tools.ietf.org/html/rfc7159#section-7 - switch e := in[1]; e { - case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': - // Valid basic 2-character escapes (use lookup table) - out[0] = backslashCharEscapeTable[e] - return 2, 1 - case 'u': - // Unicode escape - if r, inLen := decodeUnicodeEscape(in); inLen == -1 { - // Invalid Unicode escape - return -1, -1 - } else { - // Valid Unicode escape; re-encode as UTF8 - outLen := utf8.EncodeRune(out, r) - return inLen, outLen - } - } - - return -1, -1 -} - -// unescape unescapes the string contained in 'in' and returns it as a slice. -// If 'in' contains no escaped characters: -// Returns 'in'. -// Else, if 'out' is of sufficient capacity (guaranteed if cap(out) >= len(in)): -// 'out' is used to build the unescaped string and is returned with no extra allocation -// Else: -// A new slice is allocated and returned. -func Unescape(in, out []byte) ([]byte, error) { - firstBackslash := bytes.IndexByte(in, '\\') - if firstBackslash == -1 { - return in, nil - } - - // Get a buffer of sufficient size (allocate if needed) - if cap(out) < len(in) { - out = make([]byte, len(in)) - } else { - out = out[0:len(in)] - } - - // Copy the first sequence of unescaped bytes to the output and obtain a buffer pointer (subslice) - copy(out, in[:firstBackslash]) - in = in[firstBackslash:] - buf := out[firstBackslash:] - - for len(in) > 0 { - // Unescape the next escaped character - inLen, bufLen := unescapeToUTF8(in, buf) - if inLen == -1 { - return nil, MalformedStringEscapeError - } - - in = in[inLen:] - buf = buf[bufLen:] - - // Copy everything up until the next backslash - nextBackslash := bytes.IndexByte(in, '\\') - if nextBackslash == -1 { - copy(buf, in) - buf = buf[len(in):] - break - } else { - copy(buf, in[:nextBackslash]) - buf = buf[nextBackslash:] - in = in[nextBackslash:] - } - } - - // Trim the out buffer to the amount that was actually emitted - return out[:len(out)-len(buf)], nil -} diff --git a/vendor/github.com/buger/jsonparser/fuzz.go b/vendor/github.com/buger/jsonparser/fuzz.go deleted file mode 100644 index 854bd11b2..000000000 --- a/vendor/github.com/buger/jsonparser/fuzz.go +++ /dev/null @@ -1,117 +0,0 @@ -package jsonparser - -func FuzzParseString(data []byte) int { - r, err := ParseString(data) - if err != nil || r == "" { - return 0 - } - return 1 -} - -func FuzzEachKey(data []byte) int { - paths := [][]string{ - {"name"}, - {"order"}, - {"nested", "a"}, - {"nested", "b"}, - {"nested2", "a"}, - {"nested", "nested3", "b"}, - {"arr", "[1]", "b"}, - {"arrInt", "[3]"}, - {"arrInt", "[5]"}, - {"nested"}, - {"arr", "["}, - {"a\n", "b\n"}, - } - EachKey(data, func(idx int, value []byte, vt ValueType, err error) {}, paths...) - return 1 -} - -func FuzzDelete(data []byte) int { - Delete(data, "test") - return 1 -} - -func FuzzSet(data []byte) int { - _, err := Set(data, []byte(`"new value"`), "test") - if err != nil { - return 0 - } - return 1 -} - -func FuzzObjectEach(data []byte) int { - _ = ObjectEach(data, func(key, value []byte, valueType ValueType, off int) error { - return nil - }) - return 1 -} - -func FuzzParseFloat(data []byte) int { - _, err := ParseFloat(data) - if err != nil { - return 0 - } - return 1 -} - -func FuzzParseInt(data []byte) int { - _, err := ParseInt(data) - if err != nil { - return 0 - } - return 1 -} - -func FuzzParseBool(data []byte) int { - _, err := ParseBoolean(data) - if err != nil { - return 0 - } - return 1 -} - -func FuzzTokenStart(data []byte) int { - _ = tokenStart(data) - return 1 -} - -func FuzzGetString(data []byte) int { - _, err := GetString(data, "test") - if err != nil { - return 0 - } - return 1 -} - -func FuzzGetFloat(data []byte) int { - _, err := GetFloat(data, "test") - if err != nil { - return 0 - } - return 1 -} - -func FuzzGetInt(data []byte) int { - _, err := GetInt(data, "test") - if err != nil { - return 0 - } - return 1 -} - -func FuzzGetBoolean(data []byte) int { - _, err := GetBoolean(data, "test") - if err != nil { - return 0 - } - return 1 -} - -func FuzzGetUnsafeString(data []byte) int { - _, err := GetUnsafeString(data, "test") - if err != nil { - return 0 - } - return 1 -} diff --git a/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh b/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh deleted file mode 100644 index c573b0e2d..000000000 --- a/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -eu - -git clone https://github.com/dvyukov/go-fuzz-corpus -zip corpus.zip go-fuzz-corpus/json/corpus/* - -cp corpus.zip $OUT/fuzzparsestring_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzParseString fuzzparsestring - -cp corpus.zip $OUT/fuzzeachkey_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzEachKey fuzzeachkey - -cp corpus.zip $OUT/fuzzdelete_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzDelete fuzzdelete - -cp corpus.zip $OUT/fuzzset_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzSet fuzzset - -cp corpus.zip $OUT/fuzzobjecteach_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzObjectEach fuzzobjecteach - -cp corpus.zip $OUT/fuzzparsefloat_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzParseFloat fuzzparsefloat - -cp corpus.zip $OUT/fuzzparseint_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzParseInt fuzzparseint - -cp corpus.zip $OUT/fuzzparsebool_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzParseBool fuzzparsebool - -cp corpus.zip $OUT/fuzztokenstart_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzTokenStart fuzztokenstart - -cp corpus.zip $OUT/fuzzgetstring_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzGetString fuzzgetstring - -cp corpus.zip $OUT/fuzzgetfloat_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzGetFloat fuzzgetfloat - -cp corpus.zip $OUT/fuzzgetint_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzGetInt fuzzgetint - -cp corpus.zip $OUT/fuzzgetboolean_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzGetBoolean fuzzgetboolean - -cp corpus.zip $OUT/fuzzgetunsafestring_seed_corpus.zip -compile_go_fuzzer github.com/buger/jsonparser FuzzGetUnsafeString fuzzgetunsafestring - diff --git a/vendor/github.com/buger/jsonparser/parser.go b/vendor/github.com/buger/jsonparser/parser.go deleted file mode 100644 index 14b80bc48..000000000 --- a/vendor/github.com/buger/jsonparser/parser.go +++ /dev/null @@ -1,1283 +0,0 @@ -package jsonparser - -import ( - "bytes" - "errors" - "fmt" - "strconv" -) - -// Errors -var ( - KeyPathNotFoundError = errors.New("Key path not found") - UnknownValueTypeError = errors.New("Unknown value type") - MalformedJsonError = errors.New("Malformed JSON error") - MalformedStringError = errors.New("Value is string, but can't find closing '\"' symbol") - MalformedArrayError = errors.New("Value is array, but can't find closing ']' symbol") - MalformedObjectError = errors.New("Value looks like object, but can't find closing '}' symbol") - MalformedValueError = errors.New("Value looks like Number/Boolean/None, but can't find its end: ',' or '}' symbol") - OverflowIntegerError = errors.New("Value is number, but overflowed while parsing") - MalformedStringEscapeError = errors.New("Encountered an invalid escape sequence in a string") -) - -// How much stack space to allocate for unescaping JSON strings; if a string longer -// than this needs to be escaped, it will result in a heap allocation -const unescapeStackBufSize = 64 - -func tokenEnd(data []byte) int { - for i, c := range data { - switch c { - case ' ', '\n', '\r', '\t', ',', '}', ']': - return i - } - } - - return len(data) -} - -func findTokenStart(data []byte, token byte) int { - for i := len(data) - 1; i >= 0; i-- { - switch data[i] { - case token: - return i - case '[', '{': - return 0 - } - } - - return 0 -} - -func findKeyStart(data []byte, key string) (int, error) { - i := 0 - ln := len(data) - if ln > 0 && (data[0] == '{' || data[0] == '[') { - i = 1 - } - var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings - - if ku, err := Unescape(StringToBytes(key), stackbuf[:]); err == nil { - key = bytesToString(&ku) - } - - for i < ln { - switch data[i] { - case '"': - i++ - keyBegin := i - - strEnd, keyEscaped := stringEnd(data[i:]) - if strEnd == -1 { - break - } - i += strEnd - keyEnd := i - 1 - - valueOffset := nextToken(data[i:]) - if valueOffset == -1 { - break - } - - i += valueOffset - - // if string is a key, and key level match - k := data[keyBegin:keyEnd] - // for unescape: if there are no escape sequences, this is cheap; if there are, it is a - // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize - if keyEscaped { - if ku, err := Unescape(k, stackbuf[:]); err != nil { - break - } else { - k = ku - } - } - - if data[i] == ':' && len(key) == len(k) && bytesToString(&k) == key { - return keyBegin - 1, nil - } - - case '[': - end := blockEnd(data[i:], data[i], ']') - if end != -1 { - i = i + end - } - case '{': - end := blockEnd(data[i:], data[i], '}') - if end != -1 { - i = i + end - } - } - i++ - } - - return -1, KeyPathNotFoundError -} - -func tokenStart(data []byte) int { - for i := len(data) - 1; i >= 0; i-- { - switch data[i] { - case '\n', '\r', '\t', ',', '{', '[': - return i - } - } - - return 0 -} - -// Find position of next character which is not whitespace -func nextToken(data []byte) int { - for i, c := range data { - switch c { - case ' ', '\n', '\r', '\t': - continue - default: - return i - } - } - - return -1 -} - -// Find position of last character which is not whitespace -func lastToken(data []byte) int { - for i := len(data) - 1; i >= 0; i-- { - switch data[i] { - case ' ', '\n', '\r', '\t': - continue - default: - return i - } - } - - return -1 -} - -// Tries to find the end of string -// Support if string contains escaped quote symbols. -func stringEnd(data []byte) (int, bool) { - escaped := false - for i, c := range data { - if c == '"' { - if !escaped { - return i + 1, false - } else { - j := i - 1 - for { - if j < 0 || data[j] != '\\' { - return i + 1, true // even number of backslashes - } - j-- - if j < 0 || data[j] != '\\' { - break // odd number of backslashes - } - j-- - - } - } - } else if c == '\\' { - escaped = true - } - } - - return -1, escaped -} - -// Find end of the data structure, array or object. -// For array openSym and closeSym will be '[' and ']', for object '{' and '}' -func blockEnd(data []byte, openSym byte, closeSym byte) int { - level := 0 - i := 0 - ln := len(data) - - for i < ln { - switch data[i] { - case '"': // If inside string, skip it - se, _ := stringEnd(data[i+1:]) - if se == -1 { - return -1 - } - i += se - case openSym: // If open symbol, increase level - level++ - case closeSym: // If close symbol, increase level - level-- - - // If we have returned to the original level, we're done - if level == 0 { - return i + 1 - } - } - i++ - } - - return -1 -} - -func searchKeys(data []byte, keys ...string) int { - keyLevel := 0 - level := 0 - i := 0 - ln := len(data) - lk := len(keys) - lastMatched := true - - if lk == 0 { - return 0 - } - - var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings - - for i < ln { - switch data[i] { - case '"': - i++ - keyBegin := i - - strEnd, keyEscaped := stringEnd(data[i:]) - if strEnd == -1 { - return -1 - } - i += strEnd - keyEnd := i - 1 - - valueOffset := nextToken(data[i:]) - if valueOffset == -1 { - return -1 - } - - i += valueOffset - - // if string is a key - if data[i] == ':' { - if level < 1 { - return -1 - } - - key := data[keyBegin:keyEnd] - - // for unescape: if there are no escape sequences, this is cheap; if there are, it is a - // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize - var keyUnesc []byte - if !keyEscaped { - keyUnesc = key - } else if ku, err := Unescape(key, stackbuf[:]); err != nil { - return -1 - } else { - keyUnesc = ku - } - - if level <= len(keys) { - if equalStr(&keyUnesc, keys[level-1]) { - lastMatched = true - - // if key level match - if keyLevel == level-1 { - keyLevel++ - // If we found all keys in path - if keyLevel == lk { - return i + 1 - } - } - } else { - lastMatched = false - } - } else { - return -1 - } - } else { - i-- - } - case '{': - - // in case parent key is matched then only we will increase the level otherwise can directly - // can move to the end of this block - if !lastMatched { - end := blockEnd(data[i:], '{', '}') - if end == -1 { - return -1 - } - i += end - 1 - } else { - level++ - } - case '}': - level-- - if level == keyLevel { - keyLevel-- - } - case '[': - // If we want to get array element by index - if keyLevel == level && keys[level][0] == '[' { - var keyLen = len(keys[level]) - if keyLen < 3 || keys[level][0] != '[' || keys[level][keyLen-1] != ']' { - return -1 - } - aIdx, err := strconv.Atoi(keys[level][1 : keyLen-1]) - if err != nil { - return -1 - } - var curIdx int - var valueFound []byte - var valueOffset int - var curI = i - ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { - if curIdx == aIdx { - valueFound = value - valueOffset = offset - if dataType == String { - valueOffset = valueOffset - 2 - valueFound = data[curI+valueOffset : curI+valueOffset+len(value)+2] - } - } - curIdx += 1 - }) - - if valueFound == nil { - return -1 - } else { - subIndex := searchKeys(valueFound, keys[level+1:]...) - if subIndex < 0 { - return -1 - } - return i + valueOffset + subIndex - } - } else { - // Do not search for keys inside arrays - if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { - return -1 - } else { - i += arraySkip - 1 - } - } - case ':': // If encountered, JSON data is malformed - return -1 - } - - i++ - } - - return -1 -} - -func sameTree(p1, p2 []string) bool { - minLen := len(p1) - if len(p2) < minLen { - minLen = len(p2) - } - - for pi_1, p_1 := range p1[:minLen] { - if p2[pi_1] != p_1 { - return false - } - } - - return true -} - -func EachKey(data []byte, cb func(int, []byte, ValueType, error), paths ...[]string) int { - var x struct{} - pathFlags := make([]bool, len(paths)) - var level, pathsMatched, i int - ln := len(data) - - var maxPath int - for _, p := range paths { - if len(p) > maxPath { - maxPath = len(p) - } - } - - pathsBuf := make([]string, maxPath) - - for i < ln { - switch data[i] { - case '"': - i++ - keyBegin := i - - strEnd, keyEscaped := stringEnd(data[i:]) - if strEnd == -1 { - return -1 - } - i += strEnd - - keyEnd := i - 1 - - valueOffset := nextToken(data[i:]) - if valueOffset == -1 { - return -1 - } - - i += valueOffset - - // if string is a key, and key level match - if data[i] == ':' { - match := -1 - key := data[keyBegin:keyEnd] - - // for unescape: if there are no escape sequences, this is cheap; if there are, it is a - // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize - var keyUnesc []byte - if !keyEscaped { - keyUnesc = key - } else { - var stackbuf [unescapeStackBufSize]byte - if ku, err := Unescape(key, stackbuf[:]); err != nil { - return -1 - } else { - keyUnesc = ku - } - } - - if maxPath >= level { - if level < 1 { - cb(-1, nil, Unknown, MalformedJsonError) - return -1 - } - - pathsBuf[level-1] = bytesToString(&keyUnesc) - for pi, p := range paths { - if len(p) != level || pathFlags[pi] || !equalStr(&keyUnesc, p[level-1]) || !sameTree(p, pathsBuf[:level]) { - continue - } - - match = pi - - pathsMatched++ - pathFlags[pi] = true - - v, dt, _, e := Get(data[i+1:]) - cb(pi, v, dt, e) - - if pathsMatched == len(paths) { - break - } - } - if pathsMatched == len(paths) { - return i - } - } - - if match == -1 { - tokenOffset := nextToken(data[i+1:]) - i += tokenOffset - - if data[i] == '{' { - blockSkip := blockEnd(data[i:], '{', '}') - i += blockSkip + 1 - } - } - - if i < ln { - switch data[i] { - case '{', '}', '[', '"': - i-- - } - } - } else { - i-- - } - case '{': - level++ - case '}': - level-- - case '[': - var ok bool - arrIdxFlags := make(map[int]struct{}) - pIdxFlags := make([]bool, len(paths)) - - if level < 0 { - cb(-1, nil, Unknown, MalformedJsonError) - return -1 - } - - for pi, p := range paths { - if len(p) < level+1 || pathFlags[pi] || p[level][0] != '[' || !sameTree(p, pathsBuf[:level]) { - continue - } - if len(p[level]) >= 2 { - aIdx, _ := strconv.Atoi(p[level][1 : len(p[level])-1]) - arrIdxFlags[aIdx] = x - pIdxFlags[pi] = true - } - } - - if len(arrIdxFlags) > 0 { - level++ - - var curIdx int - arrOff, _ := ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { - if _, ok = arrIdxFlags[curIdx]; ok { - for pi, p := range paths { - if pIdxFlags[pi] { - aIdx, _ := strconv.Atoi(p[level-1][1 : len(p[level-1])-1]) - - if curIdx == aIdx { - of := searchKeys(value, p[level:]...) - - pathsMatched++ - pathFlags[pi] = true - - if of != -1 { - v, dt, _, e := Get(value[of:]) - cb(pi, v, dt, e) - } - } - } - } - } - - curIdx += 1 - }) - - if pathsMatched == len(paths) { - return i - } - - i += arrOff - 1 - } else { - // Do not search for keys inside arrays - if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { - return -1 - } else { - i += arraySkip - 1 - } - } - case ']': - level-- - } - - i++ - } - - return -1 -} - -// Data types available in valid JSON data. -type ValueType int - -const ( - NotExist = ValueType(iota) - String - Number - Object - Array - Boolean - Null - Unknown -) - -func (vt ValueType) String() string { - switch vt { - case NotExist: - return "non-existent" - case String: - return "string" - case Number: - return "number" - case Object: - return "object" - case Array: - return "array" - case Boolean: - return "boolean" - case Null: - return "null" - default: - return "unknown" - } -} - -var ( - trueLiteral = []byte("true") - falseLiteral = []byte("false") - nullLiteral = []byte("null") -) - -func createInsertComponent(keys []string, setValue []byte, comma, object bool) []byte { - isIndex := string(keys[0][0]) == "[" - offset := 0 - lk := calcAllocateSpace(keys, setValue, comma, object) - buffer := make([]byte, lk, lk) - if comma { - offset += WriteToBuffer(buffer[offset:], ",") - } - if isIndex && !comma { - offset += WriteToBuffer(buffer[offset:], "[") - } else { - if object { - offset += WriteToBuffer(buffer[offset:], "{") - } - if !isIndex { - offset += WriteToBuffer(buffer[offset:], "\"") - offset += WriteToBuffer(buffer[offset:], keys[0]) - offset += WriteToBuffer(buffer[offset:], "\":") - } - } - - for i := 1; i < len(keys); i++ { - if string(keys[i][0]) == "[" { - offset += WriteToBuffer(buffer[offset:], "[") - } else { - offset += WriteToBuffer(buffer[offset:], "{\"") - offset += WriteToBuffer(buffer[offset:], keys[i]) - offset += WriteToBuffer(buffer[offset:], "\":") - } - } - offset += WriteToBuffer(buffer[offset:], string(setValue)) - for i := len(keys) - 1; i > 0; i-- { - if string(keys[i][0]) == "[" { - offset += WriteToBuffer(buffer[offset:], "]") - } else { - offset += WriteToBuffer(buffer[offset:], "}") - } - } - if isIndex && !comma { - offset += WriteToBuffer(buffer[offset:], "]") - } - if object && !isIndex { - offset += WriteToBuffer(buffer[offset:], "}") - } - return buffer -} - -func calcAllocateSpace(keys []string, setValue []byte, comma, object bool) int { - isIndex := string(keys[0][0]) == "[" - lk := 0 - if comma { - // , - lk += 1 - } - if isIndex && !comma { - // [] - lk += 2 - } else { - if object { - // { - lk += 1 - } - if !isIndex { - // "keys[0]" - lk += len(keys[0]) + 3 - } - } - - - lk += len(setValue) - for i := 1; i < len(keys); i++ { - if string(keys[i][0]) == "[" { - // [] - lk += 2 - } else { - // {"keys[i]":setValue} - lk += len(keys[i]) + 5 - } - } - - if object && !isIndex { - // } - lk += 1 - } - - return lk -} - -func WriteToBuffer(buffer []byte, str string) int { - copy(buffer, str) - return len(str) -} - -/* - -Del - Receives existing data structure, path to delete. - -Returns: -`data` - return modified data - -*/ -func Delete(data []byte, keys ...string) []byte { - lk := len(keys) - if lk == 0 { - return data[:0] - } - - array := false - if len(keys[lk-1]) > 0 && string(keys[lk-1][0]) == "[" { - array = true - } - - var startOffset, keyOffset int - endOffset := len(data) - var err error - if !array { - if len(keys) > 1 { - _, _, startOffset, endOffset, err = internalGet(data, keys[:lk-1]...) - if err == KeyPathNotFoundError { - // problem parsing the data - return data - } - } - - keyOffset, err = findKeyStart(data[startOffset:endOffset], keys[lk-1]) - if err == KeyPathNotFoundError { - // problem parsing the data - return data - } - keyOffset += startOffset - _, _, _, subEndOffset, _ := internalGet(data[startOffset:endOffset], keys[lk-1]) - endOffset = startOffset + subEndOffset - tokEnd := tokenEnd(data[endOffset:]) - tokStart := findTokenStart(data[:keyOffset], ","[0]) - - if data[endOffset+tokEnd] == ","[0] { - endOffset += tokEnd + 1 - } else if data[endOffset+tokEnd] == " "[0] && len(data) > endOffset+tokEnd+1 && data[endOffset+tokEnd+1] == ","[0] { - endOffset += tokEnd + 2 - } else if data[endOffset+tokEnd] == "}"[0] && data[tokStart] == ","[0] { - keyOffset = tokStart - } - } else { - _, _, keyOffset, endOffset, err = internalGet(data, keys...) - if err == KeyPathNotFoundError { - // problem parsing the data - return data - } - - tokEnd := tokenEnd(data[endOffset:]) - tokStart := findTokenStart(data[:keyOffset], ","[0]) - - if data[endOffset+tokEnd] == ","[0] { - endOffset += tokEnd + 1 - } else if data[endOffset+tokEnd] == "]"[0] && data[tokStart] == ","[0] { - keyOffset = tokStart - } - } - - // We need to remove remaining trailing comma if we delete las element in the object - prevTok := lastToken(data[:keyOffset]) - remainedValue := data[endOffset:] - - var newOffset int - if nextToken(remainedValue) > -1 && remainedValue[nextToken(remainedValue)] == '}' && data[prevTok] == ',' { - newOffset = prevTok - } else { - newOffset = prevTok + 1 - } - - // We have to make a copy here if we don't want to mangle the original data, because byte slices are - // accessed by reference and not by value - dataCopy := make([]byte, len(data)) - copy(dataCopy, data) - data = append(dataCopy[:newOffset], dataCopy[endOffset:]...) - - return data -} - -/* - -Set - Receives existing data structure, path to set, and data to set at that key. - -Returns: -`value` - modified byte array -`err` - On any parsing error - -*/ -func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) { - // ensure keys are set - if len(keys) == 0 { - return nil, KeyPathNotFoundError - } - - _, _, startOffset, endOffset, err := internalGet(data, keys...) - if err != nil { - if err != KeyPathNotFoundError { - // problem parsing the data - return nil, err - } - // full path doesnt exist - // does any subpath exist? - var depth int - for i := range keys { - _, _, start, end, sErr := internalGet(data, keys[:i+1]...) - if sErr != nil { - break - } else { - endOffset = end - startOffset = start - depth++ - } - } - comma := true - object := false - if endOffset == -1 { - firstToken := nextToken(data) - // We can't set a top-level key if data isn't an object - if firstToken < 0 || data[firstToken] != '{' { - return nil, KeyPathNotFoundError - } - // Don't need a comma if the input is an empty object - secondToken := firstToken + 1 + nextToken(data[firstToken+1:]) - if data[secondToken] == '}' { - comma = false - } - // Set the top level key at the end (accounting for any trailing whitespace) - // This assumes last token is valid like '}', could check and return error - endOffset = lastToken(data) - } - depthOffset := endOffset - if depth != 0 { - // if subpath is a non-empty object, add to it - // or if subpath is a non-empty array, add to it - if (data[startOffset] == '{' && data[startOffset+1+nextToken(data[startOffset+1:])] != '}') || - (data[startOffset] == '[' && data[startOffset+1+nextToken(data[startOffset+1:])] == '{') && keys[depth:][0][0] == 91 { - depthOffset-- - startOffset = depthOffset - // otherwise, over-write it with a new object - } else { - comma = false - object = true - } - } else { - startOffset = depthOffset - } - value = append(data[:startOffset], append(createInsertComponent(keys[depth:], setValue, comma, object), data[depthOffset:]...)...) - } else { - // path currently exists - startComponent := data[:startOffset] - endComponent := data[endOffset:] - - value = make([]byte, len(startComponent)+len(endComponent)+len(setValue)) - newEndOffset := startOffset + len(setValue) - copy(value[0:startOffset], startComponent) - copy(value[startOffset:newEndOffset], setValue) - copy(value[newEndOffset:], endComponent) - } - return value, nil -} - -func getType(data []byte, offset int) ([]byte, ValueType, int, error) { - var dataType ValueType - endOffset := offset - - // if string value - if data[offset] == '"' { - dataType = String - if idx, _ := stringEnd(data[offset+1:]); idx != -1 { - endOffset += idx + 1 - } else { - return nil, dataType, offset, MalformedStringError - } - } else if data[offset] == '[' { // if array value - dataType = Array - // break label, for stopping nested loops - endOffset = blockEnd(data[offset:], '[', ']') - - if endOffset == -1 { - return nil, dataType, offset, MalformedArrayError - } - - endOffset += offset - } else if data[offset] == '{' { // if object value - dataType = Object - // break label, for stopping nested loops - endOffset = blockEnd(data[offset:], '{', '}') - - if endOffset == -1 { - return nil, dataType, offset, MalformedObjectError - } - - endOffset += offset - } else { - // Number, Boolean or None - end := tokenEnd(data[endOffset:]) - - if end == -1 { - return nil, dataType, offset, MalformedValueError - } - - value := data[offset : endOffset+end] - - switch data[offset] { - case 't', 'f': // true or false - if bytes.Equal(value, trueLiteral) || bytes.Equal(value, falseLiteral) { - dataType = Boolean - } else { - return nil, Unknown, offset, UnknownValueTypeError - } - case 'u', 'n': // undefined or null - if bytes.Equal(value, nullLiteral) { - dataType = Null - } else { - return nil, Unknown, offset, UnknownValueTypeError - } - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': - dataType = Number - default: - return nil, Unknown, offset, UnknownValueTypeError - } - - endOffset += end - } - return data[offset:endOffset], dataType, endOffset, nil -} - -/* -Get - Receives data structure, and key path to extract value from. - -Returns: -`value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error -`dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` -`offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. -`err` - If key not found or any other parsing issue it should return error. If key not found it also sets `dataType` to `NotExist` - -Accept multiple keys to specify path to JSON value (in case of quering nested structures). -If no keys provided it will try to extract closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. -*/ -func Get(data []byte, keys ...string) (value []byte, dataType ValueType, offset int, err error) { - a, b, _, d, e := internalGet(data, keys...) - return a, b, d, e -} - -func internalGet(data []byte, keys ...string) (value []byte, dataType ValueType, offset, endOffset int, err error) { - if len(keys) > 0 { - if offset = searchKeys(data, keys...); offset == -1 { - return nil, NotExist, -1, -1, KeyPathNotFoundError - } - } - - // Go to closest value - nO := nextToken(data[offset:]) - if nO == -1 { - return nil, NotExist, offset, -1, MalformedJsonError - } - - offset += nO - value, dataType, endOffset, err = getType(data, offset) - if err != nil { - return value, dataType, offset, endOffset, err - } - - // Strip quotes from string values - if dataType == String { - value = value[1 : len(value)-1] - } - - return value[:len(value):len(value)], dataType, offset, endOffset, nil -} - -// ArrayEach is used when iterating arrays, accepts a callback function with the same return arguments as `Get`. -func ArrayEach(data []byte, cb func(value []byte, dataType ValueType, offset int, err error), keys ...string) (offset int, err error) { - if len(data) == 0 { - return -1, MalformedObjectError - } - - nT := nextToken(data) - if nT == -1 { - return -1, MalformedJsonError - } - - offset = nT + 1 - - if len(keys) > 0 { - if offset = searchKeys(data, keys...); offset == -1 { - return offset, KeyPathNotFoundError - } - - // Go to closest value - nO := nextToken(data[offset:]) - if nO == -1 { - return offset, MalformedJsonError - } - - offset += nO - - if data[offset] != '[' { - return offset, MalformedArrayError - } - - offset++ - } - - nO := nextToken(data[offset:]) - if nO == -1 { - return offset, MalformedJsonError - } - - offset += nO - - if data[offset] == ']' { - return offset, nil - } - - for true { - v, t, o, e := Get(data[offset:]) - - if e != nil { - return offset, e - } - - if o == 0 { - break - } - - if t != NotExist { - cb(v, t, offset+o-len(v), e) - } - - if e != nil { - break - } - - offset += o - - skipToToken := nextToken(data[offset:]) - if skipToToken == -1 { - return offset, MalformedArrayError - } - offset += skipToToken - - if data[offset] == ']' { - break - } - - if data[offset] != ',' { - return offset, MalformedArrayError - } - - offset++ - } - - return offset, nil -} - -// ObjectEach iterates over the key-value pairs of a JSON object, invoking a given callback for each such entry -func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) { - offset := 0 - - // Descend to the desired key, if requested - if len(keys) > 0 { - if off := searchKeys(data, keys...); off == -1 { - return KeyPathNotFoundError - } else { - offset = off - } - } - - // Validate and skip past opening brace - if off := nextToken(data[offset:]); off == -1 { - return MalformedObjectError - } else if offset += off; data[offset] != '{' { - return MalformedObjectError - } else { - offset++ - } - - // Skip to the first token inside the object, or stop if we find the ending brace - if off := nextToken(data[offset:]); off == -1 { - return MalformedJsonError - } else if offset += off; data[offset] == '}' { - return nil - } - - // Loop pre-condition: data[offset] points to what should be either the next entry's key, or the closing brace (if it's anything else, the JSON is malformed) - for offset < len(data) { - // Step 1: find the next key - var key []byte - - // Check what the the next token is: start of string, end of object, or something else (error) - switch data[offset] { - case '"': - offset++ // accept as string and skip opening quote - case '}': - return nil // we found the end of the object; stop and return success - default: - return MalformedObjectError - } - - // Find the end of the key string - var keyEscaped bool - if off, esc := stringEnd(data[offset:]); off == -1 { - return MalformedJsonError - } else { - key, keyEscaped = data[offset:offset+off-1], esc - offset += off - } - - // Unescape the string if needed - if keyEscaped { - var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings - if keyUnescaped, err := Unescape(key, stackbuf[:]); err != nil { - return MalformedStringEscapeError - } else { - key = keyUnescaped - } - } - - // Step 2: skip the colon - if off := nextToken(data[offset:]); off == -1 { - return MalformedJsonError - } else if offset += off; data[offset] != ':' { - return MalformedJsonError - } else { - offset++ - } - - // Step 3: find the associated value, then invoke the callback - if value, valueType, off, err := Get(data[offset:]); err != nil { - return err - } else if err := callback(key, value, valueType, offset+off); err != nil { // Invoke the callback here! - return err - } else { - offset += off - } - - // Step 4: skip over the next comma to the following token, or stop if we hit the ending brace - if off := nextToken(data[offset:]); off == -1 { - return MalformedArrayError - } else { - offset += off - switch data[offset] { - case '}': - return nil // Stop if we hit the close brace - case ',': - offset++ // Ignore the comma - default: - return MalformedObjectError - } - } - - // Skip to the next token after the comma - if off := nextToken(data[offset:]); off == -1 { - return MalformedArrayError - } else { - offset += off - } - } - - return MalformedObjectError // we shouldn't get here; it's expected that we will return via finding the ending brace -} - -// GetUnsafeString returns the value retrieved by `Get`, use creates string without memory allocation by mapping string to slice memory. It does not handle escape symbols. -func GetUnsafeString(data []byte, keys ...string) (val string, err error) { - v, _, _, e := Get(data, keys...) - - if e != nil { - return "", e - } - - return bytesToString(&v), nil -} - -// GetString returns the value retrieved by `Get`, cast to a string if possible, trying to properly handle escape and utf8 symbols -// If key data type do not match, it will return an error. -func GetString(data []byte, keys ...string) (val string, err error) { - v, t, _, e := Get(data, keys...) - - if e != nil { - return "", e - } - - if t != String { - return "", fmt.Errorf("Value is not a string: %s", string(v)) - } - - // If no escapes return raw content - if bytes.IndexByte(v, '\\') == -1 { - return string(v), nil - } - - return ParseString(v) -} - -// GetFloat returns the value retrieved by `Get`, cast to a float64 if possible. -// The offset is the same as in `Get`. -// If key data type do not match, it will return an error. -func GetFloat(data []byte, keys ...string) (val float64, err error) { - v, t, _, e := Get(data, keys...) - - if e != nil { - return 0, e - } - - if t != Number { - return 0, fmt.Errorf("Value is not a number: %s", string(v)) - } - - return ParseFloat(v) -} - -// GetInt returns the value retrieved by `Get`, cast to a int64 if possible. -// If key data type do not match, it will return an error. -func GetInt(data []byte, keys ...string) (val int64, err error) { - v, t, _, e := Get(data, keys...) - - if e != nil { - return 0, e - } - - if t != Number { - return 0, fmt.Errorf("Value is not a number: %s", string(v)) - } - - return ParseInt(v) -} - -// GetBoolean returns the value retrieved by `Get`, cast to a bool if possible. -// The offset is the same as in `Get`. -// If key data type do not match, it will return error. -func GetBoolean(data []byte, keys ...string) (val bool, err error) { - v, t, _, e := Get(data, keys...) - - if e != nil { - return false, e - } - - if t != Boolean { - return false, fmt.Errorf("Value is not a boolean: %s", string(v)) - } - - return ParseBoolean(v) -} - -// ParseBoolean parses a Boolean ValueType into a Go bool (not particularly useful, but here for completeness) -func ParseBoolean(b []byte) (bool, error) { - switch { - case bytes.Equal(b, trueLiteral): - return true, nil - case bytes.Equal(b, falseLiteral): - return false, nil - default: - return false, MalformedValueError - } -} - -// ParseString parses a String ValueType into a Go string (the main parsing work is unescaping the JSON string) -func ParseString(b []byte) (string, error) { - var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings - if bU, err := Unescape(b, stackbuf[:]); err != nil { - return "", MalformedValueError - } else { - return string(bU), nil - } -} - -// ParseNumber parses a Number ValueType into a Go float64 -func ParseFloat(b []byte) (float64, error) { - if v, err := parseFloat(&b); err != nil { - return 0, MalformedValueError - } else { - return v, nil - } -} - -// ParseInt parses a Number ValueType into a Go int64 -func ParseInt(b []byte) (int64, error) { - if v, ok, overflow := parseInt(b); !ok { - if overflow { - return 0, OverflowIntegerError - } - return 0, MalformedValueError - } else { - return v, nil - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/LICENSE b/vendor/github.com/google/jsonschema-go/LICENSE similarity index 95% rename from vendor/github.com/mark3labs/mcp-go/LICENSE rename to vendor/github.com/google/jsonschema-go/LICENSE index 3d4843545..1cb53e9df 100644 --- a/vendor/github.com/mark3labs/mcp-go/LICENSE +++ b/vendor/github.com/google/jsonschema-go/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 Anthropic, PBC +Copyright (c) 2025 JSON Schema Go Project Authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go new file mode 100644 index 000000000..d4dd6436b --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/annotations.go @@ -0,0 +1,76 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import "maps" + +// An annotations tracks certain properties computed by keywords that are used by validation. +// ("Annotation" is the spec's term.) +// In particular, the unevaluatedItems and unevaluatedProperties keywords need to know which +// items and properties were evaluated (validated successfully). +type annotations struct { + allItems bool // all items were evaluated + endIndex int // 1+largest index evaluated by prefixItems + evaluatedIndexes map[int]bool // set of indexes evaluated by contains + allProperties bool // all properties were evaluated + evaluatedProperties map[string]bool // set of properties evaluated by various keywords +} + +// noteIndex marks i as evaluated. +func (a *annotations) noteIndex(i int) { + if a.evaluatedIndexes == nil { + a.evaluatedIndexes = map[int]bool{} + } + a.evaluatedIndexes[i] = true +} + +// noteEndIndex marks items with index less than end as evaluated. +func (a *annotations) noteEndIndex(end int) { + if end > a.endIndex { + a.endIndex = end + } +} + +// noteProperty marks prop as evaluated. +func (a *annotations) noteProperty(prop string) { + if a.evaluatedProperties == nil { + a.evaluatedProperties = map[string]bool{} + } + a.evaluatedProperties[prop] = true +} + +// noteProperties marks all the properties in props as evaluated. +func (a *annotations) noteProperties(props map[string]bool) { + a.evaluatedProperties = merge(a.evaluatedProperties, props) +} + +// merge adds b's annotations to a. +// a must not be nil. +func (a *annotations) merge(b *annotations) { + if b == nil { + return + } + if b.allItems { + a.allItems = true + } + if b.endIndex > a.endIndex { + a.endIndex = b.endIndex + } + a.evaluatedIndexes = merge(a.evaluatedIndexes, b.evaluatedIndexes) + if b.allProperties { + a.allProperties = true + } + a.evaluatedProperties = merge(a.evaluatedProperties, b.evaluatedProperties) +} + +// merge adds t's keys to s and returns s. +// If s is nil, it returns a copy of t. +func merge[K comparable](s, t map[K]bool) map[K]bool { + if s == nil { + return maps.Clone(t) + } + maps.Copy(s, t) + return s +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/doc.go b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go new file mode 100644 index 000000000..a34bab725 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/doc.go @@ -0,0 +1,101 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +/* +Package jsonschema is an implementation of the [JSON Schema specification], +a JSON-based format for describing the structure of JSON data. +The package can be used to read schemas for code generation, and to validate +data using the draft 2020-12 specification. Validation with other drafts +or custom meta-schemas is not supported. + +Construct a [Schema] as you would any Go struct (for example, by writing +a struct literal), or unmarshal a JSON schema into a [Schema] in the usual +way (with [encoding/json], for instance). It can then be used for code +generation or other purposes without further processing. +You can also infer a schema from a Go struct. + +# Resolution + +A Schema can refer to other schemas, both inside and outside itself. These +references must be resolved before a schema can be used for validation. +Call [Schema.Resolve] to obtain a resolved schema (called a [Resolved]). +If the schema has external references, pass a [ResolveOptions] with a [Loader] +to load them. To validate default values in a schema, set +[ResolveOptions.ValidateDefaults] to true. + +# Validation + +Call [Resolved.Validate] to validate a JSON value. The value must be a +Go value that looks like the result of unmarshaling a JSON value into an +[any] or a struct. For example, the JSON value + + {"name": "Al", "scores": [90, 80, 100]} + +could be represented as the Go value + + map[string]any{ + "name": "Al", + "scores": []any{90, 80, 100}, + } + +or as a value of this type: + + type Player struct { + Name string `json:"name"` + Scores []int `json:"scores"` + } + +# Inference + +The [For] function returns a [Schema] describing the given Go type. +Each field in the struct becomes a property of the schema. +The values of "json" tags are respected: the field's property name is taken +from the tag, and fields omitted from the JSON are omitted from the schema as +well. +For example, `jsonschema.For[Player]()` returns this schema: + + { + "properties": { + "name": { + "type": "string" + }, + "scores": { + "type": "array", + "items": {"type": "integer"} + } + "required": ["name", "scores"], + "additionalProperties": {"not": {}} + } + } + +Use the "jsonschema" struct tag to provide a description for the property: + + type Player struct { + Name string `json:"name" jsonschema:"player name"` + Scores []int `json:"scores" jsonschema:"scores of player's games"` + } + +# Deviations from the specification + +Regular expressions are processed with Go's regexp package, which differs +from ECMA 262, most significantly in not supporting back-references. +See [this table of differences] for more. + +The "format" keyword described in [section 7 of the validation spec] is recorded +in the Schema, but is ignored during validation. +It does not even produce [annotations]. +Use the "pattern" keyword instead: it will work more reliably across JSON Schema +implementations. See [learnjsonschema.com] for more recommendations about "format". + +The content keywords described in [section 8 of the validation spec] +are recorded in the schema, but ignored during validation. + +[JSON Schema specification]: https://json-schema.org +[section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 +[section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 +[learnjsonschema.com]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ +[this table of differences]: https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 +[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations +*/ +package jsonschema diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/infer.go b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go new file mode 100644 index 000000000..ae624ad09 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/infer.go @@ -0,0 +1,248 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains functions that infer a schema from a Go type. + +package jsonschema + +import ( + "fmt" + "log/slog" + "maps" + "math/big" + "reflect" + "regexp" + "time" +) + +// ForOptions are options for the [For] and [ForType] functions. +type ForOptions struct { + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON + // Schema are ignored instead of causing an error. + // This allows callers to adjust the resulting schema using custom knowledge. + // For example, an interface type where all the possible implementations are + // known can be described with "oneof". + IgnoreInvalidTypes bool + + // TypeSchemas maps types to their schemas. + // If [For] encounters a type that is a key in this map, the + // corresponding value is used as the resulting schema (after cloning to + // ensure uniqueness). + // Types in this map override the default translations, as described + // in [For]'s documentation. + TypeSchemas map[reflect.Type]*Schema +} + +// For constructs a JSON schema object for the given type argument. +// If non-nil, the provided options configure certain aspects of this contruction, +// described below. + +// It translates Go types into compatible JSON schema types, as follows. +// These defaults can be overridden by [ForOptions.TypeSchemas]. +// +// - Strings have schema type "string". +// - Bools have schema type "boolean". +// - Signed and unsigned integer types have schema type "integer". +// - Floating point types have schema type "number". +// - Slices and arrays have schema type "array", and a corresponding schema +// for items. +// - Maps with string key have schema type "object", and corresponding +// schema for additionalProperties. +// - Structs have schema type "object", and disallow additionalProperties. +// Their properties are derived from exported struct fields, using the +// struct field JSON name. Fields that are marked "omitempty" are +// considered optional; all other fields become required properties. +// - Some types in the standard library that implement json.Marshaler +// translate to schemas that match the values to which they marshal. +// For example, [time.Time] translates to the schema for strings. +// +// For will return an error if there is a cycle in the types. +// +// By default, For returns an error if t contains (possibly recursively) any of the +// following Go types, as they are incompatible with the JSON schema spec. +// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. +// - maps with key other than 'string' +// - function types +// - channel types +// - complex numbers +// - unsafe pointers +// +// This function recognizes struct field tags named "jsonschema". +// A jsonschema tag on a field is used as the description for the corresponding property. +// For future compatibility, descriptions must not start with "WORD=", where WORD is a +// sequence of non-whitespace characters. +func For[T any](opts *ForOptions) (*Schema, error) { + if opts == nil { + opts = &ForOptions{} + } + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + var z T + return nil, fmt.Errorf("For[%T](): %w", z, err) + } + return s, nil +} + +// ForType is like [For], but takes a [reflect.Type] +func ForType(t reflect.Type, opts *ForOptions) (*Schema, error) { + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + maps.Copy(schemas, opts.TypeSchemas) + s, err := forType(t, map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + return nil, fmt.Errorf("ForType(%s): %w", t, err) + } + return s, nil +} + +func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { + // Follow pointers: the schema for *T is almost the same as for T, except that + // an explicit JSON "null" is allowed for the pointer. + allowNull := false + for t.Kind() == reflect.Pointer { + allowNull = true + t = t.Elem() + } + + // Check for cycles + // User defined types have a name, so we can skip those that are natively defined + if t.Name() != "" { + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + } + + if s := schemas[t]; s != nil { + return s.CloneSchemas(), nil + } + + var ( + s = new(Schema) + err error + ) + + switch t.Kind() { + case reflect.Bool: + s.Type = "boolean" + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Uintptr: + s.Type = "integer" + + case reflect.Float32, reflect.Float64: + s.Type = "number" + + case reflect.Interface: + // Unrestricted + + case reflect.Map: + if t.Key().Kind() != reflect.String { + if ignore { + return nil, nil // ignore + } + return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) + } + if t.Key().Kind() != reflect.String { + } + s.Type = "object" + s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing map value schema: %v", err) + } + if ignore && s.AdditionalProperties == nil { + // Ignore if the element type is invalid. + return nil, nil + } + + case reflect.Slice, reflect.Array: + s.Type = "array" + s.Items, err = forType(t.Elem(), seen, ignore, schemas) + if err != nil { + return nil, fmt.Errorf("computing element schema: %v", err) + } + if ignore && s.Items == nil { + // Ignore if the element type is invalid. + return nil, nil + } + if t.Kind() == reflect.Array { + s.MinItems = Ptr(t.Len()) + s.MaxItems = Ptr(t.Len()) + } + + case reflect.String: + s.Type = "string" + + case reflect.Struct: + s.Type = "object" + // no additional properties are allowed + s.AdditionalProperties = falseSchema() + for _, field := range reflect.VisibleFields(t) { + if field.Anonymous { + continue + } + + info := fieldJSONInfo(field) + if info.omit { + continue + } + if s.Properties == nil { + s.Properties = make(map[string]*Schema) + } + fs, err := forType(field.Type, seen, ignore, schemas) + if err != nil { + return nil, err + } + if ignore && fs == nil { + // Skip fields of invalid type. + continue + } + if tag, ok := field.Tag.Lookup("jsonschema"); ok { + if tag == "" { + return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) + } + if disallowedPrefixRegexp.MatchString(tag) { + return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) + } + fs.Description = tag + } + s.Properties[info.name] = fs + if !info.settings["omitempty"] && !info.settings["omitzero"] { + s.Required = append(s.Required, info.name) + } + } + + default: + if ignore { + // Ignore. + return nil, nil + } + return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) + } + if allowNull && s.Type != "" { + s.Types = []string{"null", s.Type} + s.Type = "" + } + return s, nil +} + +// initialSchemaMap holds types from the standard library that have MarshalJSON methods. +var initialSchemaMap = make(map[reflect.Type]*Schema) + +func init() { + ss := &Schema{Type: "string"} + initialSchemaMap[reflect.TypeFor[time.Time]()] = ss + initialSchemaMap[reflect.TypeFor[slog.Level]()] = ss + initialSchemaMap[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + initialSchemaMap[reflect.TypeFor[big.Rat]()] = ss + initialSchemaMap[reflect.TypeFor[big.Float]()] = ss +} + +// Disallow jsonschema tag values beginning "WORD=", for future expansion. +var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go new file mode 100644 index 000000000..ed1b16991 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/json_pointer.go @@ -0,0 +1,146 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JSON Pointers. +// A JSON Pointer is a path that refers to one JSON value within another. +// If the path is empty, it refers to the root value. +// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", +// selecting successive properties (for JSON objects) or items (for JSON arrays). +// For example, when applied to this JSON value: +// { +// "points": [ +// {"x": 1, "y": 2}, +// {"x": 3, "y": 4} +// ] +// } +// +// the JSON Pointer "/points/1/x" refers to the number 3. +// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. + +package jsonschema + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +var ( + jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") + jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") +) + +func escapeJSONPointerSegment(s string) string { + return jsonPointerEscaper.Replace(s) +} + +func unescapeJSONPointerSegment(s string) string { + return jsonPointerUnescaper.Replace(s) +} + +// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't +// convert strings to numbers, because that depends on the traversal: a segment +// is treated as a number when applied to an array, but a string when applied to +// an object. See section 4 of the spec. +func parseJSONPointer(ptr string) (segments []string, err error) { + if ptr == "" { + return nil, nil + } + if ptr[0] != '/' { + return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) + } + // Unlike file paths, consecutive slashes are not coalesced. + // Split is nicer than Cut here, because it gets a final "/" right. + segments = strings.Split(ptr[1:], "/") + if strings.Contains(ptr, "~") { + // Undo the simple escaping rules that allow one to include a slash in a segment. + for i := range segments { + segments[i] = unescapeJSONPointerSegment(segments[i]) + } + } + return segments, nil +} + +// dereferenceJSONPointer returns the Schema that sptr points to within s, +// or an error if none. +// This implementation suffices for JSON Schema: pointers are applied only to Schemas, +// and refer only to Schemas. +func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { + defer wrapf(&err, "JSON Pointer %q", sptr) + + segments, err := parseJSONPointer(sptr) + if err != nil { + return nil, err + } + v := reflect.ValueOf(s) + for _, seg := range segments { + switch v.Kind() { + case reflect.Pointer: + v = v.Elem() + if !v.IsValid() { + return nil, errors.New("navigated to nil reference") + } + fallthrough // if valid, can only be a pointer to a Schema + + case reflect.Struct: + // The segment must refer to a field in a Schema. + if v.Type() != reflect.TypeFor[Schema]() { + return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) + } + v = lookupSchemaField(v, seg) + if !v.IsValid() { + return nil, fmt.Errorf("no schema field %q", seg) + } + case reflect.Slice, reflect.Array: + // The segment must be an integer without leading zeroes that refers to an item in the + // slice or array. + if seg == "-" { + return nil, errors.New("the JSON Pointer array segment '-' is not supported") + } + if len(seg) > 1 && seg[0] == '0' { + return nil, fmt.Errorf("segment %q has leading zeroes", seg) + } + n, err := strconv.Atoi(seg) + if err != nil { + return nil, fmt.Errorf("invalid int: %q", seg) + } + if n < 0 || n >= v.Len() { + return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) + } + v = v.Index(n) + // Cannot be invalid. + case reflect.Map: + // The segment must be a key in the map. + v = v.MapIndex(reflect.ValueOf(seg)) + if !v.IsValid() { + return nil, fmt.Errorf("no key %q in map", seg) + } + default: + return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) + } + } + if s, ok := v.Interface().(*Schema); ok { + return s, nil + } + return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) +} + +// lookupSchemaField returns the value of the field with the given name in v, +// or the zero value if there is no such field or it is not of type Schema or *Schema. +func lookupSchemaField(v reflect.Value, name string) reflect.Value { + if name == "type" { + // The "type" keyword may refer to Type or Types. + // At most one will be non-zero. + if t := v.FieldByName("Type"); !t.IsZero() { + return t + } + return v.FieldByName("Types") + } + if sf, ok := schemaFieldMap[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go new file mode 100644 index 000000000..ece9be880 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/resolve.go @@ -0,0 +1,548 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file deals with preparing a schema for validation, including various checks, +// optimizations, and the resolution of cross-schema references. + +package jsonschema + +import ( + "errors" + "fmt" + "net/url" + "reflect" + "regexp" + "strings" +) + +// A Resolved consists of a [Schema] along with associated information needed to +// validate documents against it. +// A Resolved has been validated against its meta-schema, and all its references +// (the $ref and $dynamicRef keywords) have been resolved to their referenced Schemas. +// Call [Schema.Resolve] to obtain a Resolved from a Schema. +type Resolved struct { + root *Schema + // map from $ids to their schemas + resolvedURIs map[string]*Schema + // map from schemas to additional info computed during resolution + resolvedInfos map[*Schema]*resolvedInfo +} + +func newResolved(s *Schema) *Resolved { + return &Resolved{ + root: s, + resolvedURIs: map[string]*Schema{}, + resolvedInfos: map[*Schema]*resolvedInfo{}, + } +} + +// resolvedInfo holds information specific to a schema that is computed by [Schema.Resolve]. +type resolvedInfo struct { + s *Schema + // The JSON Pointer path from the root schema to here. + // Used in errors. + path string + // The schema's base schema. + // If the schema is the root or has an ID, its base is itself. + // Otherwise, its base is the innermost enclosing schema whose base + // is itself. + // Intuitively, a base schema is one that can be referred to with a + // fragmentless URI. + base *Schema + // The URI for the schema, if it is the root or has an ID. + // Otherwise nil. + // Invariants: + // s.base.uri != nil. + // s.base == s <=> s.uri != nil + uri *url.URL + // The schema to which Ref refers. + resolvedRef *Schema + + // If the schema has a dynamic ref, exactly one of the next two fields + // will be non-zero after successful resolution. + // The schema to which the dynamic ref refers when it acts lexically. + resolvedDynamicRef *Schema + // The anchor to look up on the stack when the dynamic ref acts dynamically. + dynamicRefAnchor string + + // The following fields are independent of arguments to Schema.Resolved, + // so they could live on the Schema. We put them here for simplicity. + + // The set of required properties. + isRequired map[string]bool + + // Compiled regexps. + pattern *regexp.Regexp + patternProperties map[*regexp.Regexp]*Schema + + // Map from anchors to subschemas. + anchors map[string]anchorInfo +} + +// Schema returns the schema that was resolved. +// It must not be modified. +func (r *Resolved) Schema() *Schema { return r.root } + +// schemaString returns a short string describing the schema. +func (r *Resolved) schemaString(s *Schema) string { + if s.ID != "" { + return s.ID + } + info := r.resolvedInfos[s] + if info.path != "" { + return info.path + } + return "" +} + +// A Loader reads and unmarshals the schema at uri, if any. +type Loader func(uri *url.URL) (*Schema, error) + +// ResolveOptions are options for [Schema.Resolve]. +type ResolveOptions struct { + // BaseURI is the URI relative to which the root schema should be resolved. + // If non-empty, must be an absolute URI (one that starts with a scheme). + // It is resolved (in the URI sense; see [url.ResolveReference]) with root's + // $id property. + // If the resulting URI is not absolute, then the schema cannot contain + // relative URI references. + BaseURI string + // Loader loads schemas that are referred to by a $ref but are not under the + // root schema (remote references). + // If nil, resolving a remote reference will return an error. + Loader Loader + // ValidateDefaults determines whether to validate values of "default" keywords + // against their schemas. + // The [JSON Schema specification] does not require this, but it is recommended + // if defaults will be used. + // + // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations + ValidateDefaults bool +} + +// Resolve resolves all references within the schema and performs other tasks that +// prepare the schema for validation. +// If opts is nil, the default values are used. +// The schema must not be changed after Resolve is called. +// The same schema may be resolved multiple times. +func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { + // There are up to five steps required to prepare a schema to validate. + // 1. Load: read the schema from somewhere and unmarshal it. + // This schema (root) may have been loaded or created in memory, but other schemas that + // come into the picture in step 4 will be loaded by the given loader. + // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. + // Precompute some values along the way. + // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and + // resolve (in the URI sense) all identifiers and anchors with their bases. This step results + // in a map from URIs to schemas within root. + // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. + // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. + r := &resolver{loaded: map[string]*Resolved{}} + if opts != nil { + r.opts = *opts + } + var base *url.URL + if r.opts.BaseURI == "" { + base = &url.URL{} // so we can call ResolveReference on it + } else { + var err error + base, err = url.Parse(r.opts.BaseURI) + if err != nil { + return nil, fmt.Errorf("parsing base URI: %w", err) + } + } + + if r.opts.Loader == nil { + r.opts.Loader = func(uri *url.URL) (*Schema, error) { + return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") + } + } + + resolved, err := r.resolve(root, base) + if err != nil { + return nil, err + } + if r.opts.ValidateDefaults { + if err := resolved.validateDefaults(); err != nil { + return nil, err + } + } + // TODO: before we return, throw away anything we don't need for validation. + return resolved, nil +} + +// A resolver holds the state for resolution. +type resolver struct { + opts ResolveOptions + // A cache of loaded and partly resolved schemas. (They may not have had their + // refs resolved.) The cache ensures that the loader will never be called more + // than once with the same URI, and that reference cycles are handled properly. + loaded map[string]*Resolved +} + +func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { + if baseURI.Fragment != "" { + return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) + } + rs := newResolved(s) + + if err := s.check(rs.resolvedInfos); err != nil { + return nil, err + } + + if err := resolveURIs(rs, baseURI); err != nil { + return nil, err + } + + // Remember the schema by both the URI we loaded it from and its canonical name, + // which may differ if the schema has an $id. + // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. + r.loaded[baseURI.String()] = rs + r.loaded[rs.resolvedInfos[s].uri.String()] = rs + + if err := r.resolveRefs(rs); err != nil { + return nil, err + } + return rs, nil +} + +func (root *Schema) check(infos map[*Schema]*resolvedInfo) error { + // Check for structural validity. Do this first and fail fast: + // bad structure will cause other code to panic. + if err := root.checkStructure(infos); err != nil { + return err + } + + var errs []error + report := func(err error) { errs = append(errs, err) } + + for ss := range root.all() { + ss.checkLocal(report, infos) + } + return errors.Join(errs...) +} + +// checkStructure verifies that root and its subschemas form a tree. +// It also assigns each schema a unique path, to improve error messages. +func (root *Schema) checkStructure(infos map[*Schema]*resolvedInfo) error { + assert(len(infos) == 0, "non-empty infos") + + var check func(reflect.Value, []byte) error + check = func(v reflect.Value, path []byte) error { + // For the purpose of error messages, the root schema has path "root" + // and other schemas' paths are their JSON Pointer from the root. + p := "root" + if len(path) > 0 { + p = string(path) + } + s := v.Interface().(*Schema) + if s == nil { + return fmt.Errorf("jsonschema: schema at %s is nil", p) + } + if info, ok := infos[s]; ok { + // We've seen s before. + // The schema graph at root is not a tree, but it needs to + // be because a schema's base must be unique. + // A cycle would also put Schema.all into an infinite recursion. + return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", + root, info.path, p) + } + infos[s] = &resolvedInfo{s: s, path: p} + + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. + // A nil is valid: it just means the field isn't present. + if !fv.IsNil() { + if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { + return err + } + } + + case schemaSliceType: + for i := range fv.Len() { + if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { + return err + } + } + + case schemaMapType: + iter := fv.MapRange() + for iter.Next() { + key := escapeJSONPointerSegment(iter.Key().String()) + if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { + return err + } + } + } + + } + return nil + } + + return check(reflect.ValueOf(root), make([]byte, 0, 256)) +} + +// checkLocal checks s for validity, independently of other schemas it may refer to. +// Since checking a regexp involves compiling it, checkLocal saves those compiled regexps +// in the schema for later use. +// It appends the errors it finds to errs. +func (s *Schema) checkLocal(report func(error), infos map[*Schema]*resolvedInfo) { + addf := func(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) + } + + if s == nil { + addf("nil subschema") + return + } + if err := s.basicChecks(); err != nil { + report(err) + return + } + + // TODO: validate the schema's properties, + // ideally by jsonschema-validating it against the meta-schema. + + // Some properties are present so that Schemas can round-trip, but we do not + // validate them. + // Currently, it's just the $vocabulary property. + // As a special case, we can validate the 2020-12 meta-schema. + if s.Vocabulary != nil && s.Schema != draft202012 { + addf("cannot validate a schema with $vocabulary") + } + + info := infos[s] + + // Check and compile regexps. + if s.Pattern != "" { + re, err := regexp.Compile(s.Pattern) + if err != nil { + addf("pattern: %v", err) + } else { + info.pattern = re + } + } + if len(s.PatternProperties) > 0 { + info.patternProperties = map[*regexp.Regexp]*Schema{} + for reString, subschema := range s.PatternProperties { + re, err := regexp.Compile(reString) + if err != nil { + addf("patternProperties[%q]: %v", reString, err) + continue + } + info.patternProperties[re] = subschema + } + } + + // Build a set of required properties, to avoid quadratic behavior when validating + // a struct. + if len(s.Required) > 0 { + info.isRequired = map[string]bool{} + for _, r := range s.Required { + info.isRequired[r] = true + } + } +} + +// resolveURIs resolves the ids and anchors in all the schemas of root, relative +// to baseURI. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section +// 8.2.1. +// +// Every schema has a base URI and a parent base URI. +// +// The parent base URI is the base URI of the lexically enclosing schema, or for +// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. +// +// If the schema has no $id property, the base URI of a schema is that of its parent. +// If the schema does have an $id, it must be a URI, possibly relative. The schema's +// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against +// the parent base. +// +// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): +// +// { +// allOf: [ +// {$id: "sub1.json", minLength: 5}, +// {$id: "http://b.com", minimum: 10}, +// {not: {maximum: 20}} +// ] +// } +// +// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. +// +// schema base URI +// root http://a.com/root.json +// allOf/0 http://a.com/sub1.json +// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) +// allOf/2 http://a.com/root.json (inherited from parent) +// allOf/2/not http://a.com/root.json (inherited from parent) +func resolveURIs(rs *Resolved, baseURI *url.URL) error { + var resolve func(s, base *Schema) error + resolve = func(s, base *Schema) error { + info := rs.resolvedInfos[s] + baseInfo := rs.resolvedInfos[base] + + // ids are scoped to the root. + if s.ID != "" { + // A non-empty ID establishes a new base. + idURI, err := url.Parse(s.ID) + if err != nil { + return err + } + if idURI.Fragment != "" { + return fmt.Errorf("$id %s must not have a fragment", s.ID) + } + // The base URI for this schema is its $id resolved against the parent base. + info.uri = baseInfo.uri.ResolveReference(idURI) + if !info.uri.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %q)", s.ID, baseInfo.uri) + } + rs.resolvedURIs[info.uri.String()] = s + base = s // needed for anchors + baseInfo = rs.resolvedInfos[base] + } + info.base = base + + // Anchors and dynamic anchors are URI fragments that are scoped to their base. + // We treat them as keys in a map stored within the schema. + setAnchor := func(anchor string, dynamic bool) error { + if anchor != "" { + if _, ok := baseInfo.anchors[anchor]; ok { + return fmt.Errorf("duplicate anchor %q in %s", anchor, baseInfo.uri) + } + if baseInfo.anchors == nil { + baseInfo.anchors = map[string]anchorInfo{} + } + baseInfo.anchors[anchor] = anchorInfo{s, dynamic} + } + return nil + } + + setAnchor(s.Anchor, false) + setAnchor(s.DynamicAnchor, true) + + for c := range s.children() { + if err := resolve(c, base); err != nil { + return err + } + } + return nil + } + + // Set the root URI to the base for now. If the root has an $id, this will change. + rs.resolvedInfos[rs.root].uri = baseURI + // The original base, even if changed, is still a valid way to refer to the root. + rs.resolvedURIs[baseURI.String()] = rs.root + + return resolve(rs.root, rs.root) +} + +// resolveRefs replaces every ref in the schemas with the schema it refers to. +// A reference that doesn't resolve within the schema may refer to some other schema +// that needs to be loaded. +func (r *resolver) resolveRefs(rs *Resolved) error { + for s := range rs.root.all() { + info := rs.resolvedInfos[s] + if s.Ref != "" { + refSchema, _, err := r.resolveRef(rs, s, s.Ref) + if err != nil { + return err + } + // Whether or not the anchor referred to by $ref fragment is dynamic, + // the ref still treats it lexically. + info.resolvedRef = refSchema + } + if s.DynamicRef != "" { + refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) + if err != nil { + return err + } + if frag != "" { + // The dynamic ref's fragment points to a dynamic anchor. + // We must resolve the fragment at validation time. + info.dynamicRefAnchor = frag + } else { + // There is no dynamic anchor in the lexically referenced schema, + // so the dynamic ref behaves like a lexical ref. + info.resolvedDynamicRef = refSchema + } + } + } + return nil +} + +// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. +func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { + refURI, err := url.Parse(ref) + if err != nil { + return nil, "", err + } + // URI-resolve the ref against the current base URI to get a complete URI. + base := rs.resolvedInfos[s].base + refURI = rs.resolvedInfos[base].uri.ResolveReference(refURI) + // The non-fragment part of a ref URI refers to the base URI of some schema. + // This part is the same for dynamic refs too: their non-fragment part resolves + // lexically. + u := *refURI + u.Fragment = "" + fraglessRefURI := &u + // Look it up locally. + referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] + if referencedSchema == nil { + // The schema is remote. Maybe we've already loaded it. + // We assume that the non-fragment part of refURI refers to a top-level schema + // document. That is, we don't support the case exemplified by + // http://foo.com/bar.json/baz, where the document is in bar.json and + // the reference points to a subschema within it. + // TODO: support that case. + if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { + referencedSchema = lrs.root + } else { + // Try to load the schema. + ls, err := r.opts.Loader(fraglessRefURI) + if err != nil { + return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) + } + lrs, err := r.resolve(ls, fraglessRefURI) + if err != nil { + return nil, "", err + } + referencedSchema = lrs.root + assert(referencedSchema != nil, "nil referenced schema") + // Copy the resolvedInfos from lrs into rs, without overwriting + // (hence we can't use maps.Insert). + for s, i := range lrs.resolvedInfos { + if rs.resolvedInfos[s] == nil { + rs.resolvedInfos[s] = i + } + } + } + } + + frag := refURI.Fragment + // Look up frag in refSchema. + // frag is either a JSON Pointer or the name of an anchor. + // A JSON Pointer is either the empty string or begins with a '/', + // whereas anchors are always non-empty strings that don't contain slashes. + if frag != "" && !strings.HasPrefix(frag, "/") { + resInfo := rs.resolvedInfos[referencedSchema] + info, found := resInfo.anchors[frag] + + if !found { + return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) + } + if info.dynamic { + dynamicFragment = frag + } + return info.schema, dynamicFragment, nil + } + // frag is a JSON Pointer. + s, err = dereferenceJSONPointer(referencedSchema, frag) + return s, "", err +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/schema.go b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go new file mode 100644 index 000000000..3b4db9a6e --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/schema.go @@ -0,0 +1,436 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/json" + "errors" + "fmt" + "iter" + "maps" + "math" + "reflect" + "slices" +) + +// A Schema is a JSON schema object. +// It corresponds to the 2020-12 draft, as described in https://json-schema.org/draft/2020-12, +// specifically: +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-01 +// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01 +// +// A Schema value may have non-zero values for more than one field: +// all relevant non-zero fields are used for validation. +// There is one exception to provide more Go type-safety: the Type and Types fields +// are mutually exclusive. +// +// Since this struct is a Go representation of a JSON value, it inherits JSON's +// distinction between nil and empty. Nil slices and maps are considered absent, +// but empty ones are present and affect validation. For example, +// +// Schema{Enum: nil} +// +// is equivalent to an empty schema, so it validates every instance. But +// +// Schema{Enum: []any{}} +// +// requires equality to some slice element, so it vacuously rejects every instance. +type Schema struct { + // core + ID string `json:"$id,omitempty"` + Schema string `json:"$schema,omitempty"` + Ref string `json:"$ref,omitempty"` + Comment string `json:"$comment,omitempty"` + Defs map[string]*Schema `json:"$defs,omitempty"` + // definitions is deprecated but still allowed. It is a synonym for $defs. + Definitions map[string]*Schema `json:"definitions,omitempty"` + + Anchor string `json:"$anchor,omitempty"` + DynamicAnchor string `json:"$dynamicAnchor,omitempty"` + DynamicRef string `json:"$dynamicRef,omitempty"` + Vocabulary map[string]bool `json:"$vocabulary,omitempty"` + + // metadata + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Default json.RawMessage `json:"default,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + WriteOnly bool `json:"writeOnly,omitempty"` + Examples []any `json:"examples,omitempty"` + + // validation + // Use Type for a single type, or Types for multiple types; never both. + Type string `json:"-"` + Types []string `json:"-"` + Enum []any `json:"enum,omitempty"` + // Const is *any because a JSON null (Go nil) is a valid value. + Const *any `json:"const,omitempty"` + MultipleOf *float64 `json:"multipleOf,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` + ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` + Pattern string `json:"pattern,omitempty"` + + // arrays + PrefixItems []*Schema `json:"prefixItems,omitempty"` + Items *Schema `json:"items,omitempty"` + MinItems *int `json:"minItems,omitempty"` + MaxItems *int `json:"maxItems,omitempty"` + AdditionalItems *Schema `json:"additionalItems,omitempty"` + UniqueItems bool `json:"uniqueItems,omitempty"` + Contains *Schema `json:"contains,omitempty"` + MinContains *int `json:"minContains,omitempty"` // *int, not int: default is 1, not 0 + MaxContains *int `json:"maxContains,omitempty"` + UnevaluatedItems *Schema `json:"unevaluatedItems,omitempty"` + + // objects + MinProperties *int `json:"minProperties,omitempty"` + MaxProperties *int `json:"maxProperties,omitempty"` + Required []string `json:"required,omitempty"` + DependentRequired map[string][]string `json:"dependentRequired,omitempty"` + Properties map[string]*Schema `json:"properties,omitempty"` + PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` + AdditionalProperties *Schema `json:"additionalProperties,omitempty"` + PropertyNames *Schema `json:"propertyNames,omitempty"` + UnevaluatedProperties *Schema `json:"unevaluatedProperties,omitempty"` + + // logic + AllOf []*Schema `json:"allOf,omitempty"` + AnyOf []*Schema `json:"anyOf,omitempty"` + OneOf []*Schema `json:"oneOf,omitempty"` + Not *Schema `json:"not,omitempty"` + + // conditional + If *Schema `json:"if,omitempty"` + Then *Schema `json:"then,omitempty"` + Else *Schema `json:"else,omitempty"` + DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` + + // other + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 + ContentEncoding string `json:"contentEncoding,omitempty"` + ContentMediaType string `json:"contentMediaType,omitempty"` + ContentSchema *Schema `json:"contentSchema,omitempty"` + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 + Format string `json:"format,omitempty"` + + // Extra allows for additional keywords beyond those specified. + Extra map[string]any `json:"-"` +} + +// falseSchema returns a new Schema tree that fails to validate any value. +func falseSchema() *Schema { + return &Schema{Not: &Schema{}} +} + +// anchorInfo records the subschema to which an anchor refers, and whether +// the anchor keyword is $anchor or $dynamicAnchor. +type anchorInfo struct { + schema *Schema + dynamic bool +} + +// String returns a short description of the schema. +func (s *Schema) String() string { + if s.ID != "" { + return s.ID + } + if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { + return fmt.Sprintf("anchor %s", a) + } + return "" +} + +// CloneSchemas returns a copy of s. +// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. +// This allows both s and s.CloneSchemas() to appear as sub-schemas of the same parent. +func (s *Schema) CloneSchemas() *Schema { + if s == nil { + return nil + } + s2 := *s + v := reflect.ValueOf(&s2) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + sscss := fv.Interface().(*Schema) + fv.Set(reflect.ValueOf(sscss.CloneSchemas())) + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + slice = slices.Clone(slice) + for i, ss := range slice { + slice[i] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(slice)) + + case schemaMapType: + m := fv.Interface().(map[string]*Schema) + m = maps.Clone(m) + for k, ss := range m { + m[k] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(m)) + } + } + return &s2 +} + +func (s *Schema) basicChecks() error { + if s.Type != "" && s.Types != nil { + return errors.New("both Type and Types are set; at most one should be") + } + if s.Defs != nil && s.Definitions != nil { + return errors.New("both Defs and Definitions are set; at most one should be") + } + return nil +} + +type schemaWithoutMethods Schema // doesn't implement json.{Unm,M}arshaler + +func (s *Schema) MarshalJSON() ([]byte, error) { + if err := s.basicChecks(); err != nil { + return nil, err + } + + // Marshal either Type or Types as "type". + var typ any + switch { + case s.Type != "": + typ = s.Type + case s.Types != nil: + typ = s.Types + } + ms := struct { + Type any `json:"type,omitempty"` + *schemaWithoutMethods + }{ + Type: typ, + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + bs, err := marshalStructWithMap(&ms, "Extra") + if err != nil { + return nil, err + } + // Marshal {} as true and {"not": {}} as false. + // It is wasteful to do this here instead of earlier, but much easier. + switch { + case bytes.Equal(bs, []byte(`{}`)): + bs = []byte("true") + case bytes.Equal(bs, []byte(`{"not":true}`)): + bs = []byte("false") + } + return bs, nil +} + +func (s *Schema) UnmarshalJSON(data []byte) error { + // A JSON boolean is a valid schema. + var b bool + if err := json.Unmarshal(data, &b); err == nil { + if b { + // true is the empty schema, which validates everything. + *s = Schema{} + } else { + // false is the schema that validates nothing. + *s = *falseSchema() + } + return nil + } + + ms := struct { + Type json.RawMessage `json:"type,omitempty"` + Const json.RawMessage `json:"const,omitempty"` + MinLength *integer `json:"minLength,omitempty"` + MaxLength *integer `json:"maxLength,omitempty"` + MinItems *integer `json:"minItems,omitempty"` + MaxItems *integer `json:"maxItems,omitempty"` + MinProperties *integer `json:"minProperties,omitempty"` + MaxProperties *integer `json:"maxProperties,omitempty"` + MinContains *integer `json:"minContains,omitempty"` + MaxContains *integer `json:"maxContains,omitempty"` + + *schemaWithoutMethods + }{ + schemaWithoutMethods: (*schemaWithoutMethods)(s), + } + if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil { + return err + } + // Unmarshal "type" as either Type or Types. + var err error + if len(ms.Type) > 0 { + switch ms.Type[0] { + case '"': + err = json.Unmarshal(ms.Type, &s.Type) + case '[': + err = json.Unmarshal(ms.Type, &s.Types) + default: + err = fmt.Errorf(`invalid value for "type": %q`, ms.Type) + } + } + if err != nil { + return err + } + + unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { + if len(raw) == 0 { + return nil + } + if bytes.Equal(raw, []byte("null")) { + *p = new(any) + return nil + } + return json.Unmarshal(raw, p) + } + + // Setting Const to a pointer to null will marshal properly, but won't + // unmarshal: the *any is set to nil, not a pointer to nil. + if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { + return err + } + + set := func(dst **int, src *integer) { + if src != nil { + *dst = Ptr(int(*src)) + } + } + + set(&s.MinLength, ms.MinLength) + set(&s.MaxLength, ms.MaxLength) + set(&s.MinItems, ms.MinItems) + set(&s.MaxItems, ms.MaxItems) + set(&s.MinProperties, ms.MinProperties) + set(&s.MaxProperties, ms.MaxProperties) + set(&s.MinContains, ms.MinContains) + set(&s.MaxContains, ms.MaxContains) + + return nil +} + +type integer int32 // for the integer-valued fields of Schema + +func (ip *integer) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + // nothing to do + return nil + } + // If there is a decimal point, src is a floating-point number. + var i int64 + if bytes.ContainsRune(data, '.') { + var f float64 + if err := json.Unmarshal(data, &f); err != nil { + return errors.New("not a number") + } + i = int64(f) + if float64(i) != f { + return errors.New("not an integer value") + } + } else { + if err := json.Unmarshal(data, &i); err != nil { + return errors.New("cannot be unmarshaled into an int") + } + } + // Ensure behavior is the same on both 32-bit and 64-bit systems. + if i < math.MinInt32 || i > math.MaxInt32 { + return errors.New("integer is out of range") + } + *ip = integer(i) + return nil +} + +// Ptr returns a pointer to a new variable whose value is x. +func Ptr[T any](x T) *T { return &x } + +// every applies f preorder to every schema under s including s. +// The second argument to f is the path to the schema appended to the argument path. +// It stops when f returns false. +func (s *Schema) every(f func(*Schema) bool) bool { + return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) +} + +// everyChild reports whether f is true for every immediate child schema of s. +func (s *Schema) everyChild(f func(*Schema) bool) bool { + v := reflect.ValueOf(s) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + // A field that contains an individual schema. A nil is valid: it just means the field isn't present. + c := fv.Interface().(*Schema) + if c != nil && !f(c) { + return false + } + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + for _, c := range slice { + if !f(c) { + return false + } + } + + case schemaMapType: + // Sort keys for determinism. + m := fv.Interface().(map[string]*Schema) + for _, k := range slices.Sorted(maps.Keys(m)) { + if !f(m[k]) { + return false + } + } + } + } + return true +} + +// all wraps every in an iterator. +func (s *Schema) all() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.every(yield) } +} + +// children wraps everyChild in an iterator. +func (s *Schema) children() iter.Seq[*Schema] { + return func(yield func(*Schema) bool) { s.everyChild(yield) } +} + +var ( + schemaType = reflect.TypeFor[*Schema]() + schemaSliceType = reflect.TypeFor[[]*Schema]() + schemaMapType = reflect.TypeFor[map[string]*Schema]() +) + +type structFieldInfo struct { + sf reflect.StructField + jsonName string +} + +var ( + // the visible fields of Schema that have a JSON name, sorted by that name + schemaFieldInfos []structFieldInfo + // map from JSON name to field + schemaFieldMap = map[string]reflect.StructField{} +) + +func init() { + for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { + info := fieldJSONInfo(sf) + if !info.omit { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.name}) + } + } + slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { + return cmp.Compare(i1.jsonName, i2.jsonName) + }) + for _, info := range schemaFieldInfos { + schemaFieldMap[info.jsonName] = info.sf + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/util.go b/vendor/github.com/google/jsonschema-go/jsonschema/util.go new file mode 100644 index 000000000..5cfa27dc6 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/util.go @@ -0,0 +1,463 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "bytes" + "cmp" + "encoding/binary" + "encoding/json" + "fmt" + "hash/maphash" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" +) + +// Equal reports whether two Go values representing JSON values are equal according +// to the JSON Schema spec. +// The values must not contain cycles. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-4.2.2. +// It behaves like reflect.DeepEqual, except that numbers are compared according +// to mathematical equality. +func Equal(x, y any) bool { + return equalValue(reflect.ValueOf(x), reflect.ValueOf(y)) +} + +func equalValue(x, y reflect.Value) bool { + // Copied from src/reflect/deepequal.go, omitting the visited check (because JSON + // values are trees). + if !x.IsValid() || !y.IsValid() { + return x.IsValid() == y.IsValid() + } + + // Treat numbers specially. + rx, ok1 := jsonNumber(x) + ry, ok2 := jsonNumber(y) + if ok1 && ok2 { + return rx.Cmp(ry) == 0 + } + if x.Kind() != y.Kind() { + return false + } + switch x.Kind() { + case reflect.Array: + if x.Len() != y.Len() { + return false + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Slice: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + // Special case for []byte, which is common. + if x.Type().Elem().Kind() == reflect.Uint8 && x.Type() == y.Type() { + return bytes.Equal(x.Bytes(), y.Bytes()) + } + for i := range x.Len() { + if !equalValue(x.Index(i), y.Index(i)) { + return false + } + } + return true + case reflect.Interface: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Pointer: + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + return equalValue(x.Elem(), y.Elem()) + case reflect.Struct: + t := x.Type() + if t != y.Type() { + return false + } + for i := range t.NumField() { + sf := t.Field(i) + if !sf.IsExported() { + continue + } + if !equalValue(x.FieldByIndex(sf.Index), y.FieldByIndex(sf.Index)) { + return false + } + } + return true + case reflect.Map: + if x.IsNil() != y.IsNil() { + return false + } + if x.Len() != y.Len() { + return false + } + if x.UnsafePointer() == y.UnsafePointer() { + return true + } + iter := x.MapRange() + for iter.Next() { + vx := iter.Value() + vy := y.MapIndex(iter.Key()) + if !vy.IsValid() || !equalValue(vx, vy) { + return false + } + } + return true + case reflect.Func: + if x.Type() != y.Type() { + return false + } + if x.IsNil() && y.IsNil() { + return true + } + panic("cannot compare functions") + case reflect.String: + return x.String() == y.String() + case reflect.Bool: + return x.Bool() == y.Bool() + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", x.Kind())) + } +} + +// hashValue adds v to the data hashed by h. v must not have cycles. +// hashValue panics if the value contains functions or channels, or maps whose +// key type is not string. +// It ignores unexported fields of structs. +// Calls to hashValue with the equal values (in the sense +// of [Equal]) result in the same sequence of values written to the hash. +func hashValue(h *maphash.Hash, v reflect.Value) { + // TODO: replace writes of basic types with WriteComparable in 1.24. + + writeUint := func(u uint64) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], u) + h.Write(buf[:]) + } + + var write func(reflect.Value) + write = func(v reflect.Value) { + if r, ok := jsonNumber(v); ok { + // We want 1.0 and 1 to hash the same. + // big.Rats are always normalized, so they will be. + // We could do this more efficiently by handling the int and float cases + // separately, but that's premature. + writeUint(uint64(r.Sign() + 1)) + h.Write(r.Num().Bytes()) + h.Write(r.Denom().Bytes()) + return + } + switch v.Kind() { + case reflect.Invalid: + h.WriteByte(0) + case reflect.String: + h.WriteString(v.String()) + case reflect.Bool: + if v.Bool() { + h.WriteByte(1) + } else { + h.WriteByte(0) + } + case reflect.Complex64, reflect.Complex128: + c := v.Complex() + writeUint(math.Float64bits(real(c))) + writeUint(math.Float64bits(imag(c))) + case reflect.Array, reflect.Slice: + // Although we could treat []byte more efficiently, + // JSON values are unlikely to contain them. + writeUint(uint64(v.Len())) + for i := range v.Len() { + write(v.Index(i)) + } + case reflect.Interface, reflect.Pointer: + write(v.Elem()) + case reflect.Struct: + t := v.Type() + for i := range t.NumField() { + if sf := t.Field(i); sf.IsExported() { + write(v.FieldByIndex(sf.Index)) + } + } + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + panic("map with non-string key") + } + // Sort the keys so the hash is deterministic. + keys := v.MapKeys() + // Write the length. That distinguishes between, say, two consecutive + // maps with disjoint keys from one map that has the items of both. + writeUint(uint64(len(keys))) + slices.SortFunc(keys, func(x, y reflect.Value) int { return cmp.Compare(x.String(), y.String()) }) + for _, k := range keys { + write(k) + write(v.MapIndex(k)) + } + // Ints, uints and floats handled in jsonNumber, at top of function. + default: + panic(fmt.Sprintf("unsupported kind: %s", v.Kind())) + } + } + + write(v) +} + +// jsonNumber converts a numeric value or a json.Number to a [big.Rat]. +// If v is not a number, it returns nil, false. +func jsonNumber(v reflect.Value) (*big.Rat, bool) { + r := new(big.Rat) + switch { + case !v.IsValid(): + return nil, false + case v.CanInt(): + r.SetInt64(v.Int()) + case v.CanUint(): + r.SetUint64(v.Uint()) + case v.CanFloat(): + r.SetFloat64(v.Float()) + default: + jn, ok := v.Interface().(json.Number) + if !ok { + return nil, false + } + if _, ok := r.SetString(jn.String()); !ok { + // This can fail in rare cases; for example, "1e9999999". + // That is a valid JSON number, since the spec puts no limit on the size + // of the exponent. + return nil, false + } + } + return r, true +} + +// jsonType returns a string describing the type of the JSON value, +// as described in the JSON Schema specification: +// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1. +// It returns "", false if the value is not valid JSON. +func jsonType(v reflect.Value) (string, bool) { + if !v.IsValid() { + // Not v.IsNil(): a nil []any is still a JSON array. + return "null", true + } + if v.CanInt() || v.CanUint() { + return "integer", true + } + if v.CanFloat() { + if _, f := math.Modf(v.Float()); f == 0 { + return "integer", true + } + return "number", true + } + switch v.Kind() { + case reflect.Bool: + return "boolean", true + case reflect.String: + return "string", true + case reflect.Slice, reflect.Array: + return "array", true + case reflect.Map, reflect.Struct: + return "object", true + default: + return "", false + } +} + +func assert(cond bool, msg string) { + if !cond { + panic("assertion failed: " + msg) + } +} + +// marshalStructWithMap marshals its first argument to JSON, treating the field named +// mapField as an embedded map. The first argument must be a pointer to +// a struct. The underlying type of mapField must be a map[string]any, and it must have +// a "-" json tag, meaning it will not be marshaled. +// +// For example, given this struct: +// +// type S struct { +// A int +// Extra map[string] any `json:"-"` +// } +// +// and this value: +// +// s := S{A: 1, Extra: map[string]any{"B": 2}} +// +// the call marshalJSONWithMap(s, "Extra") would return +// +// {"A": 1, "B": 2} +// +// It is an error if the map contains the same key as another struct field's +// JSON name. +// +// marshalStructWithMap calls json.Marshal on a value of type T, so T must not +// have a MarshalJSON method that calls this function, on pain of infinite regress. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here the function requires `-` json tag, does not clear the mapField map, +// and handles embedded struct due to the implementation of jsonNames in this package. +// +// TODO: avoid this restriction on T by forcing it to marshal in a default way. +// See https://go.dev/play/p/EgXKJHxEx_R. +func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { + // Marshal the struct and the map separately, and concatenate the bytes. + // This strategy is dramatically less complicated than + // constructing a synthetic struct or map with the combined keys. + if s == nil { + return []byte("null"), nil + } + s2 := *s + vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) + mapVal := vMapField.Interface().(map[string]any) + + // Check for duplicates. + names := jsonNames(reflect.TypeFor[T]()) + for key := range mapVal { + if names[key] { + return nil, fmt.Errorf("map key %q duplicates struct field", key) + } + } + + structBytes, err := json.Marshal(s2) + if err != nil { + return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) + } + if len(mapVal) == 0 { + return structBytes, nil + } + mapBytes, err := json.Marshal(mapVal) + if err != nil { + return nil, err + } + if len(structBytes) == 2 { // must be "{}" + return mapBytes, nil + } + // "{X}" + "{Y}" => "{X,Y}" + res := append(structBytes[:len(structBytes)-1], ',') + res = append(res, mapBytes[1:]...) + return res, nil +} + +// unmarshalStructWithMap is the inverse of marshalStructWithMap. +// T has the same restrictions as in that function. +// +// Note that there is a similar function in mcp/util.go, but they are not the same. +// Here jsonNames also returns fields from embedded structs, hence this function +// handles embedded structs as well. +func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { + // Unmarshal into the struct, ignoring unknown fields. + if err := json.Unmarshal(data, v); err != nil { + return err + } + // Unmarshal into the map. + m := map[string]any{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + // Delete from the map the fields of the struct. + for n := range jsonNames(reflect.TypeFor[T]()) { + delete(m, n) + } + if len(m) != 0 { + reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) + } + return nil +} + +var jsonNamesMap sync.Map // from reflect.Type to map[string]bool + +// jsonNames returns the set of JSON object keys that t will marshal into, +// including fields from embedded structs in t. +// t must be a struct type. +// +// Note that there is a similar function in mcp/util.go, but they are not the same +// Here the function recurses over embedded structs and includes fields from them. +func jsonNames(t reflect.Type) map[string]bool { + // Lock not necessary: at worst we'll duplicate work. + if val, ok := jsonNamesMap.Load(t); ok { + return val.(map[string]bool) + } + m := map[string]bool{} + for i := range t.NumField() { + field := t.Field(i) + // handle embedded structs + if field.Anonymous { + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + for n := range jsonNames(fieldType) { + m[n] = true + } + continue + } + info := fieldJSONInfo(field) + if !info.omit { + m[info.name] = true + } + } + jsonNamesMap.Store(t, m) + return m +} + +type jsonInfo struct { + omit bool // unexported or first tag element is "-" + name string // Go field name or first tag element. Empty if omit is true. + settings map[string]bool // "omitempty", "omitzero", etc. +} + +// fieldJSONInfo reports information about how encoding/json +// handles the given struct field. +// If the field is unexported, jsonInfo.omit is true and no other jsonInfo field +// is populated. +// If the field is exported and has no tag, then name is the field's name and all +// other fields are false. +// Otherwise, the information is obtained from the tag. +func fieldJSONInfo(f reflect.StructField) jsonInfo { + if !f.IsExported() { + return jsonInfo{omit: true} + } + info := jsonInfo{name: f.Name} + if tag, ok := f.Tag.Lookup("json"); ok { + name, rest, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return jsonInfo{omit: true} + } + if name != "" { + info.name = name + } + if len(rest) > 0 { + info.settings = map[string]bool{} + for _, s := range strings.Split(rest, ",") { + info.settings[s] = true + } + } + } + return info +} + +// wrapf wraps *errp with the given formatted message if *errp is not nil. +func wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/google/jsonschema-go/jsonschema/validate.go b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go new file mode 100644 index 000000000..b895bbd41 --- /dev/null +++ b/vendor/github.com/google/jsonschema-go/jsonschema/validate.go @@ -0,0 +1,789 @@ +// Copyright 2025 The JSON Schema Go Project Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonschema + +import ( + "encoding/json" + "errors" + "fmt" + "hash/maphash" + "iter" + "math" + "math/big" + "reflect" + "slices" + "strings" + "sync" + "unicode/utf8" +) + +// The value of the "$schema" keyword for the version that we can validate. +const draft202012 = "https://json-schema.org/draft/2020-12/schema" + +// Validate validates the instance, which must be a JSON value, against the schema. +// It returns nil if validation is successful or an error if it is not. +// If the schema type is "object", instance can be a map[string]any or a struct. +func (rs *Resolved) Validate(instance any) error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + return st.validate(reflect.ValueOf(instance), st.rs.root, nil) +} + +// validateDefaults walks the schema tree. If it finds a default, it validates it +// against the schema containing it. +// +// TODO(jba): account for dynamic refs. This algorithm simple-mindedly +// treats each schema with a default as its own root. +func (rs *Resolved) validateDefaults() error { + if s := rs.root.Schema; s != "" && s != draft202012 { + return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) + } + st := &state{rs: rs} + for s := range rs.root.all() { + // We checked for nil schemas in [Schema.Resolve]. + assert(s != nil, "nil schema") + if s.DynamicRef != "" { + return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", rs.schemaString(s)) + } + if s.Default != nil { + var d any + if err := json.Unmarshal(s.Default, &d); err != nil { + return fmt.Errorf("unmarshaling default value of schema %s: %w", rs.schemaString(s), err) + } + if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { + return err + } + } + } + return nil +} + +// state is the state of single call to ResolvedSchema.Validate. +type state struct { + rs *Resolved + // stack holds the schemas from recursive calls to validate. + // These are the "dynamic scopes" used to resolve dynamic references. + // https://json-schema.org/draft/2020-12/json-schema-core#scopes + stack []*Schema +} + +// validate validates the reflected value of the instance. +func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { + defer wrapf(&err, "validating %s", st.rs.schemaString(schema)) + + // Maintain a stack for dynamic schema resolution. + st.stack = append(st.stack, schema) // push + defer func() { + st.stack = st.stack[:len(st.stack)-1] // pop + }() + + // We checked for nil schemas in [Schema.Resolve]. + assert(schema != nil, "nil schema") + + // Step through interfaces and pointers. + for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { + instance = instance.Elem() + } + + schemaInfo := st.rs.resolvedInfos[schema] + + // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 + if schema.Type != "" || schema.Types != nil { + gotType, ok := jsonType(instance) + if !ok { + return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) + } + if schema.Type != "" { + // "number" subsumes integers + if !(gotType == schema.Type || + gotType == "integer" && schema.Type == "number") { + return fmt.Errorf("type: %v has type %q, want %q", instance, gotType, schema.Type) + } + } else { + if !(slices.Contains(schema.Types, gotType) || (gotType == "integer" && slices.Contains(schema.Types, "number"))) { + return fmt.Errorf("type: %v has type %q, want one of %q", + instance, gotType, strings.Join(schema.Types, ", ")) + } + } + } + // enum: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.2 + if schema.Enum != nil { + ok := false + for _, e := range schema.Enum { + if equalValue(reflect.ValueOf(e), instance) { + ok = true + break + } + } + if !ok { + return fmt.Errorf("enum: %v does not equal any of: %v", instance, schema.Enum) + } + } + + // const: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.3 + if schema.Const != nil { + if !equalValue(reflect.ValueOf(*schema.Const), instance) { + return fmt.Errorf("const: %v does not equal %v", instance, *schema.Const) + } + } + + // numbers: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.2 + if schema.MultipleOf != nil || schema.Minimum != nil || schema.Maximum != nil || schema.ExclusiveMinimum != nil || schema.ExclusiveMaximum != nil { + n, ok := jsonNumber(instance) + if ok { // these keywords don't apply to non-numbers + if schema.MultipleOf != nil { + // TODO: validate MultipleOf as non-zero. + // The test suite assumes floats. + nf, _ := n.Float64() // don't care if it's exact or not + if _, f := math.Modf(nf / *schema.MultipleOf); f != 0 { + return fmt.Errorf("multipleOf: %s is not a multiple of %f", n, *schema.MultipleOf) + } + } + + m := new(big.Rat) // reuse for all of the following + cmp := func(f float64) int { return n.Cmp(m.SetFloat64(f)) } + + if schema.Minimum != nil && cmp(*schema.Minimum) < 0 { + return fmt.Errorf("minimum: %s is less than %f", n, *schema.Minimum) + } + if schema.Maximum != nil && cmp(*schema.Maximum) > 0 { + return fmt.Errorf("maximum: %s is greater than %f", n, *schema.Maximum) + } + if schema.ExclusiveMinimum != nil && cmp(*schema.ExclusiveMinimum) <= 0 { + return fmt.Errorf("exclusiveMinimum: %s is less than or equal to %f", n, *schema.ExclusiveMinimum) + } + if schema.ExclusiveMaximum != nil && cmp(*schema.ExclusiveMaximum) >= 0 { + return fmt.Errorf("exclusiveMaximum: %s is greater than or equal to %f", n, *schema.ExclusiveMaximum) + } + } + } + + // strings: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.3 + if instance.Kind() == reflect.String && (schema.MinLength != nil || schema.MaxLength != nil || schema.Pattern != "") { + str := instance.String() + n := utf8.RuneCountInString(str) + if schema.MinLength != nil { + if m := *schema.MinLength; n < m { + return fmt.Errorf("minLength: %q contains %d Unicode code points, fewer than %d", str, n, m) + } + } + if schema.MaxLength != nil { + if m := *schema.MaxLength; n > m { + return fmt.Errorf("maxLength: %q contains %d Unicode code points, more than %d", str, n, m) + } + } + + if schema.Pattern != "" && !schemaInfo.pattern.MatchString(str) { + return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) + } + } + + var anns annotations // all the annotations for this call and child calls + + // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 + if schema.Ref != "" { + if err := st.validate(instance, schemaInfo.resolvedRef, &anns); err != nil { + return err + } + } + + // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 + if schema.DynamicRef != "" { + // The ref behaves lexically or dynamically, but not both. + assert((schemaInfo.resolvedDynamicRef == nil) != (schemaInfo.dynamicRefAnchor == ""), + "DynamicRef not resolved properly") + if schemaInfo.resolvedDynamicRef != nil { + // Same as $ref. + if err := st.validate(instance, schemaInfo.resolvedDynamicRef, &anns); err != nil { + return err + } + } else { + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + var dynamicSchema *Schema + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[schemaInfo.dynamicRefAnchor] + if ok && info.dynamic { + dynamicSchema = info.schema + break + } + } + if dynamicSchema == nil { + return fmt.Errorf("missing dynamic anchor %q", schemaInfo.dynamicRefAnchor) + } + if err := st.validate(instance, dynamicSchema, &anns); err != nil { + return err + } + } + } + + // logic + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 + // These must happen before arrays and objects because if they evaluate an item or property, + // then the unevaluatedItems/Properties schemas don't apply to it. + // See https://json-schema.org/draft/2020-12/json-schema-core#section-11.2, paragraph 4. + // + // If any of these fail, then validation fails, even if there is an unevaluatedXXX + // keyword in the schema. The spec is unclear about this, but that is the intention. + + valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } + + if schema.AllOf != nil { + for _, ss := range schema.AllOf { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + if schema.AnyOf != nil { + // We must visit them all, to collect annotations. + ok := false + for _, ss := range schema.AnyOf { + if valid(ss, &anns) { + ok = true + } + } + if !ok { + return fmt.Errorf("anyOf: did not validate against any of %v", schema.AnyOf) + } + } + if schema.OneOf != nil { + // Exactly one. + var okSchema *Schema + for _, ss := range schema.OneOf { + if valid(ss, &anns) { + if okSchema != nil { + return fmt.Errorf("oneOf: validated against both %v and %v", okSchema, ss) + } + okSchema = ss + } + } + if okSchema == nil { + return fmt.Errorf("oneOf: did not validate against any of %v", schema.OneOf) + } + } + if schema.Not != nil { + // Ignore annotations from "not". + if valid(schema.Not, nil) { + return fmt.Errorf("not: validated against %v", schema.Not) + } + } + if schema.If != nil { + var ss *Schema + if valid(schema.If, &anns) { + ss = schema.Then + } else { + ss = schema.Else + } + if ss != nil { + if err := st.validate(instance, ss, &anns); err != nil { + return err + } + } + } + + // arrays + // TODO(jba): consider arrays of structs. + if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 + // This validate call doesn't collect annotations for the items of the instance; they are separate + // instances in their own right. + // TODO(jba): if the test suite doesn't cover this case, add a test. For example, nested arrays. + for i, ischema := range schema.PrefixItems { + if i >= instance.Len() { + break // shorter is OK + } + if err := st.validate(instance.Index(i), ischema, nil); err != nil { + return err + } + } + anns.noteEndIndex(min(len(schema.PrefixItems), instance.Len())) + + if schema.Items != nil { + for i := len(schema.PrefixItems); i < instance.Len(); i++ { + if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { + return err + } + } + // Note that all the items in this array have been validated. + anns.allItems = true + } + + nContains := 0 + if schema.Contains != nil { + for i := range instance.Len() { + if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { + nContains++ + anns.noteIndex(i) + } + } + if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { + return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.4 + // TODO(jba): check that these next four keywords' values are integers. + if schema.MinContains != nil && schema.Contains != nil { + if m := *schema.MinContains; nContains < m { + return fmt.Errorf("minContains: contains validated %d items, less than %d", nContains, m) + } + } + if schema.MaxContains != nil && schema.Contains != nil { + if m := *schema.MaxContains; nContains > m { + return fmt.Errorf("maxContains: contains validated %d items, greater than %d", nContains, m) + } + } + if schema.MinItems != nil { + if m := *schema.MinItems; instance.Len() < m { + return fmt.Errorf("minItems: array length %d is less than %d", instance.Len(), m) + } + } + if schema.MaxItems != nil { + if m := *schema.MaxItems; instance.Len() > m { + return fmt.Errorf("maxItems: array length %d is greater than %d", instance.Len(), m) + } + } + if schema.UniqueItems { + if instance.Len() > 1 { + // Hash each item and compare the hashes. + // If two hashes differ, the items differ. + // If two hashes are the same, compare the collisions for equality. + // (The same logic as hash table lookup.) + // TODO(jba): Use container/hash.Map when it becomes available (https://go.dev/issue/69559), + hashes := map[uint64][]int{} // from hash to indices + seed := maphash.MakeSeed() + for i := range instance.Len() { + item := instance.Index(i) + var h maphash.Hash + h.SetSeed(seed) + hashValue(&h, item) + hv := h.Sum64() + if sames := hashes[hv]; len(sames) > 0 { + for _, j := range sames { + if equalValue(item, instance.Index(j)) { + return fmt.Errorf("uniqueItems: array items %d and %d are equal", i, j) + } + } + } + hashes[hv] = append(hashes[hv], i) + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-11.2 + if schema.UnevaluatedItems != nil && !anns.allItems { + // Apply this subschema to all items in the array that haven't been successfully validated. + // That includes validations by subschemas on the same instance, like allOf. + for i := anns.endIndex; i < instance.Len(); i++ { + if !anns.evaluatedIndexes[i] { + if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { + return err + } + } + } + anns.allItems = true + } + } + + // objects + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 + // Validating structs is problematic. See https://github.com/google/jsonschema-go/issues/23. + if instance.Kind() == reflect.Struct { + return errors.New("cannot validate against a struct; see https://github.com/google/jsonschema-go/issues/23 for details") + } + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + // Track the evaluated properties for just this schema, to support additionalProperties. + // If we used anns here, then we'd be including properties evaluated in subschemas + // from allOf, etc., which additionalProperties shouldn't observe. + evalProps := map[string]bool{} + for prop, subschema := range schema.Properties { + val := property(instance, prop) + if !val.IsValid() { + // It's OK if the instance doesn't have the property. + continue + } + // If the instance is a struct and an optional property has the zero + // value, then we could interpret it as present or missing. Be generous: + // assume it's missing, and thus always validates successfully. + if instance.Kind() == reflect.Struct && val.IsZero() && !schemaInfo.isRequired[prop] { + continue + } + if err := st.validate(val, subschema, nil); err != nil { + return err + } + evalProps[prop] = true + } + if len(schema.PatternProperties) > 0 { + for prop, val := range properties(instance) { + // Check every matching pattern. + for re, schema := range schemaInfo.patternProperties { + if re.MatchString(prop) { + if err := st.validate(val, schema, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + if schema.AdditionalProperties != nil { + // Special case for a better error message when additional properties is + // 'falsy' + // + // If additionalProperties is {"not":{}} (which is how we + // unmarshal "false"), we can produce a better error message that + // summarizes all the extra properties. Otherwise, we fall back to the + // default validation. + // + // Note: this is much faster than comparing with falseSchema using Equal. + isFalsy := schema.AdditionalProperties.Not != nil && reflect.ValueOf(*schema.AdditionalProperties.Not).IsZero() + if isFalsy { + var disallowed []string + for prop := range properties(instance) { + if !evalProps[prop] { + disallowed = append(disallowed, prop) + } + } + if len(disallowed) > 0 { + return fmt.Errorf("unexpected additional properties %q", disallowed) + } + } else { + // Apply to all properties not handled above. + for prop, val := range properties(instance) { + if !evalProps[prop] { + if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { + return err + } + evalProps[prop] = true + } + } + } + } + anns.noteProperties(evalProps) + if schema.PropertyNames != nil { + // Note: properties unnecessarily fetches each value. We could define a propertyNames function + // if performance ever matters. + for prop := range properties(instance) { + if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { + return err + } + } + } + + // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 + var min, max int + if schema.MinProperties != nil || schema.MaxProperties != nil { + min, max = numPropertiesBounds(instance, schemaInfo.isRequired) + } + if schema.MinProperties != nil { + if n, m := max, *schema.MinProperties; n < m { + return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) + } + } + if schema.MaxProperties != nil { + if n, m := min, *schema.MaxProperties; n > m { + return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) + } + } + + hasProperty := func(prop string) bool { + return property(instance, prop).IsValid() + } + + missingProperties := func(props []string) []string { + var missing []string + for _, p := range props { + if !hasProperty(p) { + missing = append(missing, p) + } + } + return missing + } + + if schema.Required != nil { + if m := missingProperties(schema.Required); len(m) > 0 { + return fmt.Errorf("required: missing properties: %q", m) + } + } + if schema.DependentRequired != nil { + // "Validation succeeds if, for each name that appears in both the instance + // and as a name within this keyword's value, every item in the corresponding + // array is also the name of a property in the instance." §6.5.4 + for dprop, reqs := range schema.DependentRequired { + if hasProperty(dprop) { + if m := missingProperties(reqs); len(m) > 0 { + return fmt.Errorf("dependentRequired[%q]: missing properties %q", dprop, m) + } + } + } + } + + // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2.2.4 + if schema.DependentSchemas != nil { + // This does not collect annotations, although it seems like it should. + for dprop, ss := range schema.DependentSchemas { + if hasProperty(dprop) { + // TODO: include dependentSchemas[dprop] in the errors. + err := st.validate(instance, ss, &anns) + if err != nil { + return err + } + } + } + } + if schema.UnevaluatedProperties != nil && !anns.allProperties { + // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf + // in addition to sibling keywords. + for prop, val := range properties(instance) { + if !anns.evaluatedProperties[prop] { + if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { + return err + } + } + } + // The spec says the annotation should be the set of evaluated properties, but we can optimize + // by setting a single boolean, since after this succeeds all properties will be validated. + // See https://json-schema.slack.com/archives/CT7FF623C/p1745592564381459. + anns.allProperties = true + } + } + + if callerAnns != nil { + // Our caller wants to know what we've validated. + callerAnns.merge(&anns) + } + return nil +} + +// resolveDynamicRef returns the schema referred to by the argument schema's +// $dynamicRef value. +// It returns an error if the dynamic reference has no referent. +// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. +// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. +func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { + if schema.DynamicRef == "" { + return nil, nil + } + info := st.rs.resolvedInfos[schema] + // The ref behaves lexically or dynamically, but not both. + assert((info.resolvedDynamicRef == nil) != (info.dynamicRefAnchor == ""), + "DynamicRef not statically resolved properly") + if r := info.resolvedDynamicRef; r != nil { + // Same as $ref. + return r, nil + } + // Dynamic behavior. + // Look for the base of the outermost schema on the stack with this dynamic + // anchor. (Yes, outermost: the one farthest from here. This the opposite + // of how ordinary dynamic variables behave.) + // Why the base of the schema being validated and not the schema itself? + // Because the base is the scope for anchors. In fact it's possible to + // refer to a schema that is not on the stack, but a child of some base + // on the stack. + // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. + for _, s := range st.stack { + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[info.dynamicRefAnchor] + if ok && info.dynamic { + return info.schema, nil + } + } + return nil, fmt.Errorf("missing dynamic anchor %q", info.dynamicRefAnchor) +} + +// ApplyDefaults modifies an instance by applying the schema's defaults to it. If +// a schema or sub-schema has a default, then a corresponding zero instance value +// is set to the default. +// +// The JSON Schema specification does not describe how defaults should be interpreted. +// This method honors defaults only on properties, and only those that are not required. +// If the instance is a map and the property is missing, the property is added to +// the map with the default. +// If the instance is a struct, the field corresponding to the property exists, and +// its value is zero, the field is set to the default. +// ApplyDefaults can panic if a default cannot be assigned to a field. +// +// The argument must be a pointer to the instance. +// (In case we decide that top-level defaults are meaningful.) +// +// It is recommended to first call Resolve with a ValidateDefaults option of true, +// then call this method, and lastly call Validate. +func (rs *Resolved) ApplyDefaults(instancep any) error { + // TODO(jba): consider what defaults on top-level or array instances might mean. + // TODO(jba): follow $ref and $dynamicRef + // TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. + st := &state{rs: rs} + return st.applyDefaults(reflect.ValueOf(instancep), rs.root) +} + +// Leave this as a potentially recursive helper function, because we'll surely want +// to apply defaults on sub-schemas someday. +func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { + defer wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) + + schemaInfo := st.rs.resolvedInfos[schema] + instance := instancep.Elem() + if instance.Kind() == reflect.Interface && instance.IsValid() { + // If we unmarshalled into 'any', the default object unmarshalling will be map[string]any. + instance = instance.Elem() + } + if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { + if instance.Kind() == reflect.Map { + if kt := instance.Type().Key(); kt.Kind() != reflect.String { + return fmt.Errorf("map key type %s is not a string", kt) + } + } + for prop, subschema := range schema.Properties { + // Ignore defaults on required properties. (A required property shouldn't have a default.) + if schemaInfo.isRequired[prop] { + continue + } + val := property(instance, prop) + switch instance.Kind() { + case reflect.Map: + // If there is a default for this property, and the map key is missing, + // set the map value to the default. + if subschema.Default != nil && !val.IsValid() { + // Create an lvalue, since map values aren't addressable. + lvalue := reflect.New(instance.Type().Elem()) + if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { + return err + } + instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) + } + case reflect.Struct: + // If there is a default for this property, and the field exists but is zero, + // set the field to the default. + if subschema.Default != nil && val.IsValid() && val.IsZero() { + if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { + return err + } + } + default: + panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", + prop, instance, instance.Kind())) + } + } + } + return nil +} + +// property returns the value of the property of v with the given name, or the invalid +// reflect.Value if there is none. +// If v is a map, the property is the value of the map whose key is name. +// If v is a struct, the property is the value of the field with the given name according +// to the encoding/json package (see [jsonName]). +// If v is anything else, property panics. +func property(v reflect.Value, name string) reflect.Value { + switch v.Kind() { + case reflect.Map: + return v.MapIndex(reflect.ValueOf(name)) + case reflect.Struct: + props := structPropertiesOf(v.Type()) + // Ignore nonexistent properties. + if sf, ok := props[name]; ok { + return v.FieldByIndex(sf.Index) + } + return reflect.Value{} + default: + panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) + } +} + +// properties returns an iterator over the names and values of all properties +// in v, which must be a map or a struct. +// If a struct, zero-valued properties that are marked omitempty or omitzero +// are excluded. +func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { + return func(yield func(string, reflect.Value) bool) { + switch v.Kind() { + case reflect.Map: + for k, e := range v.Seq2() { + if !yield(k.String(), e) { + return + } + } + case reflect.Struct: + for name, sf := range structPropertiesOf(v.Type()) { + val := v.FieldByIndex(sf.Index) + if val.IsZero() { + info := fieldJSONInfo(sf) + if info.settings["omitempty"] || info.settings["omitzero"] { + continue + } + } + if !yield(name, val) { + return + } + } + default: + panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) + } + } +} + +// numPropertiesBounds returns bounds on the number of v's properties. +// v must be a map or a struct. +// If v is a map, both bounds are the map's size. +// If v is a struct, the max is the number of struct properties. +// But since we don't know whether a zero value indicates a missing optional property +// or not, be generous and use the number of non-zero properties as the min. +func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { + switch v.Kind() { + case reflect.Map: + return v.Len(), v.Len() + case reflect.Struct: + sp := structPropertiesOf(v.Type()) + min := 0 + for prop, sf := range sp { + if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { + min++ + } + } + return min, len(sp) + default: + panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) + } +} + +// A propertyMap is a map from property name to struct field index. +type propertyMap = map[string]reflect.StructField + +var structProperties sync.Map // from reflect.Type to propertyMap + +// structPropertiesOf returns the JSON Schema properties for the struct type t. +// The caller must not mutate the result. +func structPropertiesOf(t reflect.Type) propertyMap { + // Mutex not necessary: at worst we'll recompute the same value. + if props, ok := structProperties.Load(t); ok { + return props.(propertyMap) + } + props := map[string]reflect.StructField{} + for _, sf := range reflect.VisibleFields(t) { + if sf.Anonymous { + continue + } + info := fieldJSONInfo(sf) + if !info.omit { + props[info.name] = sf + } + } + structProperties.Store(t, props) + return props +} diff --git a/vendor/github.com/invopop/jsonschema/.gitignore b/vendor/github.com/invopop/jsonschema/.gitignore deleted file mode 100644 index 8ef0e14fc..000000000 --- a/vendor/github.com/invopop/jsonschema/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -vendor/ -.idea/ diff --git a/vendor/github.com/invopop/jsonschema/.golangci.yml b/vendor/github.com/invopop/jsonschema/.golangci.yml deleted file mode 100644 index b89b2e124..000000000 --- a/vendor/github.com/invopop/jsonschema/.golangci.yml +++ /dev/null @@ -1,69 +0,0 @@ -run: - tests: true - max-same-issues: 50 - -output: - print-issued-lines: false - -linters: - enable: - - gocyclo - - gocritic - - goconst - - dupl - - unconvert - - goimports - - unused - - govet - - nakedret - - errcheck - - revive - - ineffassign - - goconst - - unparam - - gofmt - -linters-settings: - vet: - check-shadowing: true - use-installed-packages: true - dupl: - threshold: 100 - goconst: - min-len: 8 - min-occurrences: 3 - gocyclo: - min-complexity: 20 - gocritic: - disabled-checks: - - ifElseChain - gofmt: - rewrite-rules: - - pattern: "interface{}" - replacement: "any" - - pattern: "a[b:len(a)]" - replacement: "a[b:]" - -issues: - max-per-linter: 0 - max-same: 0 - exclude-dirs: - - resources - - old - exclude-files: - - cmd/protopkg/main.go - exclude-use-default: false - exclude: - # Captured by errcheck. - - "^(G104|G204):" - # Very commonly not checked. - - 'Error return value of .(.*\.Help|.*\.MarkFlagRequired|(os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*Print(f|ln|)|os\.(Un)?Setenv). is not checked' - # Weird error only seen on Kochiku... - - "internal error: no range for" - - 'exported method `.*\.(MarshalJSON|UnmarshalJSON|URN|Payload|GoString|Close|Provides|Requires|ExcludeFromHash|MarshalText|UnmarshalText|Description|Check|Poll|Severity)` should have comment or be unexported' - - "composite literal uses unkeyed fields" - - 'declaration of "err" shadows declaration' - - "by other packages, and that stutters" - - "Potential file inclusion via variable" - - "at least one file in a package should have a package comment" - - "bad syntax for struct tag pair" diff --git a/vendor/github.com/invopop/jsonschema/COPYING b/vendor/github.com/invopop/jsonschema/COPYING deleted file mode 100644 index 2993ec085..000000000 --- a/vendor/github.com/invopop/jsonschema/COPYING +++ /dev/null @@ -1,19 +0,0 @@ -Copyright (C) 2014 Alec Thomas - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/invopop/jsonschema/README.md b/vendor/github.com/invopop/jsonschema/README.md deleted file mode 100644 index 27b362e1d..000000000 --- a/vendor/github.com/invopop/jsonschema/README.md +++ /dev/null @@ -1,374 +0,0 @@ -# Go JSON Schema Reflection - -[![Lint](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml) -[![Test Go](https://github.com/invopop/jsonschema/actions/workflows/test.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/test.yaml) -[![Go Report Card](https://goreportcard.com/badge/github.com/invopop/jsonschema)](https://goreportcard.com/report/github.com/invopop/jsonschema) -[![GoDoc](https://godoc.org/github.com/invopop/jsonschema?status.svg)](https://godoc.org/github.com/invopop/jsonschema) -[![codecov](https://codecov.io/gh/invopop/jsonschema/graph/badge.svg?token=JMEB8W8GNZ)](https://codecov.io/gh/invopop/jsonschema) -![Latest Tag](https://img.shields.io/github/v/tag/invopop/jsonschema) - -This package can be used to generate [JSON Schemas](http://json-schema.org/latest/json-schema-validation.html) from Go types through reflection. - -- Supports arbitrarily complex types, including `interface{}`, maps, slices, etc. -- Supports json-schema features such as minLength, maxLength, pattern, format, etc. -- Supports simple string and numeric enums. -- Supports custom property fields via the `jsonschema_extras` struct tag. - -This repository is a fork of the original [jsonschema](https://github.com/alecthomas/jsonschema) by [@alecthomas](https://github.com/alecthomas). At [Invopop](https://invopop.com) we use jsonschema as a cornerstone in our [GOBL library](https://github.com/invopop/gobl), and wanted to be able to continue building and adding features without taking up Alec's time. There have been a few significant changes that probably mean this version is a not compatible with with Alec's: - -- The original was stuck on the draft-04 version of JSON Schema, we've now moved to the latest JSON Schema Draft 2020-12. -- Schema IDs are added automatically from the current Go package's URL in order to be unique, and can be disabled with the `Anonymous` option. -- Support for the `FullyQualifyTypeName` option has been removed. If you have conflicts, you should use multiple schema files with different IDs, set the `DoNotReference` option to true to hide definitions completely, or add your own naming strategy using the `Namer` property. -- Support for `yaml` tags and related options has been dropped for the sake of simplification. There were a [few inconsistencies](https://github.com/invopop/jsonschema/pull/21) around this that have now been fixed. - -## Versions - -This project is still under v0 scheme, as per Go convention, breaking changes are likely. Please pin go modules to version tags or branches, and reach out if you think something can be improved. - -Go version >= 1.18 is required as generics are now being used. - -## Example - -The following Go type: - -```go -type TestUser struct { - ID int `json:"id"` - Name string `json:"name" jsonschema:"title=the name,description=The name of a friend,example=joe,example=lucy,default=alex"` - Friends []int `json:"friends,omitempty" jsonschema_description:"The list of IDs, omitted when empty"` - Tags map[string]interface{} `json:"tags,omitempty" jsonschema_extras:"a=b,foo=bar,foo=bar1"` - BirthDate time.Time `json:"birth_date,omitempty" jsonschema:"oneof_required=date"` - YearOfBirth string `json:"year_of_birth,omitempty" jsonschema:"oneof_required=year"` - Metadata interface{} `json:"metadata,omitempty" jsonschema:"oneof_type=string;array"` - FavColor string `json:"fav_color,omitempty" jsonschema:"enum=red,enum=green,enum=blue"` -} -``` - -Results in following JSON Schema: - -```go -jsonschema.Reflect(&TestUser{}) -``` - -```json -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://github.com/invopop/jsonschema_test/test-user", - "$ref": "#/$defs/TestUser", - "$defs": { - "TestUser": { - "oneOf": [ - { - "required": ["birth_date"], - "title": "date" - }, - { - "required": ["year_of_birth"], - "title": "year" - } - ], - "properties": { - "id": { - "type": "integer" - }, - "name": { - "type": "string", - "title": "the name", - "description": "The name of a friend", - "default": "alex", - "examples": ["joe", "lucy"] - }, - "friends": { - "items": { - "type": "integer" - }, - "type": "array", - "description": "The list of IDs, omitted when empty" - }, - "tags": { - "type": "object", - "a": "b", - "foo": ["bar", "bar1"] - }, - "birth_date": { - "type": "string", - "format": "date-time" - }, - "year_of_birth": { - "type": "string" - }, - "metadata": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array" - } - ] - }, - "fav_color": { - "type": "string", - "enum": ["red", "green", "blue"] - } - }, - "additionalProperties": false, - "type": "object", - "required": ["id", "name"] - } - } -} -``` - -## YAML - -Support for `yaml` tags has now been removed. If you feel very strongly about this, we've opened a discussion to hear your comments: https://github.com/invopop/jsonschema/discussions/28 - -The recommended approach if you need to deal with YAML data is to first convert to JSON. The [invopop/yaml](https://github.com/invopop/yaml) library will make this trivial. - -## Configurable behaviour - -The behaviour of the schema generator can be altered with parameters when a `jsonschema.Reflector` -instance is created. - -### ExpandedStruct - -If set to `true`, makes the top level struct not to reference itself in the definitions. But type passed should be a struct type. - -eg. - -```go -type GrandfatherType struct { - FamilyName string `json:"family_name" jsonschema:"required"` -} - -type SomeBaseType struct { - SomeBaseProperty int `json:"some_base_property"` - // The jsonschema required tag is nonsensical for private and ignored properties. - // Their presence here tests that the fields *will not* be required in the output - // schema, even if they are tagged required. - somePrivateBaseProperty string `json:"i_am_private" jsonschema:"required"` - SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"` - SomeSchemaIgnoredProperty string `jsonschema:"-,required"` - SomeUntaggedBaseProperty bool `jsonschema:"required"` - someUnexportedUntaggedBaseProperty bool - Grandfather GrandfatherType `json:"grand"` -} -``` - -will output: - -```json -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "required": ["some_base_property", "grand", "SomeUntaggedBaseProperty"], - "properties": { - "SomeUntaggedBaseProperty": { - "type": "boolean" - }, - "grand": { - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$ref": "#/definitions/GrandfatherType" - }, - "some_base_property": { - "type": "integer" - } - }, - "type": "object", - "$defs": { - "GrandfatherType": { - "required": ["family_name"], - "properties": { - "family_name": { - "type": "string" - } - }, - "additionalProperties": false, - "type": "object" - } - } -} -``` - -### Using Go Comments - -Writing a good schema with descriptions inside tags can become cumbersome and tedious, especially if you already have some Go comments around your types and field definitions. If you'd like to take advantage of these existing comments, you can use the `AddGoComments(base, path string)` method that forms part of the reflector to parse your go files and automatically generate a dictionary of Go import paths, types, and fields, to individual comments. These will then be used automatically as description fields, and can be overridden with a manual definition if needed. - -Take a simplified example of a User struct which for the sake of simplicity we assume is defined inside this package: - -```go -package main - -// User is used as a base to provide tests for comments. -type User struct { - // Unique sequential identifier. - ID int `json:"id" jsonschema:"required"` - // Name of the user - Name string `json:"name"` -} -``` - -To get the comments provided into your JSON schema, use a regular `Reflector` and add the go code using an import module URL and path. Fully qualified go module paths cannot be determined reliably by the `go/parser` library, so we need to introduce this manually: - -```go -r := new(Reflector) -if err := r.AddGoComments("github.com/invopop/jsonschema", "./"); err != nil { - // deal with error -} -s := r.Reflect(&User{}) -// output -``` - -Expect the results to be similar to: - -```json -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/User", - "$defs": { - "User": { - "required": ["id"], - "properties": { - "id": { - "type": "integer", - "description": "Unique sequential identifier." - }, - "name": { - "type": "string", - "description": "Name of the user" - } - }, - "additionalProperties": false, - "type": "object", - "description": "User is used as a base to provide tests for comments." - } - } -} -``` - -### Custom Key Naming - -In some situations, the keys actually used to write files are different from Go structs'. - -This is often the case when writing a configuration file to YAML or JSON from a Go struct, or when returning a JSON response for a Web API: APIs typically use snake_case, while Go uses PascalCase. - -You can pass a `func(string) string` function to `Reflector`'s `KeyNamer` option to map Go field names to JSON key names and reflect the aforementioned transformations, without having to specify `json:"..."` on every struct field. - -For example, consider the following struct - -```go -type User struct { - GivenName string - PasswordSalted []byte `json:"salted_password"` -} -``` - -We can transform field names to snake_case in the generated JSON schema: - -```go -r := new(jsonschema.Reflector) -r.KeyNamer = strcase.SnakeCase // from package github.com/stoewer/go-strcase - -r.Reflect(&User{}) -``` - -Will yield - -```diff - { - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/User", - "$defs": { - "User": { - "properties": { -- "GivenName": { -+ "given_name": { - "type": "string" - }, - "salted_password": { - "type": "string", - "contentEncoding": "base64" - } - }, - "additionalProperties": false, - "type": "object", -- "required": ["GivenName", "salted_password"] -+ "required": ["given_name", "salted_password"] - } - } - } -``` - -As you can see, if a field name has a `json:""` tag set, the `key` argument to `KeyNamer` will have the value of that tag. - -### Custom Type Definitions - -Sometimes it can be useful to have custom JSON Marshal and Unmarshal methods in your structs that automatically convert for example a string into an object. - -This library will recognize and attempt to call four different methods that help you adjust schemas to your specific needs: - -- `JSONSchema() *Schema` - will prevent auto-generation of the schema so that you can provide your own definition. -- `JSONSchemaExtend(schema *jsonschema.Schema)` - will be called _after_ the schema has been generated, allowing you to add or manipulate the fields easily. -- `JSONSchemaAlias() any` - is called when reflecting the type of object and allows for an alternative to be used instead. -- `JSONSchemaProperty(prop string) any` - will be called for every property inside a struct giving you the chance to provide an alternative object to convert into a schema. - -Note that all of these methods **must** be defined on a non-pointer object for them to be called. - -Take the following simplified example of a `CompactDate` that only includes the Year and Month: - -```go -type CompactDate struct { - Year int - Month int -} - -func (d *CompactDate) UnmarshalJSON(data []byte) error { - if len(data) != 9 { - return errors.New("invalid compact date length") - } - var err error - d.Year, err = strconv.Atoi(string(data[1:5])) - if err != nil { - return err - } - d.Month, err = strconv.Atoi(string(data[7:8])) - if err != nil { - return err - } - return nil -} - -func (d *CompactDate) MarshalJSON() ([]byte, error) { - buf := new(bytes.Buffer) - buf.WriteByte('"') - buf.WriteString(fmt.Sprintf("%d-%02d", d.Year, d.Month)) - buf.WriteByte('"') - return buf.Bytes(), nil -} - -func (CompactDate) JSONSchema() *Schema { - return &Schema{ - Type: "string", - Title: "Compact Date", - Description: "Short date that only includes year and month", - Pattern: "^[0-9]{4}-[0-1][0-9]$", - } -} -``` - -The resulting schema generated for this struct would look like: - -```json -{ - "$schema": "http://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/CompactDate", - "$defs": { - "CompactDate": { - "pattern": "^[0-9]{4}-[0-1][0-9]$", - "type": "string", - "title": "Compact Date", - "description": "Short date that only includes year and month" - } - } -} -``` diff --git a/vendor/github.com/invopop/jsonschema/id.go b/vendor/github.com/invopop/jsonschema/id.go deleted file mode 100644 index 73fafb38d..000000000 --- a/vendor/github.com/invopop/jsonschema/id.go +++ /dev/null @@ -1,76 +0,0 @@ -package jsonschema - -import ( - "errors" - "fmt" - "net/url" - "strings" -) - -// ID represents a Schema ID type which should always be a URI. -// See draft-bhutton-json-schema-00 section 8.2.1 -type ID string - -// EmptyID is used to explicitly define an ID with no value. -const EmptyID ID = "" - -// Validate is used to check if the ID looks like a proper schema. -// This is done by parsing the ID as a URL and checking it has all the -// relevant parts. -func (id ID) Validate() error { - u, err := url.Parse(id.String()) - if err != nil { - return fmt.Errorf("invalid URL: %w", err) - } - if u.Hostname() == "" { - return errors.New("missing hostname") - } - if !strings.Contains(u.Hostname(), ".") { - return errors.New("hostname does not look valid") - } - if u.Path == "" { - return errors.New("path is expected") - } - if u.Scheme != "https" && u.Scheme != "http" { - return errors.New("unexpected schema") - } - return nil -} - -// Anchor sets the anchor part of the schema URI. -func (id ID) Anchor(name string) ID { - b := id.Base() - return ID(b.String() + "#" + name) -} - -// Def adds or replaces a definition identifier. -func (id ID) Def(name string) ID { - b := id.Base() - return ID(b.String() + "#/$defs/" + name) -} - -// Add appends the provided path to the id, and removes any -// anchor data that might be there. -func (id ID) Add(path string) ID { - b := id.Base() - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - return ID(b.String() + path) -} - -// Base removes any anchor information from the schema -func (id ID) Base() ID { - s := id.String() - i := strings.LastIndex(s, "#") - if i != -1 { - s = s[0:i] - } - s = strings.TrimRight(s, "/") - return ID(s) -} - -// String provides string version of ID -func (id ID) String() string { - return string(id) -} diff --git a/vendor/github.com/invopop/jsonschema/reflect.go b/vendor/github.com/invopop/jsonschema/reflect.go deleted file mode 100644 index 73ce7e465..000000000 --- a/vendor/github.com/invopop/jsonschema/reflect.go +++ /dev/null @@ -1,1148 +0,0 @@ -// Package jsonschema uses reflection to generate JSON Schemas from Go types [1]. -// -// If json tags are present on struct fields, they will be used to infer -// property names and if a property is required (omitempty is present). -// -// [1] http://json-schema.org/latest/json-schema-validation.html -package jsonschema - -import ( - "bytes" - "encoding/json" - "net" - "net/url" - "reflect" - "strconv" - "strings" - "time" -) - -// customSchemaImpl is used to detect if the type provides it's own -// custom Schema Type definition to use instead. Very useful for situations -// where there are custom JSON Marshal and Unmarshal methods. -type customSchemaImpl interface { - JSONSchema() *Schema -} - -// Function to be run after the schema has been generated. -// this will let you modify a schema afterwards -type extendSchemaImpl interface { - JSONSchemaExtend(*Schema) -} - -// If the object to be reflected defines a `JSONSchemaAlias` method, its type will -// be used instead of the original type. -type aliasSchemaImpl interface { - JSONSchemaAlias() any -} - -// If an object to be reflected defines a `JSONSchemaPropertyAlias` method, -// it will be called for each property to determine if another object -// should be used for the contents. -type propertyAliasSchemaImpl interface { - JSONSchemaProperty(prop string) any -} - -var customAliasSchema = reflect.TypeOf((*aliasSchemaImpl)(nil)).Elem() -var customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).Elem() - -var customType = reflect.TypeOf((*customSchemaImpl)(nil)).Elem() -var extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).Elem() - -// customSchemaGetFieldDocString -type customSchemaGetFieldDocString interface { - GetFieldDocString(fieldName string) string -} - -type customGetFieldDocString func(fieldName string) string - -var customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem() - -// Reflect reflects to Schema from a value using the default Reflector -func Reflect(v any) *Schema { - return ReflectFromType(reflect.TypeOf(v)) -} - -// ReflectFromType generates root schema using the default Reflector -func ReflectFromType(t reflect.Type) *Schema { - r := &Reflector{} - return r.ReflectFromType(t) -} - -// A Reflector reflects values into a Schema. -type Reflector struct { - // BaseSchemaID defines the URI that will be used as a base to determine Schema - // IDs for models. For example, a base Schema ID of `https://invopop.com/schemas` - // when defined with a struct called `User{}`, will result in a schema with an - // ID set to `https://invopop.com/schemas/user`. - // - // If no `BaseSchemaID` is provided, we'll take the type's complete package path - // and use that as a base instead. Set `Anonymous` to try if you do not want to - // include a schema ID. - BaseSchemaID ID - - // Anonymous when true will hide the auto-generated Schema ID and provide what is - // known as an "anonymous schema". As a rule, this is not recommended. - Anonymous bool - - // AssignAnchor when true will use the original struct's name as an anchor inside - // every definition, including the root schema. These can be useful for having a - // reference to the original struct's name in CamelCase instead of the snake-case used - // by default for URI compatibility. - // - // Anchors do not appear to be widely used out in the wild, so at this time the - // anchors themselves will not be used inside generated schema. - AssignAnchor bool - - // AllowAdditionalProperties will cause the Reflector to generate a schema - // without additionalProperties set to 'false' for all struct types. This means - // the presence of additional keys in JSON objects will not cause validation - // to fail. Note said additional keys will simply be dropped when the - // validated JSON is unmarshaled. - AllowAdditionalProperties bool - - // RequiredFromJSONSchemaTags will cause the Reflector to generate a schema - // that requires any key tagged with `jsonschema:required`, overriding the - // default of requiring any key *not* tagged with `json:,omitempty`. - RequiredFromJSONSchemaTags bool - - // Do not reference definitions. This will remove the top-level $defs map and - // instead cause the entire structure of types to be output in one tree. The - // list of type definitions (`$defs`) will not be included. - DoNotReference bool - - // ExpandedStruct when true will include the reflected type's definition in the - // root as opposed to a definition with a reference. - ExpandedStruct bool - - // FieldNameTag will change the tag used to get field names. json tags are used by default. - FieldNameTag string - - // IgnoredTypes defines a slice of types that should be ignored in the schema, - // switching to just allowing additional properties instead. - IgnoredTypes []any - - // Lookup allows a function to be defined that will provide a custom mapping of - // types to Schema IDs. This allows existing schema documents to be referenced - // by their ID instead of being embedded into the current schema definitions. - // Reflected types will never be pointers, only underlying elements. - Lookup func(reflect.Type) ID - - // Mapper is a function that can be used to map custom Go types to jsonschema schemas. - Mapper func(reflect.Type) *Schema - - // Namer allows customizing of type names. The default is to use the type's name - // provided by the reflect package. - Namer func(reflect.Type) string - - // KeyNamer allows customizing of key names. - // The default is to use the key's name as is, or the json tag if present. - // If a json tag is present, KeyNamer will receive the tag's name as an argument, not the original key name. - KeyNamer func(string) string - - // AdditionalFields allows adding structfields for a given type - AdditionalFields func(reflect.Type) []reflect.StructField - - // LookupComment allows customizing comment lookup. Given a reflect.Type and optionally - // a field name, it should return the comment string associated with this type or field. - // - // If the field name is empty, it should return the type's comment; otherwise, the field's - // comment should be returned. If no comment is found, an empty string should be returned. - // - // When set, this function is called before the below CommentMap lookup mechanism. However, - // if it returns an empty string, the CommentMap is still consulted. - LookupComment func(reflect.Type, string) string - - // CommentMap is a dictionary of fully qualified go types and fields to comment - // strings that will be used if a description has not already been provided in - // the tags. Types and fields are added to the package path using "." as a - // separator. - // - // Type descriptions should be defined like: - // - // map[string]string{"github.com/invopop/jsonschema.Reflector": "A Reflector reflects values into a Schema."} - // - // And Fields defined as: - // - // map[string]string{"github.com/invopop/jsonschema.Reflector.DoNotReference": "Do not reference definitions."} - // - // See also: AddGoComments, LookupComment - CommentMap map[string]string -} - -// Reflect reflects to Schema from a value. -func (r *Reflector) Reflect(v any) *Schema { - return r.ReflectFromType(reflect.TypeOf(v)) -} - -// ReflectFromType generates root schema -func (r *Reflector) ReflectFromType(t reflect.Type) *Schema { - if t.Kind() == reflect.Ptr { - t = t.Elem() // re-assign from pointer - } - - name := r.typeName(t) - - s := new(Schema) - definitions := Definitions{} - s.Definitions = definitions - bs := r.reflectTypeToSchemaWithID(definitions, t) - if r.ExpandedStruct { - *s = *definitions[name] - delete(definitions, name) - } else { - *s = *bs - } - - // Attempt to set the schema ID - if !r.Anonymous && s.ID == EmptyID { - baseSchemaID := r.BaseSchemaID - if baseSchemaID == EmptyID { - id := ID("https://" + t.PkgPath()) - if err := id.Validate(); err == nil { - // it's okay to silently ignore URL errors - baseSchemaID = id - } - } - if baseSchemaID != EmptyID { - s.ID = baseSchemaID.Add(ToSnakeCase(name)) - } - } - - s.Version = Version - if !r.DoNotReference { - s.Definitions = definitions - } - - return s -} - -// Available Go defined types for JSON Schema Validation. -// RFC draft-wright-json-schema-validation-00, section 7.3 -var ( - timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1 - ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5 - uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6 -) - -// Byte slices will be encoded as base64 -var byteSliceType = reflect.TypeOf([]byte(nil)) - -// Except for json.RawMessage -var rawMessageType = reflect.TypeOf(json.RawMessage{}) - -// Go code generated from protobuf enum types should fulfil this interface. -type protoEnum interface { - EnumDescriptor() ([]byte, []int) -} - -var protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem() - -// SetBaseSchemaID is a helper use to be able to set the reflectors base -// schema ID from a string as opposed to then ID instance. -func (r *Reflector) SetBaseSchemaID(id string) { - r.BaseSchemaID = ID(id) -} - -func (r *Reflector) refOrReflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { - id := r.lookupID(t) - if id != EmptyID { - return &Schema{ - Ref: id.String(), - } - } - - // Already added to definitions? - if def := r.refDefinition(definitions, t); def != nil { - return def - } - - return r.reflectTypeToSchemaWithID(definitions, t) -} - -func (r *Reflector) reflectTypeToSchemaWithID(defs Definitions, t reflect.Type) *Schema { - s := r.reflectTypeToSchema(defs, t) - if s != nil { - if r.Lookup != nil { - id := r.Lookup(t) - if id != EmptyID { - s.ID = id - } - } - } - return s -} - -func (r *Reflector) reflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { - // only try to reflect non-pointers - if t.Kind() == reflect.Ptr { - return r.refOrReflectTypeToSchema(definitions, t.Elem()) - } - - // Check if the there is an alias method that provides an object - // that we should use instead of this one. - if t.Implements(customAliasSchema) { - v := reflect.New(t) - o := v.Interface().(aliasSchemaImpl) - t = reflect.TypeOf(o.JSONSchemaAlias()) - return r.refOrReflectTypeToSchema(definitions, t) - } - - // Do any pre-definitions exist? - if r.Mapper != nil { - if t := r.Mapper(t); t != nil { - return t - } - } - if rt := r.reflectCustomSchema(definitions, t); rt != nil { - return rt - } - - // Prepare a base to which details can be added - st := new(Schema) - - // jsonpb will marshal protobuf enum options as either strings or integers. - // It will unmarshal either. - if t.Implements(protoEnumType) { - st.OneOf = []*Schema{ - {Type: "string"}, - {Type: "integer"}, - } - return st - } - - // Defined format types for JSON Schema Validation - // RFC draft-wright-json-schema-validation-00, section 7.3 - // TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7 - if t == ipType { - // TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5 - st.Type = "string" - st.Format = "ipv4" - return st - } - - switch t.Kind() { - case reflect.Struct: - r.reflectStruct(definitions, t, st) - - case reflect.Slice, reflect.Array: - r.reflectSliceOrArray(definitions, t, st) - - case reflect.Map: - r.reflectMap(definitions, t, st) - - case reflect.Interface: - // empty - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - st.Type = "integer" - - case reflect.Float32, reflect.Float64: - st.Type = "number" - - case reflect.Bool: - st.Type = "boolean" - - case reflect.String: - st.Type = "string" - - default: - panic("unsupported type " + t.String()) - } - - r.reflectSchemaExtend(definitions, t, st) - - // Always try to reference the definition which may have just been created - if def := r.refDefinition(definitions, t); def != nil { - return def - } - - return st -} - -func (r *Reflector) reflectCustomSchema(definitions Definitions, t reflect.Type) *Schema { - if t.Kind() == reflect.Ptr { - return r.reflectCustomSchema(definitions, t.Elem()) - } - - if t.Implements(customType) { - v := reflect.New(t) - o := v.Interface().(customSchemaImpl) - st := o.JSONSchema() - r.addDefinition(definitions, t, st) - if ref := r.refDefinition(definitions, t); ref != nil { - return ref - } - return st - } - - return nil -} - -func (r *Reflector) reflectSchemaExtend(definitions Definitions, t reflect.Type, s *Schema) *Schema { - if t.Implements(extendType) { - v := reflect.New(t) - o := v.Interface().(extendSchemaImpl) - o.JSONSchemaExtend(s) - if ref := r.refDefinition(definitions, t); ref != nil { - return ref - } - } - - return s -} - -func (r *Reflector) reflectSliceOrArray(definitions Definitions, t reflect.Type, st *Schema) { - if t == rawMessageType { - return - } - - r.addDefinition(definitions, t, st) - - if st.Description == "" { - st.Description = r.lookupComment(t, "") - } - - if t.Kind() == reflect.Array { - l := uint64(t.Len()) - st.MinItems = &l - st.MaxItems = &l - } - if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() { - st.Type = "string" - // NOTE: ContentMediaType is not set here - st.ContentEncoding = "base64" - } else { - st.Type = "array" - st.Items = r.refOrReflectTypeToSchema(definitions, t.Elem()) - } -} - -func (r *Reflector) reflectMap(definitions Definitions, t reflect.Type, st *Schema) { - r.addDefinition(definitions, t, st) - - st.Type = "object" - if st.Description == "" { - st.Description = r.lookupComment(t, "") - } - - switch t.Key().Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - st.PatternProperties = map[string]*Schema{ - "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), - } - st.AdditionalProperties = FalseSchema - return - } - if t.Elem().Kind() != reflect.Interface { - st.AdditionalProperties = r.refOrReflectTypeToSchema(definitions, t.Elem()) - } -} - -// Reflects a struct to a JSON Schema type. -func (r *Reflector) reflectStruct(definitions Definitions, t reflect.Type, s *Schema) { - // Handle special types - switch t { - case timeType: // date-time RFC section 7.3.1 - s.Type = "string" - s.Format = "date-time" - return - case uriType: // uri RFC section 7.3.6 - s.Type = "string" - s.Format = "uri" - return - } - - r.addDefinition(definitions, t, s) - s.Type = "object" - s.Properties = NewProperties() - s.Description = r.lookupComment(t, "") - if r.AssignAnchor { - s.Anchor = t.Name() - } - if !r.AllowAdditionalProperties && s.AdditionalProperties == nil { - s.AdditionalProperties = FalseSchema - } - - ignored := false - for _, it := range r.IgnoredTypes { - if reflect.TypeOf(it) == t { - ignored = true - break - } - } - if !ignored { - r.reflectStructFields(s, definitions, t) - } -} - -func (r *Reflector) reflectStructFields(st *Schema, definitions Definitions, t reflect.Type) { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.Kind() != reflect.Struct { - return - } - - var getFieldDocString customGetFieldDocString - if t.Implements(customStructGetFieldDocString) { - v := reflect.New(t) - o := v.Interface().(customSchemaGetFieldDocString) - getFieldDocString = o.GetFieldDocString - } - - customPropertyMethod := func(string) any { - return nil - } - if t.Implements(customPropertyAliasSchema) { - v := reflect.New(t) - o := v.Interface().(propertyAliasSchemaImpl) - customPropertyMethod = o.JSONSchemaProperty - } - - handleField := func(f reflect.StructField) { - name, shouldEmbed, required, nullable := r.reflectFieldName(f) - // if anonymous and exported type should be processed recursively - // current type should inherit properties of anonymous one - if name == "" { - if shouldEmbed { - r.reflectStructFields(st, definitions, f.Type) - } - return - } - - // If a JSONSchemaAlias(prop string) method is defined, attempt to use - // the provided object's type instead of the field's type. - var property *Schema - if alias := customPropertyMethod(name); alias != nil { - property = r.refOrReflectTypeToSchema(definitions, reflect.TypeOf(alias)) - } else { - property = r.refOrReflectTypeToSchema(definitions, f.Type) - } - - property.structKeywordsFromTags(f, st, name) - if property.Description == "" { - property.Description = r.lookupComment(t, f.Name) - } - if getFieldDocString != nil { - property.Description = getFieldDocString(f.Name) - } - - if nullable { - property = &Schema{ - OneOf: []*Schema{ - property, - { - Type: "null", - }, - }, - } - } - - st.Properties.Set(name, property) - if required { - st.Required = appendUniqueString(st.Required, name) - } - } - - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - handleField(f) - } - if r.AdditionalFields != nil { - if af := r.AdditionalFields(t); af != nil { - for _, sf := range af { - handleField(sf) - } - } - } -} - -func appendUniqueString(base []string, value string) []string { - for _, v := range base { - if v == value { - return base - } - } - return append(base, value) -} - -// addDefinition will append the provided schema. If needed, an ID and anchor will also be added. -func (r *Reflector) addDefinition(definitions Definitions, t reflect.Type, s *Schema) { - name := r.typeName(t) - if name == "" { - return - } - definitions[name] = s -} - -// refDefinition will provide a schema with a reference to an existing definition. -func (r *Reflector) refDefinition(definitions Definitions, t reflect.Type) *Schema { - if r.DoNotReference { - return nil - } - name := r.typeName(t) - if name == "" { - return nil - } - if _, ok := definitions[name]; !ok { - return nil - } - return &Schema{ - Ref: "#/$defs/" + name, - } -} - -func (r *Reflector) lookupID(t reflect.Type) ID { - if r.Lookup != nil { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return r.Lookup(t) - - } - return EmptyID -} - -func (t *Schema) structKeywordsFromTags(f reflect.StructField, parent *Schema, propertyName string) { - t.Description = f.Tag.Get("jsonschema_description") - - tags := splitOnUnescapedCommas(f.Tag.Get("jsonschema")) - tags = t.genericKeywords(tags, parent, propertyName) - - switch t.Type { - case "string": - t.stringKeywords(tags) - case "number": - t.numericalKeywords(tags) - case "integer": - t.numericalKeywords(tags) - case "array": - t.arrayKeywords(tags) - case "boolean": - t.booleanKeywords(tags) - } - extras := strings.Split(f.Tag.Get("jsonschema_extras"), ",") - t.extraKeywords(extras) -} - -// read struct tags for generic keywords -func (t *Schema) genericKeywords(tags []string, parent *Schema, propertyName string) []string { //nolint:gocyclo - unprocessed := make([]string, 0, len(tags)) - for _, tag := range tags { - nameValue := strings.SplitN(tag, "=", 2) - if len(nameValue) == 2 { - name, val := nameValue[0], nameValue[1] - switch name { - case "title": - t.Title = val - case "description": - t.Description = val - case "type": - t.Type = val - case "anchor": - t.Anchor = val - case "oneof_required": - var typeFound *Schema - for i := range parent.OneOf { - if parent.OneOf[i].Title == nameValue[1] { - typeFound = parent.OneOf[i] - } - } - if typeFound == nil { - typeFound = &Schema{ - Title: nameValue[1], - Required: []string{}, - } - parent.OneOf = append(parent.OneOf, typeFound) - } - typeFound.Required = append(typeFound.Required, propertyName) - case "anyof_required": - var typeFound *Schema - for i := range parent.AnyOf { - if parent.AnyOf[i].Title == nameValue[1] { - typeFound = parent.AnyOf[i] - } - } - if typeFound == nil { - typeFound = &Schema{ - Title: nameValue[1], - Required: []string{}, - } - parent.AnyOf = append(parent.AnyOf, typeFound) - } - typeFound.Required = append(typeFound.Required, propertyName) - case "oneof_ref": - subSchema := t - if t.Items != nil { - subSchema = t.Items - } - if subSchema.OneOf == nil { - subSchema.OneOf = make([]*Schema, 0, 1) - } - subSchema.Ref = "" - refs := strings.Split(nameValue[1], ";") - for _, r := range refs { - subSchema.OneOf = append(subSchema.OneOf, &Schema{ - Ref: r, - }) - } - case "oneof_type": - if t.OneOf == nil { - t.OneOf = make([]*Schema, 0, 1) - } - t.Type = "" - types := strings.Split(nameValue[1], ";") - for _, ty := range types { - t.OneOf = append(t.OneOf, &Schema{ - Type: ty, - }) - } - case "anyof_ref": - subSchema := t - if t.Items != nil { - subSchema = t.Items - } - if subSchema.AnyOf == nil { - subSchema.AnyOf = make([]*Schema, 0, 1) - } - subSchema.Ref = "" - refs := strings.Split(nameValue[1], ";") - for _, r := range refs { - subSchema.AnyOf = append(subSchema.AnyOf, &Schema{ - Ref: r, - }) - } - case "anyof_type": - if t.AnyOf == nil { - t.AnyOf = make([]*Schema, 0, 1) - } - t.Type = "" - types := strings.Split(nameValue[1], ";") - for _, ty := range types { - t.AnyOf = append(t.AnyOf, &Schema{ - Type: ty, - }) - } - default: - unprocessed = append(unprocessed, tag) - } - } - } - return unprocessed -} - -// read struct tags for boolean type keywords -func (t *Schema) booleanKeywords(tags []string) { - for _, tag := range tags { - nameValue := strings.Split(tag, "=") - if len(nameValue) != 2 { - continue - } - name, val := nameValue[0], nameValue[1] - if name == "default" { - if val == "true" { - t.Default = true - } else if val == "false" { - t.Default = false - } - } - } -} - -// read struct tags for string type keywords -func (t *Schema) stringKeywords(tags []string) { - for _, tag := range tags { - nameValue := strings.SplitN(tag, "=", 2) - if len(nameValue) == 2 { - name, val := nameValue[0], nameValue[1] - switch name { - case "minLength": - t.MinLength = parseUint(val) - case "maxLength": - t.MaxLength = parseUint(val) - case "pattern": - t.Pattern = val - case "format": - t.Format = val - case "readOnly": - i, _ := strconv.ParseBool(val) - t.ReadOnly = i - case "writeOnly": - i, _ := strconv.ParseBool(val) - t.WriteOnly = i - case "default": - t.Default = val - case "example": - t.Examples = append(t.Examples, val) - case "enum": - t.Enum = append(t.Enum, val) - } - } - } -} - -// read struct tags for numerical type keywords -func (t *Schema) numericalKeywords(tags []string) { - for _, tag := range tags { - nameValue := strings.Split(tag, "=") - if len(nameValue) == 2 { - name, val := nameValue[0], nameValue[1] - switch name { - case "multipleOf": - t.MultipleOf, _ = toJSONNumber(val) - case "minimum": - t.Minimum, _ = toJSONNumber(val) - case "maximum": - t.Maximum, _ = toJSONNumber(val) - case "exclusiveMaximum": - t.ExclusiveMaximum, _ = toJSONNumber(val) - case "exclusiveMinimum": - t.ExclusiveMinimum, _ = toJSONNumber(val) - case "default": - if num, ok := toJSONNumber(val); ok { - t.Default = num - } - case "example": - if num, ok := toJSONNumber(val); ok { - t.Examples = append(t.Examples, num) - } - case "enum": - if num, ok := toJSONNumber(val); ok { - t.Enum = append(t.Enum, num) - } - } - } - } -} - -// read struct tags for object type keywords -// func (t *Type) objectKeywords(tags []string) { -// for _, tag := range tags{ -// nameValue := strings.Split(tag, "=") -// name, val := nameValue[0], nameValue[1] -// switch name{ -// case "dependencies": -// t.Dependencies = val -// break; -// case "patternProperties": -// t.PatternProperties = val -// break; -// } -// } -// } - -// read struct tags for array type keywords -func (t *Schema) arrayKeywords(tags []string) { - var defaultValues []any - - unprocessed := make([]string, 0, len(tags)) - for _, tag := range tags { - nameValue := strings.Split(tag, "=") - if len(nameValue) == 2 { - name, val := nameValue[0], nameValue[1] - switch name { - case "minItems": - t.MinItems = parseUint(val) - case "maxItems": - t.MaxItems = parseUint(val) - case "uniqueItems": - t.UniqueItems = true - case "default": - defaultValues = append(defaultValues, val) - case "format": - t.Items.Format = val - case "pattern": - t.Items.Pattern = val - default: - unprocessed = append(unprocessed, tag) // left for further processing by underlying type - } - } - } - if len(defaultValues) > 0 { - t.Default = defaultValues - } - - if len(unprocessed) == 0 { - // we don't have anything else to process - return - } - - switch t.Items.Type { - case "string": - t.Items.stringKeywords(unprocessed) - case "number": - t.Items.numericalKeywords(unprocessed) - case "integer": - t.Items.numericalKeywords(unprocessed) - case "array": - // explicitly don't support traversal for the [][]..., as it's unclear where the array tags belong - case "boolean": - t.Items.booleanKeywords(unprocessed) - } -} - -func (t *Schema) extraKeywords(tags []string) { - for _, tag := range tags { - nameValue := strings.SplitN(tag, "=", 2) - if len(nameValue) == 2 { - t.setExtra(nameValue[0], nameValue[1]) - } - } -} - -func (t *Schema) setExtra(key, val string) { - if t.Extras == nil { - t.Extras = map[string]any{} - } - if existingVal, ok := t.Extras[key]; ok { - switch existingVal := existingVal.(type) { - case string: - t.Extras[key] = []string{existingVal, val} - case []string: - t.Extras[key] = append(existingVal, val) - case int: - t.Extras[key], _ = strconv.Atoi(val) - case bool: - t.Extras[key] = (val == "true" || val == "t") - } - } else { - switch key { - case "minimum": - t.Extras[key], _ = strconv.Atoi(val) - default: - var x any - if val == "true" { - x = true - } else if val == "false" { - x = false - } else { - x = val - } - t.Extras[key] = x - } - } -} - -func requiredFromJSONTags(tags []string, val *bool) { - if ignoredByJSONTags(tags) { - return - } - - for _, tag := range tags[1:] { - if tag == "omitempty" { - *val = false - return - } - } - *val = true -} - -func requiredFromJSONSchemaTags(tags []string, val *bool) { - if ignoredByJSONSchemaTags(tags) { - return - } - for _, tag := range tags { - if tag == "required" { - *val = true - } - } -} - -func nullableFromJSONSchemaTags(tags []string) bool { - if ignoredByJSONSchemaTags(tags) { - return false - } - for _, tag := range tags { - if tag == "nullable" { - return true - } - } - return false -} - -func ignoredByJSONTags(tags []string) bool { - return tags[0] == "-" -} - -func ignoredByJSONSchemaTags(tags []string) bool { - return tags[0] == "-" -} - -func inlinedByJSONTags(tags []string) bool { - for _, tag := range tags[1:] { - if tag == "inline" { - return true - } - } - return false -} - -// toJSONNumber converts string to *json.Number. -// It'll aso return whether the number is valid. -func toJSONNumber(s string) (json.Number, bool) { - num := json.Number(s) - if _, err := num.Int64(); err == nil { - return num, true - } - if _, err := num.Float64(); err == nil { - return num, true - } - return json.Number(""), false -} - -func parseUint(num string) *uint64 { - val, err := strconv.ParseUint(num, 10, 64) - if err != nil { - return nil - } - return &val -} - -func (r *Reflector) fieldNameTag() string { - if r.FieldNameTag != "" { - return r.FieldNameTag - } - return "json" -} - -func (r *Reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) { - jsonTagString := f.Tag.Get(r.fieldNameTag()) - jsonTags := strings.Split(jsonTagString, ",") - - if ignoredByJSONTags(jsonTags) { - return "", false, false, false - } - - schemaTags := strings.Split(f.Tag.Get("jsonschema"), ",") - if ignoredByJSONSchemaTags(schemaTags) { - return "", false, false, false - } - - var required bool - if !r.RequiredFromJSONSchemaTags { - requiredFromJSONTags(jsonTags, &required) - } - requiredFromJSONSchemaTags(schemaTags, &required) - - nullable := nullableFromJSONSchemaTags(schemaTags) - - if f.Anonymous && jsonTags[0] == "" { - // As per JSON Marshal rules, anonymous structs are inherited - if f.Type.Kind() == reflect.Struct { - return "", true, false, false - } - - // As per JSON Marshal rules, anonymous pointer to structs are inherited - if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct { - return "", true, false, false - } - } - - // As per JSON Marshal rules, inline nested structs that have `inline` tag. - if inlinedByJSONTags(jsonTags) { - return "", true, false, false - } - - // Try to determine the name from the different combos - name := f.Name - if jsonTags[0] != "" { - name = jsonTags[0] - } - if !f.Anonymous && f.PkgPath != "" { - // field not anonymous and not export has no export name - name = "" - } else if r.KeyNamer != nil { - name = r.KeyNamer(name) - } - - return name, false, required, nullable -} - -// UnmarshalJSON is used to parse a schema object or boolean. -func (t *Schema) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, []byte("true")) { - *t = *TrueSchema - return nil - } else if bytes.Equal(data, []byte("false")) { - *t = *FalseSchema - return nil - } - type SchemaAlt Schema - aux := &struct { - *SchemaAlt - }{ - SchemaAlt: (*SchemaAlt)(t), - } - return json.Unmarshal(data, aux) -} - -// MarshalJSON is used to serialize a schema object or boolean. -func (t *Schema) MarshalJSON() ([]byte, error) { - if t.boolean != nil { - if *t.boolean { - return []byte("true"), nil - } - return []byte("false"), nil - } - if reflect.DeepEqual(&Schema{}, t) { - // Don't bother returning empty schemas - return []byte("true"), nil - } - type SchemaAlt Schema - b, err := json.Marshal((*SchemaAlt)(t)) - if err != nil { - return nil, err - } - if len(t.Extras) == 0 { - return b, nil - } - m, err := json.Marshal(t.Extras) - if err != nil { - return nil, err - } - if len(b) == 2 { - return m, nil - } - b[len(b)-1] = ',' - return append(b, m[1:]...), nil -} - -func (r *Reflector) typeName(t reflect.Type) string { - if r.Namer != nil { - if name := r.Namer(t); name != "" { - return name - } - } - return t.Name() -} - -// Split on commas that are not preceded by `\`. -// This way, we prevent splitting regexes -func splitOnUnescapedCommas(tagString string) []string { - ret := make([]string, 0) - separated := strings.Split(tagString, ",") - ret = append(ret, separated[0]) - i := 0 - for _, nextTag := range separated[1:] { - if len(ret[i]) == 0 { - ret = append(ret, nextTag) - i++ - continue - } - - if ret[i][len(ret[i])-1] == '\\' { - ret[i] = ret[i][:len(ret[i])-1] + "," + nextTag - } else { - ret = append(ret, nextTag) - i++ - } - } - - return ret -} - -func fullyQualifiedTypeName(t reflect.Type) string { - return t.PkgPath() + "." + t.Name() -} diff --git a/vendor/github.com/invopop/jsonschema/reflect_comments.go b/vendor/github.com/invopop/jsonschema/reflect_comments.go deleted file mode 100644 index ff374c75c..000000000 --- a/vendor/github.com/invopop/jsonschema/reflect_comments.go +++ /dev/null @@ -1,146 +0,0 @@ -package jsonschema - -import ( - "fmt" - "io/fs" - gopath "path" - "path/filepath" - "reflect" - "strings" - - "go/ast" - "go/doc" - "go/parser" - "go/token" -) - -type commentOptions struct { - fullObjectText bool // use the first sentence only? -} - -// CommentOption allows for special configuration options when preparing Go -// source files for comment extraction. -type CommentOption func(*commentOptions) - -// WithFullComment will configure the comment extraction to process to use an -// object type's full comment text instead of just the synopsis. -func WithFullComment() CommentOption { - return func(o *commentOptions) { - o.fullObjectText = true - } -} - -// AddGoComments will update the reflectors comment map with all the comments -// found in the provided source directories including sub-directories, in order to -// generate a dictionary of comments associated with Types and Fields. The results -// will be added to the `Reflect.CommentMap` ready to use with Schema "description" -// fields. -// -// The `go/parser` library is used to extract all the comments and unfortunately doesn't -// have a built-in way to determine the fully qualified name of a package. The `base` -// parameter, the URL used to import that package, is thus required to be able to match -// reflected types. -// -// When parsing type comments, by default we use the `go/doc`'s Synopsis method to extract -// the first phrase only. Field comments, which tend to be much shorter, will include everything. -// This behavior can be changed by using the `WithFullComment` option. -func (r *Reflector) AddGoComments(base, path string, opts ...CommentOption) error { - if r.CommentMap == nil { - r.CommentMap = make(map[string]string) - } - co := new(commentOptions) - for _, opt := range opts { - opt(co) - } - - return r.extractGoComments(base, path, r.CommentMap, co) -} - -func (r *Reflector) extractGoComments(base, path string, commentMap map[string]string, opts *commentOptions) error { - fset := token.NewFileSet() - dict := make(map[string][]*ast.Package) - err := filepath.Walk(path, func(path string, info fs.FileInfo, err error) error { - if err != nil { - return err - } - if info.IsDir() { - d, err := parser.ParseDir(fset, path, nil, parser.ParseComments) - if err != nil { - return err - } - for _, v := range d { - // paths may have multiple packages, like for tests - k := gopath.Join(base, path) - dict[k] = append(dict[k], v) - } - } - return nil - }) - if err != nil { - return err - } - - for pkg, p := range dict { - for _, f := range p { - gtxt := "" - typ := "" - ast.Inspect(f, func(n ast.Node) bool { - switch x := n.(type) { - case *ast.TypeSpec: - typ = x.Name.String() - if !ast.IsExported(typ) { - typ = "" - } else { - txt := x.Doc.Text() - if txt == "" && gtxt != "" { - txt = gtxt - gtxt = "" - } - if !opts.fullObjectText { - txt = doc.Synopsis(txt) - } - commentMap[fmt.Sprintf("%s.%s", pkg, typ)] = strings.TrimSpace(txt) - } - case *ast.Field: - txt := x.Doc.Text() - if txt == "" { - txt = x.Comment.Text() - } - if typ != "" && txt != "" { - for _, n := range x.Names { - if ast.IsExported(n.String()) { - k := fmt.Sprintf("%s.%s.%s", pkg, typ, n) - commentMap[k] = strings.TrimSpace(txt) - } - } - } - case *ast.GenDecl: - // remember for the next type - gtxt = x.Doc.Text() - } - return true - }) - } - } - - return nil -} - -func (r *Reflector) lookupComment(t reflect.Type, name string) string { - if r.LookupComment != nil { - if comment := r.LookupComment(t, name); comment != "" { - return comment - } - } - - if r.CommentMap == nil { - return "" - } - - n := fullyQualifiedTypeName(t) - if name != "" { - n = n + "." + name - } - - return r.CommentMap[n] -} diff --git a/vendor/github.com/invopop/jsonschema/schema.go b/vendor/github.com/invopop/jsonschema/schema.go deleted file mode 100644 index 2d914b8c8..000000000 --- a/vendor/github.com/invopop/jsonschema/schema.go +++ /dev/null @@ -1,94 +0,0 @@ -package jsonschema - -import ( - "encoding/json" - - orderedmap "github.com/wk8/go-ordered-map/v2" -) - -// Version is the JSON Schema version. -var Version = "https://json-schema.org/draft/2020-12/schema" - -// Schema represents a JSON Schema object type. -// RFC draft-bhutton-json-schema-00 section 4.3 -type Schema struct { - // RFC draft-bhutton-json-schema-00 - Version string `json:"$schema,omitempty"` // section 8.1.1 - ID ID `json:"$id,omitempty"` // section 8.2.1 - Anchor string `json:"$anchor,omitempty"` // section 8.2.2 - Ref string `json:"$ref,omitempty"` // section 8.2.3.1 - DynamicRef string `json:"$dynamicRef,omitempty"` // section 8.2.3.2 - Definitions Definitions `json:"$defs,omitempty"` // section 8.2.4 - Comments string `json:"$comment,omitempty"` // section 8.3 - // RFC draft-bhutton-json-schema-00 section 10.2.1 (Sub-schemas with logic) - AllOf []*Schema `json:"allOf,omitempty"` // section 10.2.1.1 - AnyOf []*Schema `json:"anyOf,omitempty"` // section 10.2.1.2 - OneOf []*Schema `json:"oneOf,omitempty"` // section 10.2.1.3 - Not *Schema `json:"not,omitempty"` // section 10.2.1.4 - // RFC draft-bhutton-json-schema-00 section 10.2.2 (Apply sub-schemas conditionally) - If *Schema `json:"if,omitempty"` // section 10.2.2.1 - Then *Schema `json:"then,omitempty"` // section 10.2.2.2 - Else *Schema `json:"else,omitempty"` // section 10.2.2.3 - DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` // section 10.2.2.4 - // RFC draft-bhutton-json-schema-00 section 10.3.1 (arrays) - PrefixItems []*Schema `json:"prefixItems,omitempty"` // section 10.3.1.1 - Items *Schema `json:"items,omitempty"` // section 10.3.1.2 (replaces additionalItems) - Contains *Schema `json:"contains,omitempty"` // section 10.3.1.3 - // RFC draft-bhutton-json-schema-00 section 10.3.2 (sub-schemas) - Properties *orderedmap.OrderedMap[string, *Schema] `json:"properties,omitempty"` // section 10.3.2.1 - PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` // section 10.3.2.2 - AdditionalProperties *Schema `json:"additionalProperties,omitempty"` // section 10.3.2.3 - PropertyNames *Schema `json:"propertyNames,omitempty"` // section 10.3.2.4 - // RFC draft-bhutton-json-schema-validation-00, section 6 - Type string `json:"type,omitempty"` // section 6.1.1 - Enum []any `json:"enum,omitempty"` // section 6.1.2 - Const any `json:"const,omitempty"` // section 6.1.3 - MultipleOf json.Number `json:"multipleOf,omitempty"` // section 6.2.1 - Maximum json.Number `json:"maximum,omitempty"` // section 6.2.2 - ExclusiveMaximum json.Number `json:"exclusiveMaximum,omitempty"` // section 6.2.3 - Minimum json.Number `json:"minimum,omitempty"` // section 6.2.4 - ExclusiveMinimum json.Number `json:"exclusiveMinimum,omitempty"` // section 6.2.5 - MaxLength *uint64 `json:"maxLength,omitempty"` // section 6.3.1 - MinLength *uint64 `json:"minLength,omitempty"` // section 6.3.2 - Pattern string `json:"pattern,omitempty"` // section 6.3.3 - MaxItems *uint64 `json:"maxItems,omitempty"` // section 6.4.1 - MinItems *uint64 `json:"minItems,omitempty"` // section 6.4.2 - UniqueItems bool `json:"uniqueItems,omitempty"` // section 6.4.3 - MaxContains *uint64 `json:"maxContains,omitempty"` // section 6.4.4 - MinContains *uint64 `json:"minContains,omitempty"` // section 6.4.5 - MaxProperties *uint64 `json:"maxProperties,omitempty"` // section 6.5.1 - MinProperties *uint64 `json:"minProperties,omitempty"` // section 6.5.2 - Required []string `json:"required,omitempty"` // section 6.5.3 - DependentRequired map[string][]string `json:"dependentRequired,omitempty"` // section 6.5.4 - // RFC draft-bhutton-json-schema-validation-00, section 7 - Format string `json:"format,omitempty"` - // RFC draft-bhutton-json-schema-validation-00, section 8 - ContentEncoding string `json:"contentEncoding,omitempty"` // section 8.3 - ContentMediaType string `json:"contentMediaType,omitempty"` // section 8.4 - ContentSchema *Schema `json:"contentSchema,omitempty"` // section 8.5 - // RFC draft-bhutton-json-schema-validation-00, section 9 - Title string `json:"title,omitempty"` // section 9.1 - Description string `json:"description,omitempty"` // section 9.1 - Default any `json:"default,omitempty"` // section 9.2 - Deprecated bool `json:"deprecated,omitempty"` // section 9.3 - ReadOnly bool `json:"readOnly,omitempty"` // section 9.4 - WriteOnly bool `json:"writeOnly,omitempty"` // section 9.4 - Examples []any `json:"examples,omitempty"` // section 9.5 - - Extras map[string]any `json:"-"` - - // Special boolean representation of the Schema - section 4.3.2 - boolean *bool -} - -var ( - // TrueSchema defines a schema with a true value - TrueSchema = &Schema{boolean: &[]bool{true}[0]} - // FalseSchema defines a schema with a false value - FalseSchema = &Schema{boolean: &[]bool{false}[0]} -) - -// Definitions hold schema definitions. -// http://json-schema.org/latest/json-schema-validation.html#rfc.section.5.26 -// RFC draft-wright-json-schema-validation-00, section 5.26 -type Definitions map[string]*Schema diff --git a/vendor/github.com/invopop/jsonschema/utils.go b/vendor/github.com/invopop/jsonschema/utils.go deleted file mode 100644 index ed8edf741..000000000 --- a/vendor/github.com/invopop/jsonschema/utils.go +++ /dev/null @@ -1,26 +0,0 @@ -package jsonschema - -import ( - "regexp" - "strings" - - orderedmap "github.com/wk8/go-ordered-map/v2" -) - -var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") -var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") - -// ToSnakeCase converts the provided string into snake case using dashes. -// This is useful for Schema IDs and definitions to be coherent with -// common JSON Schema examples. -func ToSnakeCase(str string) string { - snake := matchFirstCap.ReplaceAllString(str, "${1}-${2}") - snake = matchAllCap.ReplaceAllString(snake, "${1}-${2}") - return strings.ToLower(snake) -} - -// NewProperties is a helper method to instantiate a new properties ordered -// map. -func NewProperties() *orderedmap.OrderedMap[string, *Schema] { - return orderedmap.New[string, *Schema]() -} diff --git a/vendor/github.com/mailru/easyjson/LICENSE b/vendor/github.com/mailru/easyjson/LICENSE deleted file mode 100644 index fbff658f7..000000000 --- a/vendor/github.com/mailru/easyjson/LICENSE +++ /dev/null @@ -1,7 +0,0 @@ -Copyright (c) 2016 Mail.Ru Group - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/mailru/easyjson/buffer/pool.go b/vendor/github.com/mailru/easyjson/buffer/pool.go deleted file mode 100644 index 598a54af9..000000000 --- a/vendor/github.com/mailru/easyjson/buffer/pool.go +++ /dev/null @@ -1,278 +0,0 @@ -// Package buffer implements a buffer for serialization, consisting of a chain of []byte-s to -// reduce copying and to allow reuse of individual chunks. -package buffer - -import ( - "io" - "net" - "sync" -) - -// PoolConfig contains configuration for the allocation and reuse strategy. -type PoolConfig struct { - StartSize int // Minimum chunk size that is allocated. - PooledSize int // Minimum chunk size that is reused, reusing chunks too small will result in overhead. - MaxSize int // Maximum chunk size that will be allocated. -} - -var config = PoolConfig{ - StartSize: 128, - PooledSize: 512, - MaxSize: 32768, -} - -// Reuse pool: chunk size -> pool. -var buffers = map[int]*sync.Pool{} - -func initBuffers() { - for l := config.PooledSize; l <= config.MaxSize; l *= 2 { - buffers[l] = new(sync.Pool) - } -} - -func init() { - initBuffers() -} - -// Init sets up a non-default pooling and allocation strategy. Should be run before serialization is done. -func Init(cfg PoolConfig) { - config = cfg - initBuffers() -} - -// putBuf puts a chunk to reuse pool if it can be reused. -func putBuf(buf []byte) { - size := cap(buf) - if size < config.PooledSize { - return - } - if c := buffers[size]; c != nil { - c.Put(buf[:0]) - } -} - -// getBuf gets a chunk from reuse pool or creates a new one if reuse failed. -func getBuf(size int) []byte { - if size >= config.PooledSize { - if c := buffers[size]; c != nil { - v := c.Get() - if v != nil { - return v.([]byte) - } - } - } - return make([]byte, 0, size) -} - -// Buffer is a buffer optimized for serialization without extra copying. -type Buffer struct { - - // Buf is the current chunk that can be used for serialization. - Buf []byte - - toPool []byte - bufs [][]byte -} - -// EnsureSpace makes sure that the current chunk contains at least s free bytes, -// possibly creating a new chunk. -func (b *Buffer) EnsureSpace(s int) { - if cap(b.Buf)-len(b.Buf) < s { - b.ensureSpaceSlow(s) - } -} - -func (b *Buffer) ensureSpaceSlow(s int) { - l := len(b.Buf) - if l > 0 { - if cap(b.toPool) != cap(b.Buf) { - // Chunk was reallocated, toPool can be pooled. - putBuf(b.toPool) - } - if cap(b.bufs) == 0 { - b.bufs = make([][]byte, 0, 8) - } - b.bufs = append(b.bufs, b.Buf) - l = cap(b.toPool) * 2 - } else { - l = config.StartSize - } - - if l > config.MaxSize { - l = config.MaxSize - } - b.Buf = getBuf(l) - b.toPool = b.Buf -} - -// AppendByte appends a single byte to buffer. -func (b *Buffer) AppendByte(data byte) { - b.EnsureSpace(1) - b.Buf = append(b.Buf, data) -} - -// AppendBytes appends a byte slice to buffer. -func (b *Buffer) AppendBytes(data []byte) { - if len(data) <= cap(b.Buf)-len(b.Buf) { - b.Buf = append(b.Buf, data...) // fast path - } else { - b.appendBytesSlow(data) - } -} - -func (b *Buffer) appendBytesSlow(data []byte) { - for len(data) > 0 { - b.EnsureSpace(1) - - sz := cap(b.Buf) - len(b.Buf) - if sz > len(data) { - sz = len(data) - } - - b.Buf = append(b.Buf, data[:sz]...) - data = data[sz:] - } -} - -// AppendString appends a string to buffer. -func (b *Buffer) AppendString(data string) { - if len(data) <= cap(b.Buf)-len(b.Buf) { - b.Buf = append(b.Buf, data...) // fast path - } else { - b.appendStringSlow(data) - } -} - -func (b *Buffer) appendStringSlow(data string) { - for len(data) > 0 { - b.EnsureSpace(1) - - sz := cap(b.Buf) - len(b.Buf) - if sz > len(data) { - sz = len(data) - } - - b.Buf = append(b.Buf, data[:sz]...) - data = data[sz:] - } -} - -// Size computes the size of a buffer by adding sizes of every chunk. -func (b *Buffer) Size() int { - size := len(b.Buf) - for _, buf := range b.bufs { - size += len(buf) - } - return size -} - -// DumpTo outputs the contents of a buffer to a writer and resets the buffer. -func (b *Buffer) DumpTo(w io.Writer) (written int, err error) { - bufs := net.Buffers(b.bufs) - if len(b.Buf) > 0 { - bufs = append(bufs, b.Buf) - } - n, err := bufs.WriteTo(w) - - for _, buf := range b.bufs { - putBuf(buf) - } - putBuf(b.toPool) - - b.bufs = nil - b.Buf = nil - b.toPool = nil - - return int(n), err -} - -// BuildBytes creates a single byte slice with all the contents of the buffer. Data is -// copied if it does not fit in a single chunk. You can optionally provide one byte -// slice as argument that it will try to reuse. -func (b *Buffer) BuildBytes(reuse ...[]byte) []byte { - if len(b.bufs) == 0 { - ret := b.Buf - b.toPool = nil - b.Buf = nil - return ret - } - - var ret []byte - size := b.Size() - - // If we got a buffer as argument and it is big enough, reuse it. - if len(reuse) == 1 && cap(reuse[0]) >= size { - ret = reuse[0][:0] - } else { - ret = make([]byte, 0, size) - } - for _, buf := range b.bufs { - ret = append(ret, buf...) - putBuf(buf) - } - - ret = append(ret, b.Buf...) - putBuf(b.toPool) - - b.bufs = nil - b.toPool = nil - b.Buf = nil - - return ret -} - -type readCloser struct { - offset int - bufs [][]byte -} - -func (r *readCloser) Read(p []byte) (n int, err error) { - for _, buf := range r.bufs { - // Copy as much as we can. - x := copy(p[n:], buf[r.offset:]) - n += x // Increment how much we filled. - - // Did we empty the whole buffer? - if r.offset+x == len(buf) { - // On to the next buffer. - r.offset = 0 - r.bufs = r.bufs[1:] - - // We can release this buffer. - putBuf(buf) - } else { - r.offset += x - } - - if n == len(p) { - break - } - } - // No buffers left or nothing read? - if len(r.bufs) == 0 { - err = io.EOF - } - return -} - -func (r *readCloser) Close() error { - // Release all remaining buffers. - for _, buf := range r.bufs { - putBuf(buf) - } - // In case Close gets called multiple times. - r.bufs = nil - - return nil -} - -// ReadCloser creates an io.ReadCloser with all the contents of the buffer. -func (b *Buffer) ReadCloser() io.ReadCloser { - ret := &readCloser{0, append(b.bufs, b.Buf)} - - b.bufs = nil - b.toPool = nil - b.Buf = nil - - return ret -} diff --git a/vendor/github.com/mailru/easyjson/jwriter/writer.go b/vendor/github.com/mailru/easyjson/jwriter/writer.go deleted file mode 100644 index 2c5b20105..000000000 --- a/vendor/github.com/mailru/easyjson/jwriter/writer.go +++ /dev/null @@ -1,405 +0,0 @@ -// Package jwriter contains a JSON writer. -package jwriter - -import ( - "io" - "strconv" - "unicode/utf8" - - "github.com/mailru/easyjson/buffer" -) - -// Flags describe various encoding options. The behavior may be actually implemented in the encoder, but -// Flags field in Writer is used to set and pass them around. -type Flags int - -const ( - NilMapAsEmpty Flags = 1 << iota // Encode nil map as '{}' rather than 'null'. - NilSliceAsEmpty // Encode nil slice as '[]' rather than 'null'. -) - -// Writer is a JSON writer. -type Writer struct { - Flags Flags - - Error error - Buffer buffer.Buffer - NoEscapeHTML bool -} - -// Size returns the size of the data that was written out. -func (w *Writer) Size() int { - return w.Buffer.Size() -} - -// DumpTo outputs the data to given io.Writer, resetting the buffer. -func (w *Writer) DumpTo(out io.Writer) (written int, err error) { - return w.Buffer.DumpTo(out) -} - -// BuildBytes returns writer data as a single byte slice. You can optionally provide one byte slice -// as argument that it will try to reuse. -func (w *Writer) BuildBytes(reuse ...[]byte) ([]byte, error) { - if w.Error != nil { - return nil, w.Error - } - - return w.Buffer.BuildBytes(reuse...), nil -} - -// ReadCloser returns an io.ReadCloser that can be used to read the data. -// ReadCloser also resets the buffer. -func (w *Writer) ReadCloser() (io.ReadCloser, error) { - if w.Error != nil { - return nil, w.Error - } - - return w.Buffer.ReadCloser(), nil -} - -// RawByte appends raw binary data to the buffer. -func (w *Writer) RawByte(c byte) { - w.Buffer.AppendByte(c) -} - -// RawByte appends raw binary data to the buffer. -func (w *Writer) RawString(s string) { - w.Buffer.AppendString(s) -} - -// Raw appends raw binary data to the buffer or sets the error if it is given. Useful for -// calling with results of MarshalJSON-like functions. -func (w *Writer) Raw(data []byte, err error) { - switch { - case w.Error != nil: - return - case err != nil: - w.Error = err - case len(data) > 0: - w.Buffer.AppendBytes(data) - default: - w.RawString("null") - } -} - -// RawText encloses raw binary data in quotes and appends in to the buffer. -// Useful for calling with results of MarshalText-like functions. -func (w *Writer) RawText(data []byte, err error) { - switch { - case w.Error != nil: - return - case err != nil: - w.Error = err - case len(data) > 0: - w.String(string(data)) - default: - w.RawString("null") - } -} - -// Base64Bytes appends data to the buffer after base64 encoding it -func (w *Writer) Base64Bytes(data []byte) { - if data == nil { - w.Buffer.AppendString("null") - return - } - w.Buffer.AppendByte('"') - w.base64(data) - w.Buffer.AppendByte('"') -} - -func (w *Writer) Uint8(n uint8) { - w.Buffer.EnsureSpace(3) - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) -} - -func (w *Writer) Uint16(n uint16) { - w.Buffer.EnsureSpace(5) - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) -} - -func (w *Writer) Uint32(n uint32) { - w.Buffer.EnsureSpace(10) - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) -} - -func (w *Writer) Uint(n uint) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) -} - -func (w *Writer) Uint64(n uint64) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) -} - -func (w *Writer) Int8(n int8) { - w.Buffer.EnsureSpace(4) - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) -} - -func (w *Writer) Int16(n int16) { - w.Buffer.EnsureSpace(6) - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) -} - -func (w *Writer) Int32(n int32) { - w.Buffer.EnsureSpace(11) - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) -} - -func (w *Writer) Int(n int) { - w.Buffer.EnsureSpace(21) - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) -} - -func (w *Writer) Int64(n int64) { - w.Buffer.EnsureSpace(21) - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) -} - -func (w *Writer) Uint8Str(n uint8) { - w.Buffer.EnsureSpace(3) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Uint16Str(n uint16) { - w.Buffer.EnsureSpace(5) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Uint32Str(n uint32) { - w.Buffer.EnsureSpace(10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) UintStr(n uint) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Uint64Str(n uint64) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) UintptrStr(n uintptr) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Int8Str(n int8) { - w.Buffer.EnsureSpace(4) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Int16Str(n int16) { - w.Buffer.EnsureSpace(6) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Int32Str(n int32) { - w.Buffer.EnsureSpace(11) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) IntStr(n int) { - w.Buffer.EnsureSpace(21) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Int64Str(n int64) { - w.Buffer.EnsureSpace(21) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Float32(n float32) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) -} - -func (w *Writer) Float32Str(n float32) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Float64(n float64) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, n, 'g', -1, 64) -} - -func (w *Writer) Float64Str(n float64) { - w.Buffer.EnsureSpace(20) - w.Buffer.Buf = append(w.Buffer.Buf, '"') - w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 64) - w.Buffer.Buf = append(w.Buffer.Buf, '"') -} - -func (w *Writer) Bool(v bool) { - w.Buffer.EnsureSpace(5) - if v { - w.Buffer.Buf = append(w.Buffer.Buf, "true"...) - } else { - w.Buffer.Buf = append(w.Buffer.Buf, "false"...) - } -} - -const chars = "0123456789abcdef" - -func getTable(falseValues ...int) [128]bool { - table := [128]bool{} - - for i := 0; i < 128; i++ { - table[i] = true - } - - for _, v := range falseValues { - table[v] = false - } - - return table -} - -var ( - htmlEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '&', '<', '>', '\\') - htmlNoEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '\\') -) - -func (w *Writer) String(s string) { - w.Buffer.AppendByte('"') - - // Portions of the string that contain no escapes are appended as - // byte slices. - - p := 0 // last non-escape symbol - - escapeTable := &htmlEscapeTable - if w.NoEscapeHTML { - escapeTable = &htmlNoEscapeTable - } - - for i := 0; i < len(s); { - c := s[i] - - if c < utf8.RuneSelf { - if escapeTable[c] { - // single-width character, no escaping is required - i++ - continue - } - - w.Buffer.AppendString(s[p:i]) - switch c { - case '\t': - w.Buffer.AppendString(`\t`) - case '\r': - w.Buffer.AppendString(`\r`) - case '\n': - w.Buffer.AppendString(`\n`) - case '\\': - w.Buffer.AppendString(`\\`) - case '"': - w.Buffer.AppendString(`\"`) - default: - w.Buffer.AppendString(`\u00`) - w.Buffer.AppendByte(chars[c>>4]) - w.Buffer.AppendByte(chars[c&0xf]) - } - - i++ - p = i - continue - } - - // broken utf - runeValue, runeWidth := utf8.DecodeRuneInString(s[i:]) - if runeValue == utf8.RuneError && runeWidth == 1 { - w.Buffer.AppendString(s[p:i]) - w.Buffer.AppendString(`\ufffd`) - i++ - p = i - continue - } - - // jsonp stuff - tab separator and line separator - if runeValue == '\u2028' || runeValue == '\u2029' { - w.Buffer.AppendString(s[p:i]) - w.Buffer.AppendString(`\u202`) - w.Buffer.AppendByte(chars[runeValue&0xf]) - i += runeWidth - p = i - continue - } - i += runeWidth - } - w.Buffer.AppendString(s[p:]) - w.Buffer.AppendByte('"') -} - -const encode = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" -const padChar = '=' - -func (w *Writer) base64(in []byte) { - - if len(in) == 0 { - return - } - - w.Buffer.EnsureSpace(((len(in)-1)/3 + 1) * 4) - - si := 0 - n := (len(in) / 3) * 3 - - for si < n { - // Convert 3x 8bit source bytes into 4 bytes - val := uint(in[si+0])<<16 | uint(in[si+1])<<8 | uint(in[si+2]) - - w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F], encode[val>>6&0x3F], encode[val&0x3F]) - - si += 3 - } - - remain := len(in) - si - if remain == 0 { - return - } - - // Add the remaining small block - val := uint(in[si+0]) << 16 - if remain == 2 { - val |= uint(in[si+1]) << 8 - } - - w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F]) - - switch remain { - case 2: - w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>6&0x3F], byte(padChar)) - case 1: - w.Buffer.Buf = append(w.Buffer.Buf, byte(padChar), byte(padChar)) - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/consts.go b/vendor/github.com/mark3labs/mcp-go/mcp/consts.go deleted file mode 100644 index 66eb3803b..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/consts.go +++ /dev/null @@ -1,9 +0,0 @@ -package mcp - -const ( - ContentTypeText = "text" - ContentTypeImage = "image" - ContentTypeAudio = "audio" - ContentTypeLink = "resource_link" - ContentTypeResource = "resource" -) diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/errors.go b/vendor/github.com/mark3labs/mcp-go/mcp/errors.go deleted file mode 100644 index aead24744..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/errors.go +++ /dev/null @@ -1,85 +0,0 @@ -package mcp - -import ( - "errors" - "fmt" -) - -// Sentinel errors for common JSON-RPC error codes. -var ( - // ErrParseError indicates a JSON parsing error (code: PARSE_ERROR). - ErrParseError = errors.New("parse error") - - // ErrInvalidRequest indicates an invalid JSON-RPC request (code: INVALID_REQUEST). - ErrInvalidRequest = errors.New("invalid request") - - // ErrMethodNotFound indicates the requested method does not exist (code: METHOD_NOT_FOUND). - ErrMethodNotFound = errors.New("method not found") - - // ErrInvalidParams indicates invalid method parameters (code: INVALID_PARAMS). - ErrInvalidParams = errors.New("invalid params") - - // ErrInternalError indicates an internal JSON-RPC error (code: INTERNAL_ERROR). - ErrInternalError = errors.New("internal error") - - // ErrRequestInterrupted indicates a request was cancelled or timed out (code: REQUEST_INTERRUPTED). - ErrRequestInterrupted = errors.New("request interrupted") - - // ErrResourceNotFound indicates a requested resource was not found (code: RESOURCE_NOT_FOUND). - ErrResourceNotFound = errors.New("resource not found") -) - -// UnsupportedProtocolVersionError is returned when the server responds with -// a protocol version that the client doesn't support. -type UnsupportedProtocolVersionError struct { - Version string -} - -func (e UnsupportedProtocolVersionError) Error() string { - return fmt.Sprintf("unsupported protocol version: %q", e.Version) -} - -// Is implements the errors.Is interface for better error handling -func (e UnsupportedProtocolVersionError) Is(target error) bool { - _, ok := target.(UnsupportedProtocolVersionError) - return ok -} - -// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError -func IsUnsupportedProtocolVersion(err error) bool { - _, ok := err.(UnsupportedProtocolVersionError) - return ok -} - -// AsError maps JSONRPCErrorDetails to a Go error. -// Returns sentinel errors wrapped with custom messages for known codes. -// Defaults to a generic error with the original message when the code is not mapped. -func (e *JSONRPCErrorDetails) AsError() error { - var err error - - switch e.Code { - case PARSE_ERROR: - err = ErrParseError - case INVALID_REQUEST: - err = ErrInvalidRequest - case METHOD_NOT_FOUND: - err = ErrMethodNotFound - case INVALID_PARAMS: - err = ErrInvalidParams - case INTERNAL_ERROR: - err = ErrInternalError - case REQUEST_INTERRUPTED: - err = ErrRequestInterrupted - case RESOURCE_NOT_FOUND: - err = ErrResourceNotFound - default: - return errors.New(e.Message) - } - - // Wrap the sentinel error with the custom message if it differs from the sentinel. - if e.Message != "" && e.Message != err.Error() { - return fmt.Errorf("%w: %s", err, e.Message) - } - - return err -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go deleted file mode 100644 index 9b0b48ed2..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go +++ /dev/null @@ -1,176 +0,0 @@ -package mcp - -import "net/http" - -/* Prompts */ - -// ListPromptsRequest is sent from the client to request a list of prompts and -// prompt templates the server has. -type ListPromptsRequest struct { - PaginatedRequest - Header http.Header `json:"-"` -} - -// ListPromptsResult is the server's response to a prompts/list request from -// the client. -type ListPromptsResult struct { - PaginatedResult - Prompts []Prompt `json:"prompts"` -} - -// GetPromptRequest is used by the client to get a prompt provided by the -// server. -type GetPromptRequest struct { - Request - Params GetPromptParams `json:"params"` - Header http.Header `json:"-"` -} - -type GetPromptParams struct { - // The name of the prompt or prompt template. - Name string `json:"name"` - // Arguments to use for templating the prompt. - Arguments map[string]string `json:"arguments,omitempty"` -} - -// GetPromptResult is the server's response to a prompts/get request from the -// client. -type GetPromptResult struct { - Result - // An optional description for the prompt. - Description string `json:"description,omitempty"` - Messages []PromptMessage `json:"messages"` -} - -// Prompt represents a prompt or prompt template that the server offers. -// If Arguments is non-nil and non-empty, this indicates the prompt is a template -// that requires argument values to be provided when calling prompts/get. -// If Arguments is nil or empty, this is a static prompt that takes no arguments. -type Prompt struct { - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - // The name of the prompt or prompt template. - Name string `json:"name"` - // An optional description of what this prompt provides - Description string `json:"description,omitempty"` - // A list of arguments to use for templating the prompt. - // The presence of arguments indicates this is a template prompt. - Arguments []PromptArgument `json:"arguments,omitempty"` -} - -// GetName returns the name of the prompt. -func (p Prompt) GetName() string { - return p.Name -} - -// PromptArgument describes an argument that a prompt template can accept. -// When a prompt includes arguments, clients must provide values for all -// required arguments when making a prompts/get request. -type PromptArgument struct { - // The name of the argument. - Name string `json:"name"` - // A human-readable description of the argument. - Description string `json:"description,omitempty"` - // Whether this argument must be provided. - // If true, clients must include this argument when calling prompts/get. - Required bool `json:"required,omitempty"` -} - -// Role represents the sender or recipient of messages and data in a -// conversation. -type Role string - -const ( - RoleUser Role = "user" - RoleAssistant Role = "assistant" -) - -// PromptMessage describes a message returned as part of a prompt. -// -// This is similar to `SamplingMessage`, but also supports the embedding of -// resources from the MCP server. -type PromptMessage struct { - Role Role `json:"role"` - Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource -} - -// PromptListChangedNotification is an optional notification from the server -// to the client, informing it that the list of prompts it offers has changed. This -// may be issued by servers without any previous subscription from the client. -type PromptListChangedNotification struct { - Notification -} - -// PromptOption is a function that configures a Prompt. -// It provides a flexible way to set various properties of a Prompt using the functional options pattern. -type PromptOption func(*Prompt) - -// ArgumentOption is a function that configures a PromptArgument. -// It allows for flexible configuration of prompt arguments using the functional options pattern. -type ArgumentOption func(*PromptArgument) - -// -// Core Prompt Functions -// - -// NewPrompt creates a new Prompt with the given name and options. -// The prompt will be configured based on the provided options. -// Options are applied in order, allowing for flexible prompt configuration. -func NewPrompt(name string, opts ...PromptOption) Prompt { - prompt := Prompt{ - Name: name, - } - - for _, opt := range opts { - opt(&prompt) - } - - return prompt -} - -// WithPromptDescription adds a description to the Prompt. -// The description should provide a clear, human-readable explanation of what the prompt does. -func WithPromptDescription(description string) PromptOption { - return func(p *Prompt) { - p.Description = description - } -} - -// WithArgument adds an argument to the prompt's argument list. -// The argument will be configured based on the provided options. -func WithArgument(name string, opts ...ArgumentOption) PromptOption { - return func(p *Prompt) { - arg := PromptArgument{ - Name: name, - } - - for _, opt := range opts { - opt(&arg) - } - - if p.Arguments == nil { - p.Arguments = make([]PromptArgument, 0) - } - p.Arguments = append(p.Arguments, arg) - } -} - -// -// Argument Options -// - -// ArgumentDescription adds a description to a prompt argument. -// The description should explain the purpose and expected values of the argument. -func ArgumentDescription(desc string) ArgumentOption { - return func(arg *PromptArgument) { - arg.Description = desc - } -} - -// RequiredArgument marks an argument as required in the prompt. -// Required arguments must be provided when getting the prompt. -func RequiredArgument() ArgumentOption { - return func(arg *PromptArgument) { - arg.Required = true - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go deleted file mode 100644 index 07a59a322..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go +++ /dev/null @@ -1,99 +0,0 @@ -package mcp - -import "github.com/yosida95/uritemplate/v3" - -// ResourceOption is a function that configures a Resource. -// It provides a flexible way to set various properties of a Resource using the functional options pattern. -type ResourceOption func(*Resource) - -// NewResource creates a new Resource with the given URI, name and options. -// The resource will be configured based on the provided options. -// Options are applied in order, allowing for flexible resource configuration. -func NewResource(uri string, name string, opts ...ResourceOption) Resource { - resource := Resource{ - URI: uri, - Name: name, - } - - for _, opt := range opts { - opt(&resource) - } - - return resource -} - -// WithResourceDescription adds a description to the Resource. -// The description should provide a clear, human-readable explanation of what the resource represents. -func WithResourceDescription(description string) ResourceOption { - return func(r *Resource) { - r.Description = description - } -} - -// WithMIMEType sets the MIME type for the Resource. -// This should indicate the format of the resource's contents. -func WithMIMEType(mimeType string) ResourceOption { - return func(r *Resource) { - r.MIMEType = mimeType - } -} - -// WithAnnotations adds annotations to the Resource. -// Annotations can provide additional metadata about the resource's intended use. -func WithAnnotations(audience []Role, priority float64) ResourceOption { - return func(r *Resource) { - if r.Annotations == nil { - r.Annotations = &Annotations{} - } - r.Annotations.Audience = audience - r.Annotations.Priority = priority - } -} - -// ResourceTemplateOption is a function that configures a ResourceTemplate. -// It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. -type ResourceTemplateOption func(*ResourceTemplate) - -// NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. -// The template will be configured based on the provided options. -// Options are applied in order, allowing for flexible template configuration. -func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { - template := ResourceTemplate{ - URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, - Name: name, - } - - for _, opt := range opts { - opt(&template) - } - - return template -} - -// WithTemplateDescription adds a description to the ResourceTemplate. -// The description should provide a clear, human-readable explanation of what resources this template represents. -func WithTemplateDescription(description string) ResourceTemplateOption { - return func(t *ResourceTemplate) { - t.Description = description - } -} - -// WithTemplateMIMEType sets the MIME type for the ResourceTemplate. -// This should only be set if all resources matching this template will have the same type. -func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { - return func(t *ResourceTemplate) { - t.MIMEType = mimeType - } -} - -// WithTemplateAnnotations adds annotations to the ResourceTemplate. -// Annotations can provide additional metadata about the template's intended use. -func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { - return func(t *ResourceTemplate) { - if t.Annotations == nil { - t.Annotations = &Annotations{} - } - t.Annotations.Audience = audience - t.Annotations.Priority = priority - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go deleted file mode 100644 index 42e888d52..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go +++ /dev/null @@ -1,1331 +0,0 @@ -package mcp - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "reflect" - "strconv" - - "github.com/invopop/jsonschema" -) - -var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") - -// ListToolsRequest is sent from the client to request a list of tools the -// server has. -type ListToolsRequest struct { - PaginatedRequest - Header http.Header `json:"-"` -} - -// ListToolsResult is the server's response to a tools/list request from the -// client. -type ListToolsResult struct { - PaginatedResult - Tools []Tool `json:"tools"` -} - -// CallToolResult is the server's response to a tool call. -// -// Any errors that originate from the tool SHOULD be reported inside the result -// object, with `isError` set to true, _not_ as an MCP protocol-level error -// response. Otherwise, the LLM would not be able to see that an error occurred -// and self-correct. -// -// However, any errors in _finding_ the tool, an error indicating that the -// server does not support tool calls, or any other exceptional conditions, -// should be reported as an MCP error response. -type CallToolResult struct { - Result - Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource - // Structured content returned as a JSON object in the structuredContent field of a result. - // For backwards compatibility, a tool that returns structured content SHOULD also return - // functionally equivalent unstructured content. - StructuredContent any `json:"structuredContent,omitempty"` - // Whether the tool call ended in an error. - // - // If not set, this is assumed to be false (the call was successful). - IsError bool `json:"isError,omitempty"` -} - -// CallToolRequest is used by the client to invoke a tool provided by the server. -type CallToolRequest struct { - Request - Header http.Header `json:"-"` // HTTP headers from the original request - Params CallToolParams `json:"params"` -} - -type CallToolParams struct { - Name string `json:"name"` - Arguments any `json:"arguments,omitempty"` - Meta *Meta `json:"_meta,omitempty"` -} - -// GetArguments returns the Arguments as map[string]any for backward compatibility -// If Arguments is not a map, it returns an empty map -func (r CallToolRequest) GetArguments() map[string]any { - if args, ok := r.Params.Arguments.(map[string]any); ok { - return args - } - return nil -} - -// GetRawArguments returns the Arguments as-is without type conversion -// This allows users to access the raw arguments in any format -func (r CallToolRequest) GetRawArguments() any { - return r.Params.Arguments -} - -// BindArguments unmarshals the Arguments into the provided struct -// This is useful for working with strongly-typed arguments -func (r CallToolRequest) BindArguments(target any) error { - if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr { - return fmt.Errorf("target must be a non-nil pointer") - } - - // Fast-path: already raw JSON - if raw, ok := r.Params.Arguments.(json.RawMessage); ok { - return json.Unmarshal(raw, target) - } - - data, err := json.Marshal(r.Params.Arguments) - if err != nil { - return fmt.Errorf("failed to marshal arguments: %w", err) - } - - return json.Unmarshal(data, target) -} - -// GetString returns a string argument by key, or the default value if not found -func (r CallToolRequest) GetString(key string, defaultValue string) string { - args := r.GetArguments() - if val, ok := args[key]; ok { - if str, ok := val.(string); ok { - return str - } - } - return defaultValue -} - -// RequireString returns a string argument by key, or an error if not found or not a string -func (r CallToolRequest) RequireString(key string) (string, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - if str, ok := val.(string); ok { - return str, nil - } - return "", fmt.Errorf("argument %q is not a string", key) - } - return "", fmt.Errorf("required argument %q not found", key) -} - -// GetInt returns an int argument by key, or the default value if not found -func (r CallToolRequest) GetInt(key string, defaultValue int) int { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case int: - return v - case float64: - return int(v) - case string: - if i, err := strconv.Atoi(v); err == nil { - return i - } - } - } - return defaultValue -} - -// RequireInt returns an int argument by key, or an error if not found or not convertible to int -func (r CallToolRequest) RequireInt(key string) (int, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case int: - return v, nil - case float64: - return int(v), nil - case string: - if i, err := strconv.Atoi(v); err == nil { - return i, nil - } - return 0, fmt.Errorf("argument %q cannot be converted to int", key) - default: - return 0, fmt.Errorf("argument %q is not an int", key) - } - } - return 0, fmt.Errorf("required argument %q not found", key) -} - -// GetFloat returns a float64 argument by key, or the default value if not found -func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case float64: - return v - case int: - return float64(v) - case string: - if f, err := strconv.ParseFloat(v, 64); err == nil { - return f - } - } - } - return defaultValue -} - -// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64 -func (r CallToolRequest) RequireFloat(key string) (float64, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case float64: - return v, nil - case int: - return float64(v), nil - case string: - if f, err := strconv.ParseFloat(v, 64); err == nil { - return f, nil - } - return 0, fmt.Errorf("argument %q cannot be converted to float64", key) - default: - return 0, fmt.Errorf("argument %q is not a float64", key) - } - } - return 0, fmt.Errorf("required argument %q not found", key) -} - -// GetBool returns a bool argument by key, or the default value if not found -func (r CallToolRequest) GetBool(key string, defaultValue bool) bool { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case bool: - return v - case string: - if b, err := strconv.ParseBool(v); err == nil { - return b - } - case int: - return v != 0 - case float64: - return v != 0 - } - } - return defaultValue -} - -// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool -func (r CallToolRequest) RequireBool(key string) (bool, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case bool: - return v, nil - case string: - if b, err := strconv.ParseBool(v); err == nil { - return b, nil - } - return false, fmt.Errorf("argument %q cannot be converted to bool", key) - case int: - return v != 0, nil - case float64: - return v != 0, nil - default: - return false, fmt.Errorf("argument %q is not a bool", key) - } - } - return false, fmt.Errorf("required argument %q not found", key) -} - -// GetStringSlice returns a string slice argument by key, or the default value if not found -func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []string: - return v - case []any: - result := make([]string, 0, len(v)) - for _, item := range v { - if str, ok := item.(string); ok { - result = append(result, str) - } - } - return result - } - } - return defaultValue -} - -// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice -func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []string: - return v, nil - case []any: - result := make([]string, 0, len(v)) - for i, item := range v { - if str, ok := item.(string); ok { - result = append(result, str) - } else { - return nil, fmt.Errorf("item %d in argument %q is not a string", i, key) - } - } - return result, nil - default: - return nil, fmt.Errorf("argument %q is not a string slice", key) - } - } - return nil, fmt.Errorf("required argument %q not found", key) -} - -// GetIntSlice returns an int slice argument by key, or the default value if not found -func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []int: - return v - case []any: - result := make([]int, 0, len(v)) - for _, item := range v { - switch num := item.(type) { - case int: - result = append(result, num) - case float64: - result = append(result, int(num)) - case string: - if i, err := strconv.Atoi(num); err == nil { - result = append(result, i) - } - } - } - return result - } - } - return defaultValue -} - -// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice -func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []int: - return v, nil - case []any: - result := make([]int, 0, len(v)) - for i, item := range v { - switch num := item.(type) { - case int: - result = append(result, num) - case float64: - result = append(result, int(num)) - case string: - if i, err := strconv.Atoi(num); err == nil { - result = append(result, i) - } else { - return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key) - } - default: - return nil, fmt.Errorf("item %d in argument %q is not an int", i, key) - } - } - return result, nil - default: - return nil, fmt.Errorf("argument %q is not an int slice", key) - } - } - return nil, fmt.Errorf("required argument %q not found", key) -} - -// GetFloatSlice returns a float64 slice argument by key, or the default value if not found -func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []float64: - return v - case []any: - result := make([]float64, 0, len(v)) - for _, item := range v { - switch num := item.(type) { - case float64: - result = append(result, num) - case int: - result = append(result, float64(num)) - case string: - if f, err := strconv.ParseFloat(num, 64); err == nil { - result = append(result, f) - } - } - } - return result - } - } - return defaultValue -} - -// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice -func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []float64: - return v, nil - case []any: - result := make([]float64, 0, len(v)) - for i, item := range v { - switch num := item.(type) { - case float64: - result = append(result, num) - case int: - result = append(result, float64(num)) - case string: - if f, err := strconv.ParseFloat(num, 64); err == nil { - result = append(result, f) - } else { - return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key) - } - default: - return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key) - } - } - return result, nil - default: - return nil, fmt.Errorf("argument %q is not a float64 slice", key) - } - } - return nil, fmt.Errorf("required argument %q not found", key) -} - -// GetBoolSlice returns a bool slice argument by key, or the default value if not found -func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []bool: - return v - case []any: - result := make([]bool, 0, len(v)) - for _, item := range v { - switch b := item.(type) { - case bool: - result = append(result, b) - case string: - if parsed, err := strconv.ParseBool(b); err == nil { - result = append(result, parsed) - } - case int: - result = append(result, b != 0) - case float64: - result = append(result, b != 0) - } - } - return result - } - } - return defaultValue -} - -// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice -func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { - args := r.GetArguments() - if val, ok := args[key]; ok { - switch v := val.(type) { - case []bool: - return v, nil - case []any: - result := make([]bool, 0, len(v)) - for i, item := range v { - switch b := item.(type) { - case bool: - result = append(result, b) - case string: - if parsed, err := strconv.ParseBool(b); err == nil { - result = append(result, parsed) - } else { - return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key) - } - case int: - result = append(result, b != 0) - case float64: - result = append(result, b != 0) - default: - return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key) - } - } - return result, nil - default: - return nil, fmt.Errorf("argument %q is not a bool slice", key) - } - } - return nil, fmt.Errorf("required argument %q not found", key) -} - -// MarshalJSON implements custom JSON marshaling for CallToolResult -func (r CallToolResult) MarshalJSON() ([]byte, error) { - m := make(map[string]any) - - // Marshal Meta if present - if r.Meta != nil { - m["_meta"] = r.Meta - } - - // Marshal Content array - content := make([]any, len(r.Content)) - for i, c := range r.Content { - content[i] = c - } - m["content"] = content - - // Marshal StructuredContent if present - if r.StructuredContent != nil { - m["structuredContent"] = r.StructuredContent - } - - // Marshal IsError if true - if r.IsError { - m["isError"] = r.IsError - } - - return json.Marshal(m) -} - -// UnmarshalJSON implements custom JSON unmarshaling for CallToolResult -func (r *CallToolResult) UnmarshalJSON(data []byte) error { - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - - // Unmarshal Meta - if meta, ok := raw["_meta"]; ok { - if metaMap, ok := meta.(map[string]any); ok { - r.Meta = NewMetaFromMap(metaMap) - } - } - - // Unmarshal Content array - if contentRaw, ok := raw["content"]; ok { - if contentArray, ok := contentRaw.([]any); ok { - r.Content = make([]Content, len(contentArray)) - for i, item := range contentArray { - itemBytes, err := json.Marshal(item) - if err != nil { - return err - } - content, err := UnmarshalContent(itemBytes) - if err != nil { - return err - } - r.Content[i] = content - } - } - } - - // Unmarshal StructuredContent if present - if structured, ok := raw["structuredContent"]; ok { - r.StructuredContent = structured - } - - // Unmarshal IsError - if isError, ok := raw["isError"]; ok { - if isErrorBool, ok := isError.(bool); ok { - r.IsError = isErrorBool - } - } - - return nil -} - -// ToolListChangedNotification is an optional notification from the server to -// the client, informing it that the list of tools it offers has changed. This may -// be issued by servers without any previous subscription from the client. -type ToolListChangedNotification struct { - Notification -} - -// Tool represents the definition for a tool the client can call. -type Tool struct { - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - // The name of the tool. - Name string `json:"name"` - // A human-readable description of the tool. - Description string `json:"description,omitempty"` - // A JSON Schema object defining the expected parameters for the tool. - InputSchema ToolInputSchema `json:"inputSchema"` - // Alternative to InputSchema - allows arbitrary JSON Schema to be provided - RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling - // A JSON Schema object defining the expected output returned by the tool . - OutputSchema ToolOutputSchema `json:"outputSchema,omitempty"` - // Optional JSON Schema defining expected output structure - RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling - // Optional properties describing tool behavior - Annotations ToolAnnotation `json:"annotations"` -} - -// GetName returns the name of the tool. -func (t Tool) GetName() string { - return t.Name -} - -// MarshalJSON implements the json.Marshaler interface for Tool. -// It handles marshaling either InputSchema or RawInputSchema based on which is set. -func (t Tool) MarshalJSON() ([]byte, error) { - // Create a map to build the JSON structure - m := make(map[string]any, 5) - - // Add the name and description - m["name"] = t.Name - if t.Description != "" { - m["description"] = t.Description - } - - // Determine which input schema to use - if t.RawInputSchema != nil { - if t.InputSchema.Type != "" { - return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) - } - m["inputSchema"] = t.RawInputSchema - } else { - // Use the structured InputSchema - m["inputSchema"] = t.InputSchema - } - - // Add output schema if present - if t.RawOutputSchema != nil { - if t.OutputSchema.Type != "" { - return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) - } - m["outputSchema"] = t.RawOutputSchema - } else if t.OutputSchema.Type != "" { // If no output schema is specified, do not return anything - m["outputSchema"] = t.OutputSchema - } - - m["annotations"] = t.Annotations - - // Marshal Meta if present - if t.Meta != nil { - m["_meta"] = t.Meta - } - - return json.Marshal(m) -} - -// ToolArgumentsSchema represents a JSON Schema for tool arguments. -type ToolArgumentsSchema struct { - Defs map[string]any `json:"$defs,omitempty"` - Type string `json:"type"` - Properties map[string]any `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} - -type ToolInputSchema ToolArgumentsSchema // For retro-compatibility -type ToolOutputSchema ToolArgumentsSchema - -// MarshalJSON implements the json.Marshaler interface for ToolInputSchema. -func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { - m := make(map[string]any) - m["type"] = tis.Type - - if tis.Defs != nil { - m["$defs"] = tis.Defs - } - - // Marshal Properties to '{}' rather than `nil` when its length equals zero - if tis.Properties != nil { - m["properties"] = tis.Properties - } - - if len(tis.Required) > 0 { - m["required"] = tis.Required - } - - return json.Marshal(m) -} - -// UnmarshalJSON implements the json.Unmarshaler interface for ToolArgumentsSchema. -// It handles both "$defs" (JSON Schema 2019-09+) and "definitions" (JSON Schema draft-07) -// by reading either field and storing it in the Defs field. -func (tis *ToolArgumentsSchema) UnmarshalJSON(data []byte) error { - // Use a temporary type to avoid infinite recursion - type Alias ToolArgumentsSchema - aux := &struct { - Definitions map[string]any `json:"definitions,omitempty"` - *Alias - }{ - Alias: (*Alias)(tis), - } - - if err := json.Unmarshal(data, aux); err != nil { - return err - } - - // If $defs wasn't provided but definitions was, use definitions - if tis.Defs == nil && aux.Definitions != nil { - tis.Defs = aux.Definitions - } - - return nil -} - -type ToolAnnotation struct { - // Human-readable title for the tool - Title string `json:"title,omitempty"` - // If true, the tool does not modify its environment - ReadOnlyHint *bool `json:"readOnlyHint,omitempty"` - // If true, the tool may perform destructive updates - DestructiveHint *bool `json:"destructiveHint,omitempty"` - // If true, repeated calls with same args have no additional effect - IdempotentHint *bool `json:"idempotentHint,omitempty"` - // If true, tool interacts with external entities - OpenWorldHint *bool `json:"openWorldHint,omitempty"` -} - -// ToolOption is a function that configures a Tool. -// It provides a flexible way to set various properties of a Tool using the functional options pattern. -type ToolOption func(*Tool) - -// PropertyOption is a function that configures a property in a Tool's input schema. -// It allows for flexible configuration of JSON Schema properties using the functional options pattern. -type PropertyOption func(map[string]any) - -// -// Core Tool Functions -// - -// NewTool creates a new Tool with the given name and options. -// The tool will have an object-type input schema with configurable properties. -// Options are applied in order, allowing for flexible tool configuration. -func NewTool(name string, opts ...ToolOption) Tool { - tool := Tool{ - Name: name, - InputSchema: ToolInputSchema{ - Type: "object", - Properties: make(map[string]any), - Required: nil, // Will be omitted from JSON if empty - }, - Annotations: ToolAnnotation{ - Title: "", - ReadOnlyHint: ToBoolPtr(false), - DestructiveHint: ToBoolPtr(true), - IdempotentHint: ToBoolPtr(false), - OpenWorldHint: ToBoolPtr(true), - }, - } - - for _, opt := range opts { - opt(&tool) - } - - return tool -} - -// NewToolWithRawSchema creates a new Tool with the given name and a raw JSON -// Schema. This allows for arbitrary JSON Schema to be used for the tool's input -// schema. -// -// NOTE a [Tool] built in such a way is incompatible with the [ToolOption] and -// runtime errors will result from supplying a [ToolOption] to a [Tool] built -// with this function. -func NewToolWithRawSchema(name, description string, schema json.RawMessage) Tool { - tool := Tool{ - Name: name, - Description: description, - RawInputSchema: schema, - } - - return tool -} - -// WithDescription adds a description to the Tool. -// The description should provide a clear, human-readable explanation of what the tool does. -func WithDescription(description string) ToolOption { - return func(t *Tool) { - t.Description = description - } -} - -// WithInputSchema creates a ToolOption that sets the input schema for a tool. -// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. -func WithInputSchema[T any]() ToolOption { - return func(t *Tool) { - var zero T - - // Generate schema using invopop/jsonschema library - // Configure reflector to generate clean, MCP-compatible schemas - reflector := jsonschema.Reflector{ - DoNotReference: true, // Removes $defs map, outputs entire structure inline - Anonymous: true, // Hides auto-generated Schema IDs - AllowAdditionalProperties: true, // Removes additionalProperties: false - } - schema := reflector.Reflect(zero) - - // Clean up schema for MCP compliance - schema.Version = "" // Remove $schema field - - // Convert to raw JSON for MCP - mcpSchema, err := json.Marshal(schema) - if err != nil { - // Skip and maintain backward compatibility - return - } - - t.InputSchema.Type = "" - t.RawInputSchema = json.RawMessage(mcpSchema) - } -} - -// WithRawInputSchema sets a raw JSON schema for the tool's input. -// Use this when you need full control over the schema or when working with -// complex schemas that can't be generated from Go types. The jsonschema library -// can handle complex schemas and provides nice extension points, so be sure to -// check that out before using this. -func WithRawInputSchema(schema json.RawMessage) ToolOption { - return func(t *Tool) { - t.RawInputSchema = schema - } -} - -// WithOutputSchema creates a ToolOption that sets the output schema for a tool. -// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. -func WithOutputSchema[T any]() ToolOption { - return func(t *Tool) { - var zero T - - // Generate schema using invopop/jsonschema library - // Configure reflector to generate clean, MCP-compatible schemas - reflector := jsonschema.Reflector{ - DoNotReference: true, // Removes $defs map, outputs entire structure inline - Anonymous: true, // Hides auto-generated Schema IDs - AllowAdditionalProperties: true, // Removes additionalProperties: false - } - schema := reflector.Reflect(zero) - - // Clean up schema for MCP compliance - schema.Version = "" // Remove $schema field - - // Convert to raw JSON for MCP - mcpSchema, err := json.Marshal(schema) - if err != nil { - // Skip and maintain backward compatibility - return - } - - // Retrieve the schema from raw JSON - if err := json.Unmarshal(mcpSchema, &t.OutputSchema); err != nil { - // Skip and maintain backward compatibility - return - } - - // Always set the type to "object" as of the current MCP spec - // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema - t.OutputSchema.Type = "object" - } -} - -// WithRawOutputSchema sets a raw JSON schema for the tool's output. -// Use this when you need full control over the schema or when working with -// complex schemas that can't be generated from Go types. The jsonschema library -// can handle complex schemas and provides nice extension points, so be sure to -// check that out before using this. -func WithRawOutputSchema(schema json.RawMessage) ToolOption { - return func(t *Tool) { - t.RawOutputSchema = schema - } -} - -// WithToolAnnotation adds optional hints about the Tool. -func WithToolAnnotation(annotation ToolAnnotation) ToolOption { - return func(t *Tool) { - t.Annotations = annotation - } -} - -// WithTitleAnnotation sets the Title field of the Tool's Annotations. -// It provides a human-readable title for the tool. -func WithTitleAnnotation(title string) ToolOption { - return func(t *Tool) { - t.Annotations.Title = title - } -} - -// WithReadOnlyHintAnnotation sets the ReadOnlyHint field of the Tool's Annotations. -// If true, it indicates the tool does not modify its environment. -func WithReadOnlyHintAnnotation(value bool) ToolOption { - return func(t *Tool) { - t.Annotations.ReadOnlyHint = &value - } -} - -// WithDestructiveHintAnnotation sets the DestructiveHint field of the Tool's Annotations. -// If true, it indicates the tool may perform destructive updates. -func WithDestructiveHintAnnotation(value bool) ToolOption { - return func(t *Tool) { - t.Annotations.DestructiveHint = &value - } -} - -// WithIdempotentHintAnnotation sets the IdempotentHint field of the Tool's Annotations. -// If true, it indicates repeated calls with the same arguments have no additional effect. -func WithIdempotentHintAnnotation(value bool) ToolOption { - return func(t *Tool) { - t.Annotations.IdempotentHint = &value - } -} - -// WithOpenWorldHintAnnotation sets the OpenWorldHint field of the Tool's Annotations. -// If true, it indicates the tool interacts with external entities. -func WithOpenWorldHintAnnotation(value bool) ToolOption { - return func(t *Tool) { - t.Annotations.OpenWorldHint = &value - } -} - -// -// Common Property Options -// - -// Description adds a description to a property in the JSON Schema. -// The description should explain the purpose and expected values of the property. -func Description(desc string) PropertyOption { - return func(schema map[string]any) { - schema["description"] = desc - } -} - -// Required marks a property as required in the tool's input schema. -// Required properties must be provided when using the tool. -func Required() PropertyOption { - return func(schema map[string]any) { - schema["required"] = true - } -} - -// Title adds a display-friendly title to a property in the JSON Schema. -// This title can be used by UI components to show a more readable property name. -func Title(title string) PropertyOption { - return func(schema map[string]any) { - schema["title"] = title - } -} - -// -// String Property Options -// - -// DefaultString sets the default value for a string property. -// This value will be used if the property is not explicitly provided. -func DefaultString(value string) PropertyOption { - return func(schema map[string]any) { - schema["default"] = value - } -} - -// Enum specifies a list of allowed values for a string property. -// The property value must be one of the specified enum values. -func Enum(values ...string) PropertyOption { - return func(schema map[string]any) { - schema["enum"] = values - } -} - -// MaxLength sets the maximum length for a string property. -// The string value must not exceed this length. -func MaxLength(max int) PropertyOption { - return func(schema map[string]any) { - schema["maxLength"] = max - } -} - -// MinLength sets the minimum length for a string property. -// The string value must be at least this length. -func MinLength(min int) PropertyOption { - return func(schema map[string]any) { - schema["minLength"] = min - } -} - -// Pattern sets a regex pattern that a string property must match. -// The string value must conform to the specified regular expression. -func Pattern(pattern string) PropertyOption { - return func(schema map[string]any) { - schema["pattern"] = pattern - } -} - -// -// Number Property Options -// - -// DefaultNumber sets the default value for a number property. -// This value will be used if the property is not explicitly provided. -func DefaultNumber(value float64) PropertyOption { - return func(schema map[string]any) { - schema["default"] = value - } -} - -// Max sets the maximum value for a number property. -// The number value must not exceed this maximum. -func Max(max float64) PropertyOption { - return func(schema map[string]any) { - schema["maximum"] = max - } -} - -// Min sets the minimum value for a number property. -// The number value must not be less than this minimum. -func Min(min float64) PropertyOption { - return func(schema map[string]any) { - schema["minimum"] = min - } -} - -// MultipleOf specifies that a number must be a multiple of the given value. -// The number value must be divisible by this value. -func MultipleOf(value float64) PropertyOption { - return func(schema map[string]any) { - schema["multipleOf"] = value - } -} - -// -// Boolean Property Options -// - -// DefaultBool sets the default value for a boolean property. -// This value will be used if the property is not explicitly provided. -func DefaultBool(value bool) PropertyOption { - return func(schema map[string]any) { - schema["default"] = value - } -} - -// -// Array Property Options -// - -// DefaultArray sets the default value for an array property. -// This value will be used if the property is not explicitly provided. -func DefaultArray[T any](value []T) PropertyOption { - return func(schema map[string]any) { - schema["default"] = value - } -} - -// -// Property Type Helpers -// - -// WithBoolean adds a boolean property to the tool schema. -// It accepts property options to configure the boolean property's behavior and constraints. -func WithBoolean(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{ - "type": "boolean", - } - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// WithNumber adds a number property to the tool schema. -// It accepts property options to configure the number property's behavior and constraints. -func WithNumber(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{ - "type": "number", - } - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// WithString adds a string property to the tool schema. -// It accepts property options to configure the string property's behavior and constraints. -func WithString(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{ - "type": "string", - } - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// WithObject adds an object property to the tool schema. -// It accepts property options to configure the object property's behavior and constraints. -func WithObject(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{ - "type": "object", - "properties": map[string]any{}, - } - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// WithArray returns a ToolOption that adds an array-typed property with the given name to a Tool's input schema. -// It applies provided PropertyOption functions to configure the property's schema, moves a `required` flag -// from the property schema into the Tool's InputSchema.Required slice when present, and registers the resulting -// schema under InputSchema.Properties[name]. -func WithArray(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{ - "type": "array", - } - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// WithAny adds an input property named name with no predefined JSON Schema type to the Tool's input schema. -// The returned ToolOption applies the provided PropertyOption functions to the property's schema, moves a property-level -// `required` flag into the Tool's InputSchema.Required list if present, and stores the resulting schema under InputSchema.Properties[name]. -func WithAny(name string, opts ...PropertyOption) ToolOption { - return func(t *Tool) { - schema := map[string]any{} - - for _, opt := range opts { - opt(schema) - } - - // Remove required from property schema and add to InputSchema.required - if required, ok := schema["required"].(bool); ok && required { - delete(schema, "required") - t.InputSchema.Required = append(t.InputSchema.Required, name) - } - - t.InputSchema.Properties[name] = schema - } -} - -// Properties sets the "properties" map for an object schema. -// The returned PropertyOption stores the provided map under the schema's "properties" key. -func Properties(props map[string]any) PropertyOption { - return func(schema map[string]any) { - schema["properties"] = props - } -} - -// AdditionalProperties specifies whether additional properties are allowed in the object -// or defines a schema for additional properties -func AdditionalProperties(schema any) PropertyOption { - return func(schemaMap map[string]any) { - schemaMap["additionalProperties"] = schema - } -} - -// MinProperties sets the minimum number of properties for an object -func MinProperties(min int) PropertyOption { - return func(schema map[string]any) { - schema["minProperties"] = min - } -} - -// MaxProperties sets the maximum number of properties for an object -func MaxProperties(max int) PropertyOption { - return func(schema map[string]any) { - schema["maxProperties"] = max - } -} - -// PropertyNames defines a schema for property names in an object -func PropertyNames(schema map[string]any) PropertyOption { - return func(schemaMap map[string]any) { - schemaMap["propertyNames"] = schema - } -} - -// Items defines the schema for array items. -// Accepts any schema definition for maximum flexibility. -// -// Example: -// -// Items(map[string]any{ -// "type": "object", -// "properties": map[string]any{ -// "name": map[string]any{"type": "string"}, -// "age": map[string]any{"type": "number"}, -// }, -// }) -// -// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. -func Items(schema any) PropertyOption { - return func(schemaMap map[string]any) { - schemaMap["items"] = schema - } -} - -// MinItems sets the minimum number of items for an array -func MinItems(min int) PropertyOption { - return func(schema map[string]any) { - schema["minItems"] = min - } -} - -// MaxItems sets the maximum number of items for an array -func MaxItems(max int) PropertyOption { - return func(schema map[string]any) { - schema["maxItems"] = max - } -} - -// UniqueItems specifies whether array items must be unique -func UniqueItems(unique bool) PropertyOption { - return func(schema map[string]any) { - schema["uniqueItems"] = unique - } -} - -// WithStringItems configures an array's items to be of type string. -// -// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() -// Note: Options like Required() are not valid for item schemas and will be ignored. -// -// Examples: -// -// mcp.WithArray("tags", mcp.WithStringItems()) -// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) -// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) -// -// Limitations: Only supports simple string arrays. Use Items() for complex objects. -func WithStringItems(opts ...PropertyOption) PropertyOption { - return func(schema map[string]any) { - itemSchema := map[string]any{ - "type": "string", - } - - for _, opt := range opts { - opt(itemSchema) - } - - schema["items"] = itemSchema - } -} - -// WithStringEnumItems configures an array's items to be of type string with a specified enum. -// Example: -// -// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) -// -// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. -func WithStringEnumItems(values []string) PropertyOption { - return func(schema map[string]any) { - schema["items"] = map[string]any{ - "type": "string", - "enum": values, - } - } -} - -// WithNumberItems configures an array's items to be of type number. -// -// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() -// Note: Options like Required() are not valid for item schemas and will be ignored. -// -// Examples: -// -// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) -// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) -// -// Limitations: Only supports simple number arrays. Use Items() for complex objects. -func WithNumberItems(opts ...PropertyOption) PropertyOption { - return func(schema map[string]any) { - itemSchema := map[string]any{ - "type": "number", - } - - for _, opt := range opts { - opt(itemSchema) - } - - schema["items"] = itemSchema - } -} - -// WithBooleanItems configures an array's items to be of type boolean. -// -// Supported options: Description(), DefaultBool() -// Note: Options like Required() are not valid for item schemas and will be ignored. -// -// Examples: -// -// mcp.WithArray("flags", mcp.WithBooleanItems()) -// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) -// -// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. -func WithBooleanItems(opts ...PropertyOption) PropertyOption { - return func(schema map[string]any) { - itemSchema := map[string]any{ - "type": "boolean", - } - - for _, opt := range opts { - opt(itemSchema) - } - - schema["items"] = itemSchema - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go deleted file mode 100644 index a03a19dd7..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go +++ /dev/null @@ -1,42 +0,0 @@ -package mcp - -import ( - "context" - "fmt" -) - -// TypedToolHandlerFunc is a function that handles a tool call with typed arguments -type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) - -// StructuredToolHandlerFunc is a function that handles a tool call with typed arguments and returns structured output -type StructuredToolHandlerFunc[TArgs any, TResult any] func(ctx context.Context, request CallToolRequest, args TArgs) (TResult, error) - -// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct -func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - var args T - if err := request.BindArguments(&args); err != nil { - return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil - } - return handler(ctx, request, args) - } -} - -// NewStructuredToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct -// and returns structured output. It automatically creates both structured and -// text content (from the structured output) for backwards compatibility. -func NewStructuredToolHandler[TArgs any, TResult any](handler StructuredToolHandlerFunc[TArgs, TResult]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - var args TArgs - if err := request.BindArguments(&args); err != nil { - return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil - } - - result, err := handler(ctx, request, args) - if err != nil { - return NewToolResultError(fmt.Sprintf("tool execution failed: %v", err)), nil - } - - return NewToolResultStructuredOnly(result), nil - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go deleted file mode 100644 index 6e447c61c..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/types.go +++ /dev/null @@ -1,1252 +0,0 @@ -// Package mcp defines the core types and interfaces for the Model Context Protocol (MCP). -// MCP is a protocol for communication between LLM-powered applications and their supporting services. -package mcp - -import ( - "encoding/json" - "fmt" - "maps" - "net/http" - "strconv" - - "github.com/yosida95/uritemplate/v3" -) - -type MCPMethod string - -const ( - // MethodInitialize initiates connection and negotiates protocol capabilities. - // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - MethodInitialize MCPMethod = "initialize" - - // MethodPing verifies connection liveness between client and server. - // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ - MethodPing MCPMethod = "ping" - - // MethodResourcesList lists all available server resources. - // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ - MethodResourcesList MCPMethod = "resources/list" - - // MethodResourcesTemplatesList provides URI templates for constructing resource URIs. - // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ - MethodResourcesTemplatesList MCPMethod = "resources/templates/list" - - // MethodResourcesRead retrieves content of a specific resource by URI. - // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ - MethodResourcesRead MCPMethod = "resources/read" - - // MethodPromptsList lists all available prompt templates. - // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ - MethodPromptsList MCPMethod = "prompts/list" - - // MethodPromptsGet retrieves a specific prompt template with filled parameters. - // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ - MethodPromptsGet MCPMethod = "prompts/get" - - // MethodToolsList lists all available executable tools. - // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ - MethodToolsList MCPMethod = "tools/list" - - // MethodToolsCall invokes a specific tool with provided parameters. - // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ - MethodToolsCall MCPMethod = "tools/call" - - // MethodSetLogLevel configures the minimum log level for client - // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging - MethodSetLogLevel MCPMethod = "logging/setLevel" - - // MethodElicitationCreate requests additional information from the user during interactions. - // https://modelcontextprotocol.io/docs/concepts/elicitation - MethodElicitationCreate MCPMethod = "elicitation/create" - - // MethodListRoots requests roots list from the client during interactions. - // https://modelcontextprotocol.io/specification/2025-06-18/client/roots - MethodListRoots MCPMethod = "roots/list" - - // MethodNotificationResourcesListChanged notifies when the list of available resources changes. - // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification - MethodNotificationResourcesListChanged = "notifications/resources/list_changed" - - MethodNotificationResourceUpdated = "notifications/resources/updated" - - // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. - // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification - MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" - - // MethodNotificationToolsListChanged notifies when the list of available tools changes. - // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#list-changed-notification - MethodNotificationToolsListChanged = "notifications/tools/list_changed" - - // MethodNotificationRootsListChanged notifies when the list of available roots changes. - // https://modelcontextprotocol.io/specification/2025-06-18/client/roots#root-list-changes - MethodNotificationRootsListChanged = "notifications/roots/list_changed" -) - -type URITemplate struct { - *uritemplate.Template -} - -func (t *URITemplate) MarshalJSON() ([]byte, error) { - return json.Marshal(t.Raw()) -} - -func (t *URITemplate) UnmarshalJSON(data []byte) error { - var raw string - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - template, err := uritemplate.New(raw) - if err != nil { - return err - } - t.Template = template - return nil -} - -/* JSON-RPC types */ - -// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError -type JSONRPCMessage any - -// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. -const LATEST_PROTOCOL_VERSION = "2025-06-18" - -// ValidProtocolVersions lists all known valid MCP protocol versions. -var ValidProtocolVersions = []string{ - LATEST_PROTOCOL_VERSION, - "2025-03-26", - "2024-11-05", -} - -// JSONRPC_VERSION is the version of JSON-RPC used by MCP. -const JSONRPC_VERSION = "2.0" - -// ProgressToken is used to associate progress notifications with the original request. -type ProgressToken any - -// Cursor is an opaque token used to represent a cursor for pagination. -type Cursor string - -// Meta is metadata attached to a request's parameters. This can include fields -// formally defined by the protocol or other arbitrary data. -type Meta struct { - // If specified, the caller is requesting out-of-band progress - // notifications for this request (as represented by - // notifications/progress). The value of this parameter is an - // opaque token that will be attached to any subsequent - // notifications. The receiver is not obligated to provide these - // notifications. - ProgressToken ProgressToken - - // AdditionalFields are any fields present in the Meta that are not - // otherwise defined in the protocol. - AdditionalFields map[string]any -} - -func (m *Meta) MarshalJSON() ([]byte, error) { - raw := make(map[string]any) - if m.ProgressToken != nil { - raw["progressToken"] = m.ProgressToken - } - maps.Copy(raw, m.AdditionalFields) - - return json.Marshal(raw) -} - -func (m *Meta) UnmarshalJSON(data []byte) error { - raw := make(map[string]any) - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - m.ProgressToken = raw["progressToken"] - delete(raw, "progressToken") - m.AdditionalFields = raw - return nil -} - -func NewMetaFromMap(m map[string]any) *Meta { - progressToken := m["progressToken"] - if progressToken != nil { - delete(m, "progressToken") - } - - return &Meta{ - ProgressToken: progressToken, - AdditionalFields: m, - } -} - -type Request struct { - Method string `json:"method"` - Params RequestParams `json:"params,omitempty"` -} - -type RequestParams struct { - Meta *Meta `json:"_meta,omitempty"` -} - -type Params map[string]any - -type Notification struct { - Method string `json:"method"` - Params NotificationParams `json:"params,omitempty"` -} - -type NotificationParams struct { - // This parameter name is reserved by MCP to allow clients and - // servers to attach additional metadata to their notifications. - Meta map[string]any `json:"_meta,omitempty"` - - // Additional fields can be added to this map - AdditionalFields map[string]any `json:"-"` -} - -// MarshalJSON implements custom JSON marshaling -func (p NotificationParams) MarshalJSON() ([]byte, error) { - // Create a map to hold all fields - m := make(map[string]any) - - // Add Meta if it exists - if p.Meta != nil { - m["_meta"] = p.Meta - } - - // Add all additional fields - for k, v := range p.AdditionalFields { - // Ensure we don't override the _meta field - if k != "_meta" { - m[k] = v - } - } - - return json.Marshal(m) -} - -// UnmarshalJSON implements custom JSON unmarshaling -func (p *NotificationParams) UnmarshalJSON(data []byte) error { - // Create a map to hold all fields - var m map[string]any - if err := json.Unmarshal(data, &m); err != nil { - return err - } - - // Initialize maps if they're nil - if p.Meta == nil { - p.Meta = make(map[string]any) - } - if p.AdditionalFields == nil { - p.AdditionalFields = make(map[string]any) - } - - // Process all fields - for k, v := range m { - if k == "_meta" { - // Handle Meta field - if meta, ok := v.(map[string]any); ok { - p.Meta = meta - } - } else { - // Handle additional fields - p.AdditionalFields[k] = v - } - } - - return nil -} - -type Result struct { - // This result property is reserved by the protocol to allow clients and - // servers to attach additional metadata to their responses. - Meta *Meta `json:"_meta,omitempty"` -} - -// RequestId is a uniquely identifying ID for a request in JSON-RPC. -// It can be any JSON-serializable value, typically a number or string. -type RequestId struct { - value any -} - -// NewRequestId creates a new RequestId with the given value -func NewRequestId(value any) RequestId { - return RequestId{value: value} -} - -// Value returns the underlying value of the RequestId -func (r RequestId) Value() any { - return r.value -} - -// String returns a string representation of the RequestId -func (r RequestId) String() string { - switch v := r.value.(type) { - case string: - return "string:" + v - case int64: - return "int64:" + strconv.FormatInt(v, 10) - case float64: - if v == float64(int64(v)) { - return "int64:" + strconv.FormatInt(int64(v), 10) - } - return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) - case nil: - return "" - default: - return "unknown:" + fmt.Sprintf("%v", v) - } -} - -// IsNil returns true if the RequestId is nil -func (r RequestId) IsNil() bool { - return r.value == nil -} - -func (r RequestId) MarshalJSON() ([]byte, error) { - return json.Marshal(r.value) -} - -func (r *RequestId) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - r.value = nil - return nil - } - - // Try unmarshaling as string first - var s string - if err := json.Unmarshal(data, &s); err == nil { - r.value = s - return nil - } - - // JSON numbers are unmarshaled as float64 in Go - var f float64 - if err := json.Unmarshal(data, &f); err == nil { - if f == float64(int64(f)) { - r.value = int64(f) - } else { - r.value = f - } - return nil - } - - return fmt.Errorf("invalid request id: %s", string(data)) -} - -// JSONRPCRequest represents a request that expects a response. -type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Params any `json:"params,omitempty"` - Request -} - -// JSONRPCNotification represents a notification which does not expect a response. -type JSONRPCNotification struct { - JSONRPC string `json:"jsonrpc"` - Notification -} - -// JSONRPCResponse represents a successful (non-error) response to a request. -type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Result any `json:"result"` -} - -// JSONRPCError represents a non-successful (error) response to a request. -type JSONRPCError struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Error JSONRPCErrorDetails `json:"error"` -} - -// JSONRPCErrorDetails represents a JSON-RPC error for Go error handling. -// This is separate from the JSONRPCError type which represents the full JSON-RPC error response structure. -type JSONRPCErrorDetails struct { - // The error type that occurred. - Code int `json:"code"` - // A short description of the error. The message SHOULD be limited - // to a concise single sentence. - Message string `json:"message"` - // Additional information about the error. The value of this member - // is defined by the sender (e.g. detailed error information, nested errors etc.). - Data any `json:"data,omitempty"` -} - -// Standard JSON-RPC error codes -const ( - // PARSE_ERROR indicates invalid JSON was received by the server. - PARSE_ERROR = -32700 - - // INVALID_REQUEST indicates the JSON sent is not a valid Request object. - INVALID_REQUEST = -32600 - - // METHOD_NOT_FOUND indicates the method does not exist/is not available. - METHOD_NOT_FOUND = -32601 - - // INVALID_PARAMS indicates invalid method parameter(s). - INVALID_PARAMS = -32602 - - // INTERNAL_ERROR indicates internal JSON-RPC error. - INTERNAL_ERROR = -32603 - - // REQUEST_INTERRUPTED indicates a request was cancelled or timed out. - REQUEST_INTERRUPTED = -32800 -) - -// MCP error codes -const ( - // RESOURCE_NOT_FOUND indicates a requested resource was not found. - RESOURCE_NOT_FOUND = -32002 -) - -/* Empty result */ - -// EmptyResult represents a response that indicates success but carries no data. -type EmptyResult Result - -/* Cancellation */ - -// CancelledNotification can be sent by either side to indicate that it is -// cancelling a previously-issued request. -// -// The request SHOULD still be in-flight, but due to communication latency, it -// is always possible that this notification MAY arrive after the request has -// already finished. -// -// This notification indicates that the result will be unused, so any -// associated processing SHOULD cease. -// -// A client MUST NOT attempt to cancel its `initialize` request. -type CancelledNotification struct { - Notification - Params CancelledNotificationParams `json:"params"` -} - -type CancelledNotificationParams struct { - // The ID of the request to cancel. - // - // This MUST correspond to the ID of a request previously issued - // in the same direction. - RequestId RequestId `json:"requestId"` - - // An optional string describing the reason for the cancellation. This MAY - // be logged or presented to the user. - Reason string `json:"reason,omitempty"` -} - -/* Initialization */ - -// InitializeRequest is sent from the client to the server when it first -// connects, asking it to begin initialization. -type InitializeRequest struct { - Request - Params InitializeParams `json:"params"` - Header http.Header `json:"-"` -} - -type InitializeParams struct { - // The latest version of the Model Context Protocol that the client supports. - // The client MAY decide to support older versions as well. - ProtocolVersion string `json:"protocolVersion"` - Capabilities ClientCapabilities `json:"capabilities"` - ClientInfo Implementation `json:"clientInfo"` -} - -// InitializeResult is sent after receiving an initialize request from the -// client. -type InitializeResult struct { - Result - // The version of the Model Context Protocol that the server wants to use. - // This may not match the version that the client requested. If the client cannot - // support this version, it MUST disconnect. - ProtocolVersion string `json:"protocolVersion"` - Capabilities ServerCapabilities `json:"capabilities"` - ServerInfo Implementation `json:"serverInfo"` - // Instructions describing how to use the server and its features. - // - // This can be used by clients to improve the LLM's understanding of - // available tools, resources, etc. It can be thought of like a "hint" to the model. - // For example, this information MAY be added to the system prompt. - Instructions string `json:"instructions,omitempty"` -} - -// InitializedNotification is sent from the client to the server after -// initialization has finished. -type InitializedNotification struct { - Notification -} - -// ClientCapabilities represents capabilities a client may support. Known -// capabilities are defined here, in this schema, but this is not a closed set: any -// client can define its own, additional capabilities. -type ClientCapabilities struct { - // Experimental, non-standard capabilities that the client supports. - Experimental map[string]any `json:"experimental,omitempty"` - // Present if the client supports listing roots. - Roots *struct { - // Whether the client supports notifications for changes to the roots list. - ListChanged bool `json:"listChanged,omitempty"` - } `json:"roots,omitempty"` - // Present if the client supports sampling from an LLM. - Sampling *struct{} `json:"sampling,omitempty"` - // Present if the client supports elicitation requests from the server. - Elicitation *struct{} `json:"elicitation,omitempty"` -} - -// ServerCapabilities represents capabilities that a server may support. Known -// capabilities are defined here, in this schema, but this is not a closed set: any -// server can define its own, additional capabilities. -type ServerCapabilities struct { - // Experimental, non-standard capabilities that the server supports. - Experimental map[string]any `json:"experimental,omitempty"` - // Present if the server supports sending log messages to the client. - Logging *struct{} `json:"logging,omitempty"` - // Present if the server offers any prompt templates. - Prompts *struct { - // Whether this server supports notifications for changes to the prompt list. - ListChanged bool `json:"listChanged,omitempty"` - } `json:"prompts,omitempty"` - // Present if the server offers any resources to read. - Resources *struct { - // Whether this server supports subscribing to resource updates. - Subscribe bool `json:"subscribe,omitempty"` - // Whether this server supports notifications for changes to the resource - // list. - ListChanged bool `json:"listChanged,omitempty"` - } `json:"resources,omitempty"` - // Present if the server supports sending sampling requests to clients. - Sampling *struct{} `json:"sampling,omitempty"` - // Present if the server offers any tools to call. - Tools *struct { - // Whether this server supports notifications for changes to the tool list. - ListChanged bool `json:"listChanged,omitempty"` - } `json:"tools,omitempty"` - // Present if the server supports elicitation requests to the client. - Elicitation *struct{} `json:"elicitation,omitempty"` - // Present if the server supports roots requests to the client. - Roots *struct{} `json:"roots,omitempty"` -} - -// Implementation describes the name and version of an MCP implementation. -type Implementation struct { - Name string `json:"name"` - Version string `json:"version"` - Title string `json:"title,omitempty"` -} - -/* Ping */ - -// PingRequest represents a ping, issued by either the server or the client, -// to check that the other party is still alive. The receiver must promptly respond, -// or else may be disconnected. -type PingRequest struct { - Request - Header http.Header `json:"-"` -} - -/* Progress notifications */ - -// ProgressNotification is an out-of-band notification used to inform the -// receiver of a progress update for a long-running request. -type ProgressNotification struct { - Notification - Params ProgressNotificationParams `json:"params"` -} - -type ProgressNotificationParams struct { - // The progress token which was given in the initial request, used to - // associate this notification with the request that is proceeding. - ProgressToken ProgressToken `json:"progressToken"` - // The progress thus far. This should increase every time progress is made, - // even if the total is unknown. - Progress float64 `json:"progress"` - // Total number of items to process (or total progress required), if known. - Total float64 `json:"total,omitempty"` - // Message related to progress. This should provide relevant human-readable - // progress information. - Message string `json:"message,omitempty"` -} - -/* Pagination */ - -type PaginatedRequest struct { - Request - Params PaginatedParams `json:"params,omitempty"` -} - -type PaginatedParams struct { - // An opaque token representing the current pagination position. - // If provided, the server should return results starting after this cursor. - Cursor Cursor `json:"cursor,omitempty"` -} - -type PaginatedResult struct { - Result - // An opaque token representing the pagination position after the last - // returned result. - // If present, there may be more results available. - NextCursor Cursor `json:"nextCursor,omitempty"` -} - -/* Resources */ - -// ListResourcesRequest is sent from the client to request a list of resources -// the server has. -type ListResourcesRequest struct { - PaginatedRequest - Header http.Header `json:"-"` -} - -// ListResourcesResult is the server's response to a resources/list request -// from the client. -type ListResourcesResult struct { - PaginatedResult - Resources []Resource `json:"resources"` -} - -// ListResourceTemplatesRequest is sent from the client to request a list of -// resource templates the server has. -type ListResourceTemplatesRequest struct { - PaginatedRequest - Header http.Header `json:"-"` -} - -// ListResourceTemplatesResult is the server's response to a -// resources/templates/list request from the client. -type ListResourceTemplatesResult struct { - PaginatedResult - ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` -} - -// ReadResourceRequest is sent from the client to the server, to read a -// specific resource URI. -type ReadResourceRequest struct { - Request - Header http.Header `json:"-"` - Params ReadResourceParams `json:"params"` -} - -type ReadResourceParams struct { - // The URI of the resource to read. The URI can use any protocol; it is up - // to the server how to interpret it. - URI string `json:"uri"` - // Arguments to pass to the resource handler - Arguments map[string]any `json:"arguments,omitempty"` -} - -// ReadResourceResult is the server's response to a resources/read request -// from the client. -type ReadResourceResult struct { - Result - Contents []ResourceContents `json:"contents"` // Can be TextResourceContents or BlobResourceContents -} - -// ResourceListChangedNotification is an optional notification from the server -// to the client, informing it that the list of resources it can read from has -// changed. This may be issued by servers without any previous subscription from -// the client. -type ResourceListChangedNotification struct { - Notification -} - -// SubscribeRequest is sent from the client to request resources/updated -// notifications from the server whenever a particular resource changes. -type SubscribeRequest struct { - Request - Params SubscribeParams `json:"params"` - Header http.Header `json:"-"` -} - -type SubscribeParams struct { - // The URI of the resource to subscribe to. The URI can use any protocol; it - // is up to the server how to interpret it. - URI string `json:"uri"` -} - -// UnsubscribeRequest is sent from the client to request cancellation of -// resources/updated notifications from the server. This should follow a previous -// resources/subscribe request. -type UnsubscribeRequest struct { - Request - Params UnsubscribeParams `json:"params"` - Header http.Header `json:"-"` -} - -type UnsubscribeParams struct { - // The URI of the resource to unsubscribe from. - URI string `json:"uri"` -} - -// ResourceUpdatedNotification is a notification from the server to the client, -// informing it that a resource has changed and may need to be read again. This -// should only be sent if the client previously sent a resources/subscribe request. -type ResourceUpdatedNotification struct { - Notification - Params ResourceUpdatedNotificationParams `json:"params"` -} -type ResourceUpdatedNotificationParams struct { - // The URI of the resource that has been updated. This might be a sub- - // resource of the one that the client actually subscribed to. - URI string `json:"uri"` -} - -// Resource represents a known resource that the server is capable of reading. -type Resource struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - // The URI of this resource. - URI string `json:"uri"` - // A human-readable name for this resource. - // - // This can be used by clients to populate UI elements. - Name string `json:"name"` - // A description of what this resource represents. - // - // This can be used by clients to improve the LLM's understanding of - // available resources. It can be thought of like a "hint" to the model. - Description string `json:"description,omitempty"` - // The MIME type of this resource, if known. - MIMEType string `json:"mimeType,omitempty"` -} - -// GetName returns the name of the resource. -func (r Resource) GetName() string { - return r.Name -} - -// ResourceTemplate represents a template description for resources available -// on the server. -type ResourceTemplate struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - // A URI template (according to RFC 6570) that can be used to construct - // resource URIs. - URITemplate *URITemplate `json:"uriTemplate"` - // A human-readable name for the type of resource this template refers to. - // - // This can be used by clients to populate UI elements. - Name string `json:"name"` - // A description of what this template is for. - // - // This can be used by clients to improve the LLM's understanding of - // available resources. It can be thought of like a "hint" to the model. - Description string `json:"description,omitempty"` - // The MIME type for all resources that match this template. This should only - // be included if all resources matching this template have the same type. - MIMEType string `json:"mimeType,omitempty"` -} - -// GetName returns the name of the resourceTemplate. -func (rt ResourceTemplate) GetName() string { - return rt.Name -} - -// ResourceContents represents the contents of a specific resource or sub- -// resource. -type ResourceContents interface { - isResourceContents() -} - -type TextResourceContents struct { - // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. - // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. - Meta map[string]any `json:"_meta,omitempty"` - // The URI of this resource. - URI string `json:"uri"` - // The MIME type of this resource, if known. - MIMEType string `json:"mimeType,omitempty"` - // The text of the item. This must only be set if the item can actually be - // represented as text (not binary data). - Text string `json:"text"` -} - -func (TextResourceContents) isResourceContents() {} - -type BlobResourceContents struct { - // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. - // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. - Meta map[string]any `json:"_meta,omitempty"` - // The URI of this resource. - URI string `json:"uri"` - // The MIME type of this resource, if known. - MIMEType string `json:"mimeType,omitempty"` - // A base64-encoded string representing the binary data of the item. - Blob string `json:"blob"` -} - -func (BlobResourceContents) isResourceContents() {} - -/* Logging */ - -// SetLevelRequest is a request from the client to the server, to enable or -// adjust logging. -type SetLevelRequest struct { - Request - Params SetLevelParams `json:"params"` - Header http.Header `json:"-"` -} - -type SetLevelParams struct { - // The level of logging that the client wants to receive from the server. - // The server should send all logs at this level and higher (i.e., more severe) to - // the client as notifications/logging/message. - Level LoggingLevel `json:"level"` -} - -// LoggingMessageNotification is a notification of a log message passed from -// server to client. If no logging/setLevel request has been sent from the client, -// the server MAY decide which messages to send automatically. -type LoggingMessageNotification struct { - Notification - Params LoggingMessageNotificationParams `json:"params"` -} - -type LoggingMessageNotificationParams struct { - // The severity of this log message. - Level LoggingLevel `json:"level"` - // An optional name of the logger issuing this message. - Logger string `json:"logger,omitempty"` - // The data to be logged, such as a string message or an object. Any JSON - // serializable type is allowed here. - Data any `json:"data"` -} - -// LoggingLevel represents the severity of a log message. -// -// These map to syslog message severities, as specified in RFC-5424: -// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 -type LoggingLevel string - -const ( - LoggingLevelDebug LoggingLevel = "debug" - LoggingLevelInfo LoggingLevel = "info" - LoggingLevelNotice LoggingLevel = "notice" - LoggingLevelWarning LoggingLevel = "warning" - LoggingLevelError LoggingLevel = "error" - LoggingLevelCritical LoggingLevel = "critical" - LoggingLevelAlert LoggingLevel = "alert" - LoggingLevelEmergency LoggingLevel = "emergency" -) - -var levelToInt = map[LoggingLevel]int{ - LoggingLevelDebug: 0, - LoggingLevelInfo: 1, - LoggingLevelNotice: 2, - LoggingLevelWarning: 3, - LoggingLevelError: 4, - LoggingLevelCritical: 5, - LoggingLevelAlert: 6, - LoggingLevelEmergency: 7, -} - -func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { - ia, oka := levelToInt[l] - ib, okb := levelToInt[minLevel] - if !oka || !okb { - return false - } - return ia >= ib -} - -/* Elicitation */ - -// ElicitationRequest is a request from the server to the client to request additional -// information from the user during an interaction. -type ElicitationRequest struct { - Request - Params ElicitationParams `json:"params"` -} - -// ElicitationParams contains the parameters for an elicitation request. -type ElicitationParams struct { - // A human-readable message explaining what information is being requested and why. - Message string `json:"message"` - // A JSON Schema defining the expected structure of the user's response. - RequestedSchema any `json:"requestedSchema"` -} - -// ElicitationResult represents the result of an elicitation request. -type ElicitationResult struct { - Result - ElicitationResponse -} - -// ElicitationResponse represents the user's response to an elicitation request. -type ElicitationResponse struct { - // Action indicates whether the user accepted, declined, or cancelled. - Action ElicitationResponseAction `json:"action"` - // Content contains the user's response data if they accepted. - // Should conform to the requestedSchema from the ElicitationRequest. - Content any `json:"content,omitempty"` -} - -// ElicitationResponseAction indicates how the user responded to an elicitation request. -type ElicitationResponseAction string - -const ( - // ElicitationResponseActionAccept indicates the user provided the requested information. - ElicitationResponseActionAccept ElicitationResponseAction = "accept" - // ElicitationResponseActionDecline indicates the user explicitly declined to provide information. - ElicitationResponseActionDecline ElicitationResponseAction = "decline" - // ElicitationResponseActionCancel indicates the user cancelled without making a choice. - ElicitationResponseActionCancel ElicitationResponseAction = "cancel" -) - -/* Sampling */ - -const ( - // MethodSamplingCreateMessage allows servers to request LLM completions from clients - MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" -) - -// CreateMessageRequest is a request from the server to sample an LLM via the -// client. The client has full discretion over which model to select. The client -// should also inform the user before beginning sampling, to allow them to inspect -// the request (human in the loop) and decide whether to approve it. -type CreateMessageRequest struct { - Request - CreateMessageParams `json:"params"` -} - -type CreateMessageParams struct { - Messages []SamplingMessage `json:"messages"` - ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` - SystemPrompt string `json:"systemPrompt,omitempty"` - IncludeContext string `json:"includeContext,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"maxTokens"` - StopSequences []string `json:"stopSequences,omitempty"` - Metadata any `json:"metadata,omitempty"` -} - -// CreateMessageResult is the client's response to a sampling/create_message -// request from the server. The client should inform the user before returning the -// sampled message, to allow them to inspect the response (human in the loop) and -// decide whether to allow the server to see it. -type CreateMessageResult struct { - Result - SamplingMessage - // The name of the model that generated the message. - Model string `json:"model"` - // The reason why sampling stopped, if known. - StopReason string `json:"stopReason,omitempty"` -} - -// SamplingMessage describes a message issued to or received from an LLM API. -type SamplingMessage struct { - Role Role `json:"role"` - Content any `json:"content"` // Can be TextContent, ImageContent or AudioContent -} - -type Annotations struct { - // Describes who the intended customer of this object or data is. - // - // It can include multiple entries to indicate content useful for multiple - // audiences (e.g., `["user", "assistant"]`). - Audience []Role `json:"audience,omitempty"` - - // Describes how important this data is for operating the server. - // - // A value of 1 means "most important," and indicates that the data is - // effectively required, while 0 means "least important," and indicates that - // the data is entirely optional. - Priority float64 `json:"priority,omitempty"` -} - -// Annotated is the base for objects that include optional annotations for the -// client. The client can use annotations to inform how objects are used or -// displayed -type Annotated struct { - Annotations *Annotations `json:"annotations,omitempty"` -} - -type Content interface { - isContent() -} - -// TextContent represents text provided to or from an LLM. -// It must have Type set to "text". -type TextContent struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - Type string `json:"type"` // Must be "text" - // The text content of the message. - Text string `json:"text"` -} - -func (TextContent) isContent() {} - -// ImageContent represents an image provided to or from an LLM. -// It must have Type set to "image". -type ImageContent struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - Type string `json:"type"` // Must be "image" - // The base64-encoded image data. - Data string `json:"data"` - // The MIME type of the image. Different providers may support different image types. - MIMEType string `json:"mimeType"` -} - -func (ImageContent) isContent() {} - -// AudioContent represents the contents of audio, embedded into a prompt or tool call result. -// It must have Type set to "audio". -type AudioContent struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - Type string `json:"type"` // Must be "audio" - // The base64-encoded audio data. - Data string `json:"data"` - // The MIME type of the audio. Different providers may support different audio types. - MIMEType string `json:"mimeType"` -} - -func (AudioContent) isContent() {} - -// ResourceLink represents a link to a resource that the client can access. -type ResourceLink struct { - Annotated - Type string `json:"type"` // Must be "resource_link" - // The URI of the resource. - URI string `json:"uri"` - // The name of the resource. - Name string `json:"name"` - // The description of the resource. - Description string `json:"description"` - // The MIME type of the resource. - MIMEType string `json:"mimeType"` -} - -func (ResourceLink) isContent() {} - -// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. -// -// It is up to the client how best to render embedded resources for the -// benefit of the LLM and/or the user. -type EmbeddedResource struct { - Annotated - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - Type string `json:"type"` - Resource ResourceContents `json:"resource"` -} - -func (EmbeddedResource) isContent() {} - -// ModelPreferences represents the server's preferences for model selection, -// requested of the client during sampling. -// -// Because LLMs can vary along multiple dimensions, choosing the "best" modelis -// rarely straightforward. Different models excel in different areas—some are -// faster but less capable, others are more capable but more expensive, and so -// on. This interface allows servers to express their priorities across multiple -// dimensions to help clients make an appropriate selection for their use case. -// -// These preferences are always advisory. The client MAY ignore them. It is also -// up to the client to decide how to interpret these preferences and how to -// balance them against other considerations. -type ModelPreferences struct { - // Optional hints to use for model selection. - // - // If multiple hints are specified, the client MUST evaluate them in order - // (such that the first match is taken). - // - // The client SHOULD prioritize these hints over the numeric priorities, but - // MAY still use the priorities to select from ambiguous matches. - Hints []ModelHint `json:"hints,omitempty"` - - // How much to prioritize cost when selecting a model. A value of 0 means cost - // is not important, while a value of 1 means cost is the most important - // factor. - CostPriority float64 `json:"costPriority,omitempty"` - - // How much to prioritize sampling speed (latency) when selecting a model. A - // value of 0 means speed is not important, while a value of 1 means speed is - // the most important factor. - SpeedPriority float64 `json:"speedPriority,omitempty"` - - // How much to prioritize intelligence and capabilities when selecting a - // model. A value of 0 means intelligence is not important, while a value of 1 - // means intelligence is the most important factor. - IntelligencePriority float64 `json:"intelligencePriority,omitempty"` -} - -// ModelHint represents hints to use for model selection. -// -// Keys not declared here are currently left unspecified by the spec and are up -// to the client to interpret. -type ModelHint struct { - // A hint for a model name. - // - // The client SHOULD treat this as a substring of a model name; for example: - // - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` - // - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. - // - `claude` should match any Claude model - // - // The client MAY also map the string to a different provider's model name or - // a different model family, as long as it fills a similar niche; for example: - // - `gemini-1.5-flash` could match `claude-3-haiku-20240307` - Name string `json:"name,omitempty"` -} - -/* Autocomplete */ - -// CompleteRequest is a request from the client to the server, to ask for completion options. -type CompleteRequest struct { - Request - Params CompleteParams `json:"params"` - Header http.Header `json:"-"` -} - -type CompleteParams struct { - Ref any `json:"ref"` // Can be PromptReference or ResourceReference - Argument struct { - // The name of the argument - Name string `json:"name"` - // The value of the argument to use for completion matching. - Value string `json:"value"` - } `json:"argument"` -} - -// CompleteResult is the server's response to a completion/complete request -type CompleteResult struct { - Result - Completion struct { - // An array of completion values. Must not exceed 100 items. - Values []string `json:"values"` - // The total number of completion options available. This can exceed the - // number of values actually sent in the response. - Total int `json:"total,omitempty"` - // Indicates whether there are additional completion options beyond those - // provided in the current response, even if the exact total is unknown. - HasMore bool `json:"hasMore,omitempty"` - } `json:"completion"` -} - -// ResourceReference is a reference to a resource or resource template definition. -type ResourceReference struct { - Type string `json:"type"` - // The URI or URI template of the resource. - URI string `json:"uri"` -} - -// PromptReference identifies a prompt. -type PromptReference struct { - Type string `json:"type"` - // The name of the prompt or prompt template - Name string `json:"name"` -} - -/* Roots */ - -// ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow -// servers to ask for specific directories or files to operate on. A common example -// for roots is providing a set of repositories or directories a server should operate -// on. -// -// This request is typically used when the server needs to understand the file system -// structure or access specific locations that the client has permission to read from. -type ListRootsRequest struct { - Request -} - -// ListRootsResult is the client's response to a roots/list request from the server. -// This result contains an array of Root objects, each representing a root directory -// or file that the server can operate on. -type ListRootsResult struct { - Result - Roots []Root `json:"roots"` -} - -// Root represents a root directory or file that the server can operate on. -type Root struct { - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` - // The URI identifying the root. This *must* start with file:// for now. - // This restriction may be relaxed in future versions of the protocol to allow - // other URI schemes. - URI string `json:"uri"` - // An optional name for the root. This can be used to provide a human-readable - // identifier for the root, which may be useful for display purposes or for - // referencing the root in other parts of the application. - Name string `json:"name,omitempty"` -} - -// RootsListChangedNotification is a notification from the client to the -// server, informing it that the list of roots has changed. -// This notification should be sent whenever the client adds, removes, or modifies any root. -// The server should then request an updated list of roots using the ListRootsRequest. -type RootsListChangedNotification struct { - Notification -} - -// ClientRequest represents any request that can be sent from client to server. -type ClientRequest any - -// ClientNotification represents any notification that can be sent from client to server. -type ClientNotification any - -// ClientResult represents any result that can be sent from client to server. -type ClientResult any - -// ServerRequest represents any request that can be sent from server to client. -type ServerRequest any - -// ServerNotification represents any notification that can be sent from server to client. -type ServerNotification any - -// ServerResult represents any result that can be sent from server to client. -type ServerResult any - -type Named interface { - GetName() string -} - -// MarshalJSON implements custom JSON marshaling for Content interface -func MarshalContent(content Content) ([]byte, error) { - return json.Marshal(content) -} - -// UnmarshalContent implements custom JSON unmarshaling for Content interface -func UnmarshalContent(data []byte) (Content, error) { - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return nil, err - } - - contentType, ok := raw["type"].(string) - if !ok { - return nil, fmt.Errorf("missing or invalid type field") - } - - switch contentType { - case ContentTypeText: - var content TextContent - err := json.Unmarshal(data, &content) - return content, err - case ContentTypeImage: - var content ImageContent - err := json.Unmarshal(data, &content) - return content, err - case ContentTypeAudio: - var content AudioContent - err := json.Unmarshal(data, &content) - return content, err - case ContentTypeLink: - var content ResourceLink - err := json.Unmarshal(data, &content) - return content, err - case ContentTypeResource: - var content EmbeddedResource - err := json.Unmarshal(data, &content) - return content, err - default: - return nil, fmt.Errorf("unknown content type: %s", contentType) - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go deleted file mode 100644 index 904a3dd6b..000000000 --- a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go +++ /dev/null @@ -1,979 +0,0 @@ -package mcp - -import ( - "encoding/json" - "fmt" - - "github.com/spf13/cast" -) - -// ClientRequest types -var ( - _ ClientRequest = (*PingRequest)(nil) - _ ClientRequest = (*InitializeRequest)(nil) - _ ClientRequest = (*CompleteRequest)(nil) - _ ClientRequest = (*SetLevelRequest)(nil) - _ ClientRequest = (*GetPromptRequest)(nil) - _ ClientRequest = (*ListPromptsRequest)(nil) - _ ClientRequest = (*ListResourcesRequest)(nil) - _ ClientRequest = (*ReadResourceRequest)(nil) - _ ClientRequest = (*SubscribeRequest)(nil) - _ ClientRequest = (*UnsubscribeRequest)(nil) - _ ClientRequest = (*CallToolRequest)(nil) - _ ClientRequest = (*ListToolsRequest)(nil) -) - -// ClientNotification types -var ( - _ ClientNotification = (*CancelledNotification)(nil) - _ ClientNotification = (*ProgressNotification)(nil) - _ ClientNotification = (*InitializedNotification)(nil) - _ ClientNotification = (*RootsListChangedNotification)(nil) -) - -// ClientResult types -var ( - _ ClientResult = (*EmptyResult)(nil) - _ ClientResult = (*CreateMessageResult)(nil) - _ ClientResult = (*ListRootsResult)(nil) -) - -// ServerRequest types -var ( - _ ServerRequest = (*PingRequest)(nil) - _ ServerRequest = (*CreateMessageRequest)(nil) - _ ServerRequest = (*ListRootsRequest)(nil) -) - -// ServerNotification types -var ( - _ ServerNotification = (*CancelledNotification)(nil) - _ ServerNotification = (*ProgressNotification)(nil) - _ ServerNotification = (*LoggingMessageNotification)(nil) - _ ServerNotification = (*ResourceUpdatedNotification)(nil) - _ ServerNotification = (*ResourceListChangedNotification)(nil) - _ ServerNotification = (*ToolListChangedNotification)(nil) - _ ServerNotification = (*PromptListChangedNotification)(nil) -) - -// ServerResult types -var ( - _ ServerResult = (*EmptyResult)(nil) - _ ServerResult = (*InitializeResult)(nil) - _ ServerResult = (*CompleteResult)(nil) - _ ServerResult = (*GetPromptResult)(nil) - _ ServerResult = (*ListPromptsResult)(nil) - _ ServerResult = (*ListResourcesResult)(nil) - _ ServerResult = (*ReadResourceResult)(nil) - _ ServerResult = (*CallToolResult)(nil) - _ ServerResult = (*ListToolsResult)(nil) -) - -// Helper functions for type assertions - -// asType attempts to cast the given interface to the given type -func asType[T any](content any) (*T, bool) { - tc, ok := content.(T) - if !ok { - return nil, false - } - return &tc, true -} - -// AsTextContent attempts to cast the given interface to TextContent -func AsTextContent(content any) (*TextContent, bool) { - return asType[TextContent](content) -} - -// AsImageContent attempts to cast the given interface to ImageContent -func AsImageContent(content any) (*ImageContent, bool) { - return asType[ImageContent](content) -} - -// AsAudioContent attempts to cast the given interface to AudioContent -func AsAudioContent(content any) (*AudioContent, bool) { - return asType[AudioContent](content) -} - -// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource -func AsEmbeddedResource(content any) (*EmbeddedResource, bool) { - return asType[EmbeddedResource](content) -} - -// AsTextResourceContents attempts to cast the given interface to TextResourceContents -func AsTextResourceContents(content any) (*TextResourceContents, bool) { - return asType[TextResourceContents](content) -} - -// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents -func AsBlobResourceContents(content any) (*BlobResourceContents, bool) { - return asType[BlobResourceContents](content) -} - -// Helper function for JSON-RPC - -// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result. -// NOTE: This function expects a Result struct, but JSONRPCResponse.Result is typed as `any`. -// The Result struct wraps the actual result data with optional metadata. -// For direct result assignment, use NewJSONRPCResultResponse instead. -func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { - return JSONRPCResponse{ - JSONRPC: JSONRPC_VERSION, - ID: id, - Result: result, - } -} - -// NewJSONRPCResultResponse creates a new JSONRPCResponse with the given id and result. -// This function accepts any type for the result, matching the JSONRPCResponse.Result field type. -func NewJSONRPCResultResponse(id RequestId, result any) JSONRPCResponse { - return JSONRPCResponse{ - JSONRPC: JSONRPC_VERSION, - ID: id, - Result: result, - } -} - -// NewJSONRPCErrorDetails creates a new JSONRPCErrorDetails with the given code, message, and data. -func NewJSONRPCErrorDetails(code int, message string, data any) JSONRPCErrorDetails { - return JSONRPCErrorDetails{ - Code: code, - Message: message, - Data: data, - } -} - -// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message -func NewJSONRPCError( - id RequestId, - code int, - message string, - data any, -) JSONRPCError { - return JSONRPCError{ - JSONRPC: JSONRPC_VERSION, - ID: id, - Error: NewJSONRPCErrorDetails(code, message, data), - } -} - -// NewProgressNotification -// Helper function for creating a progress notification -func NewProgressNotification( - token ProgressToken, - progress float64, - total *float64, - message *string, -) ProgressNotification { - notification := ProgressNotification{ - Notification: Notification{ - Method: "notifications/progress", - }, - Params: struct { - ProgressToken ProgressToken `json:"progressToken"` - Progress float64 `json:"progress"` - Total float64 `json:"total,omitempty"` - Message string `json:"message,omitempty"` - }{ - ProgressToken: token, - Progress: progress, - }, - } - if total != nil { - notification.Params.Total = *total - } - if message != nil { - notification.Params.Message = *message - } - return notification -} - -// NewLoggingMessageNotification -// Helper function for creating a logging message notification -func NewLoggingMessageNotification( - level LoggingLevel, - logger string, - data any, -) LoggingMessageNotification { - return LoggingMessageNotification{ - Notification: Notification{ - Method: "notifications/message", - }, - Params: struct { - Level LoggingLevel `json:"level"` - Logger string `json:"logger,omitempty"` - Data any `json:"data"` - }{ - Level: level, - Logger: logger, - Data: data, - }, - } -} - -// NewPromptMessage -// Helper function to create a new PromptMessage -func NewPromptMessage(role Role, content Content) PromptMessage { - return PromptMessage{ - Role: role, - Content: content, - } -} - -// NewTextContent -// Helper function to create a new TextContent -func NewTextContent(text string) TextContent { - return TextContent{ - Type: ContentTypeText, - Text: text, - } -} - -// NewImageContent -// Helper function to create a new ImageContent -func NewImageContent(data, mimeType string) ImageContent { - return ImageContent{ - Type: ContentTypeImage, - Data: data, - MIMEType: mimeType, - } -} - -// Helper function to create a new AudioContent -func NewAudioContent(data, mimeType string) AudioContent { - return AudioContent{ - Type: ContentTypeAudio, - Data: data, - MIMEType: mimeType, - } -} - -// Helper function to create a new ResourceLink -func NewResourceLink(uri, name, description, mimeType string) ResourceLink { - return ResourceLink{ - Type: ContentTypeLink, - URI: uri, - Name: name, - Description: description, - MIMEType: mimeType, - } -} - -// Helper function to create a new EmbeddedResource -func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { - return EmbeddedResource{ - Type: ContentTypeResource, - Resource: resource, - } -} - -// NewToolResultText creates a new CallToolResult with a text content -func NewToolResultText(text string) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - }, - } -} - -// NewToolResultJSON creates a new CallToolResult with a JSON content. -func NewToolResultJSON[T any](data T) (*CallToolResult, error) { - b, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("unable to marshal JSON: %w", err) - } - - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: string(b), - }, - }, - StructuredContent: data, - }, nil -} - -// NewToolResultStructured creates a new CallToolResult with structured content. -// It includes both the structured content and a text representation for backward compatibility. -func NewToolResultStructured(structured any, fallbackText string) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: "text", - Text: fallbackText, - }, - }, - StructuredContent: structured, - } -} - -// NewToolResultStructuredOnly creates a new CallToolResult with structured -// content and creates a JSON string fallback for backwards compatibility. -// This is useful when you want to provide structured data without any specific text fallback. -func NewToolResultStructuredOnly(structured any) *CallToolResult { - var fallbackText string - // Convert to JSON string for backward compatibility - jsonBytes, err := json.Marshal(structured) - if err != nil { - fallbackText = fmt.Sprintf("Error serializing structured content: %v", err) - } else { - fallbackText = string(jsonBytes) - } - - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: "text", - Text: fallbackText, - }, - }, - StructuredContent: structured, - } -} - -// NewToolResultImage creates a new CallToolResult with both text and image content -func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - ImageContent{ - Type: ContentTypeImage, - Data: imageData, - MIMEType: mimeType, - }, - }, - } -} - -// NewToolResultAudio creates a new CallToolResult with both text and audio content -func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - AudioContent{ - Type: ContentTypeAudio, - Data: imageData, - MIMEType: mimeType, - }, - }, - } -} - -// NewToolResultResource creates a new CallToolResult with an embedded resource -func NewToolResultResource( - text string, - resource ResourceContents, -) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - EmbeddedResource{ - Type: ContentTypeResource, - Resource: resource, - }, - }, - } -} - -// NewToolResultError creates a new CallToolResult with an error message. -// Any errors that originate from the tool SHOULD be reported inside the result object. -func NewToolResultError(text string) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - }, - IsError: true, - } -} - -// NewToolResultErrorFromErr creates a new CallToolResult with an error message. -// If an error is provided, its details will be appended to the text message. -// Any errors that originate from the tool SHOULD be reported inside the result object. -func NewToolResultErrorFromErr(text string, err error) *CallToolResult { - if err != nil { - text = fmt.Sprintf("%s: %v", text, err) - } - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: text, - }, - }, - IsError: true, - } -} - -// NewToolResultErrorf creates a new CallToolResult with an error message. -// The error message is formatted using the fmt package. -// Any errors that originate from the tool SHOULD be reported inside the result object. -func NewToolResultErrorf(format string, a ...any) *CallToolResult { - return &CallToolResult{ - Content: []Content{ - TextContent{ - Type: ContentTypeText, - Text: fmt.Sprintf(format, a...), - }, - }, - IsError: true, - } -} - -// NewListResourcesResult creates a new ListResourcesResult -func NewListResourcesResult( - resources []Resource, - nextCursor Cursor, -) *ListResourcesResult { - return &ListResourcesResult{ - PaginatedResult: PaginatedResult{ - NextCursor: nextCursor, - }, - Resources: resources, - } -} - -// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult -func NewListResourceTemplatesResult( - templates []ResourceTemplate, - nextCursor Cursor, -) *ListResourceTemplatesResult { - return &ListResourceTemplatesResult{ - PaginatedResult: PaginatedResult{ - NextCursor: nextCursor, - }, - ResourceTemplates: templates, - } -} - -// NewReadResourceResult creates a new ReadResourceResult with text content -func NewReadResourceResult(text string) *ReadResourceResult { - return &ReadResourceResult{ - Contents: []ResourceContents{ - TextResourceContents{ - Text: text, - }, - }, - } -} - -// NewListPromptsResult creates a new ListPromptsResult -func NewListPromptsResult( - prompts []Prompt, - nextCursor Cursor, -) *ListPromptsResult { - return &ListPromptsResult{ - PaginatedResult: PaginatedResult{ - NextCursor: nextCursor, - }, - Prompts: prompts, - } -} - -// NewGetPromptResult creates a new GetPromptResult -func NewGetPromptResult( - description string, - messages []PromptMessage, -) *GetPromptResult { - return &GetPromptResult{ - Description: description, - Messages: messages, - } -} - -// NewListToolsResult creates a new ListToolsResult -func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult { - return &ListToolsResult{ - PaginatedResult: PaginatedResult{ - NextCursor: nextCursor, - }, - Tools: tools, - } -} - -// NewInitializeResult creates a new InitializeResult -func NewInitializeResult( - protocolVersion string, - capabilities ServerCapabilities, - serverInfo Implementation, - instructions string, -) *InitializeResult { - return &InitializeResult{ - ProtocolVersion: protocolVersion, - Capabilities: capabilities, - ServerInfo: serverInfo, - Instructions: instructions, - } -} - -// FormatNumberResult -// Helper for formatting numbers in tool results -func FormatNumberResult(value float64) *CallToolResult { - return NewToolResultText(fmt.Sprintf("%.2f", value)) -} - -func ExtractString(data map[string]any, key string) string { - if value, ok := data[key]; ok { - if str, ok := value.(string); ok { - return str - } - } - return "" -} - -func ParseAnnotations(data map[string]any) *Annotations { - if data == nil { - return nil - } - annotations := &Annotations{} - if value, ok := data["priority"]; ok { - annotations.Priority = cast.ToFloat64(value) - } - - if value, ok := data["audience"]; ok { - for _, a := range cast.ToStringSlice(value) { - a := Role(a) - if a == RoleUser || a == RoleAssistant { - annotations.Audience = append(annotations.Audience, a) - } - } - } - return annotations - -} - -func ExtractMap(data map[string]any, key string) map[string]any { - if value, ok := data[key]; ok { - if m, ok := value.(map[string]any); ok { - return m - } - } - return nil -} - -func ParseContent(contentMap map[string]any) (Content, error) { - contentType := ExtractString(contentMap, "type") - - var annotations *Annotations - if annotationsMap := ExtractMap(contentMap, "annotations"); annotationsMap != nil { - annotations = ParseAnnotations(annotationsMap) - } - - switch contentType { - case ContentTypeText: - text := ExtractString(contentMap, "text") - c := NewTextContent(text) - c.Annotations = annotations - return c, nil - - case ContentTypeImage: - data := ExtractString(contentMap, "data") - mimeType := ExtractString(contentMap, "mimeType") - if data == "" || mimeType == "" { - return nil, fmt.Errorf("image data or mimeType is missing") - } - c := NewImageContent(data, mimeType) - c.Annotations = annotations - return c, nil - - case ContentTypeAudio: - data := ExtractString(contentMap, "data") - mimeType := ExtractString(contentMap, "mimeType") - if data == "" || mimeType == "" { - return nil, fmt.Errorf("audio data or mimeType is missing") - } - c := NewAudioContent(data, mimeType) - c.Annotations = annotations - return c, nil - - case ContentTypeLink: - uri := ExtractString(contentMap, "uri") - name := ExtractString(contentMap, "name") - description := ExtractString(contentMap, "description") - mimeType := ExtractString(contentMap, "mimeType") - if uri == "" || name == "" { - return nil, fmt.Errorf("resource_link uri or name is missing") - } - c := NewResourceLink(uri, name, description, mimeType) - c.Annotations = annotations - return c, nil - - case ContentTypeResource: - resourceMap := ExtractMap(contentMap, "resource") - if resourceMap == nil { - return nil, fmt.Errorf("resource is missing") - } - - resourceContents, err := ParseResourceContents(resourceMap) - if err != nil { - return nil, err - } - - c := NewEmbeddedResource(resourceContents) - c.Annotations = annotations - return c, nil - } - - return nil, fmt.Errorf("unsupported content type: %s", contentType) -} - -func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { - if rawMessage == nil { - return nil, fmt.Errorf("response is nil") - } - - var jsonContent map[string]any - if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - result := GetPromptResult{} - - meta, ok := jsonContent["_meta"] - if ok { - if metaMap, ok := meta.(map[string]any); ok { - result.Meta = NewMetaFromMap(metaMap) - } - } - - description, ok := jsonContent["description"] - if ok { - if descriptionStr, ok := description.(string); ok { - result.Description = descriptionStr - } - } - - messages, ok := jsonContent["messages"] - if ok { - messagesArr, ok := messages.([]any) - if !ok { - return nil, fmt.Errorf("messages is not an array") - } - - for _, message := range messagesArr { - messageMap, ok := message.(map[string]any) - if !ok { - return nil, fmt.Errorf("message is not an object") - } - - // Extract role - roleStr := ExtractString(messageMap, "role") - if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { - return nil, fmt.Errorf("unsupported role: %s", roleStr) - } - - // Extract content - contentMap, ok := messageMap["content"].(map[string]any) - if !ok { - return nil, fmt.Errorf("content is not an object") - } - - // Process content - content, err := ParseContent(contentMap) - if err != nil { - return nil, err - } - - // Append processed message - result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) - - } - } - - return &result, nil -} - -func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { - if rawMessage == nil { - return nil, fmt.Errorf("response is nil") - } - - var jsonContent map[string]any - if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - var result CallToolResult - - meta, ok := jsonContent["_meta"] - if ok { - if metaMap, ok := meta.(map[string]any); ok { - result.Meta = NewMetaFromMap(metaMap) - } - } - - isError, ok := jsonContent["isError"] - if ok { - if isErrorBool, ok := isError.(bool); ok { - result.IsError = isErrorBool - } - } - - contents, ok := jsonContent["content"] - if !ok { - return nil, fmt.Errorf("content is missing") - } - - contentArr, ok := contents.([]any) - if !ok { - return nil, fmt.Errorf("content is not an array") - } - - for _, content := range contentArr { - // Extract content - contentMap, ok := content.(map[string]any) - if !ok { - return nil, fmt.Errorf("content is not an object") - } - - // Process content - content, err := ParseContent(contentMap) - if err != nil { - return nil, err - } - - result.Content = append(result.Content, content) - } - - // Handle structured content - structuredContent, ok := jsonContent["structuredContent"] - if ok { - result.StructuredContent = structuredContent - } - - return &result, nil -} - -func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) { - uri := ExtractString(contentMap, "uri") - if uri == "" { - return nil, fmt.Errorf("resource uri is missing") - } - - mimeType := ExtractString(contentMap, "mimeType") - - meta := ExtractMap(contentMap, "_meta") - - if _, present := contentMap["_meta"]; present && meta == nil { - return nil, fmt.Errorf("_meta must be an object") - } - - if text := ExtractString(contentMap, "text"); text != "" { - return TextResourceContents{ - Meta: meta, - URI: uri, - MIMEType: mimeType, - Text: text, - }, nil - } - - if blob := ExtractString(contentMap, "blob"); blob != "" { - return BlobResourceContents{ - Meta: meta, - URI: uri, - MIMEType: mimeType, - Blob: blob, - }, nil - } - - return nil, fmt.Errorf("unsupported resource type") -} - -func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { - if rawMessage == nil { - return nil, fmt.Errorf("response is nil") - } - - var jsonContent map[string]any - if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - var result ReadResourceResult - - meta, ok := jsonContent["_meta"] - if ok { - if metaMap, ok := meta.(map[string]any); ok { - result.Meta = NewMetaFromMap(metaMap) - } - } - - contents, ok := jsonContent["contents"] - if !ok { - return nil, fmt.Errorf("contents is missing") - } - - contentArr, ok := contents.([]any) - if !ok { - return nil, fmt.Errorf("contents is not an array") - } - - for _, content := range contentArr { - // Extract content - contentMap, ok := content.(map[string]any) - if !ok { - return nil, fmt.Errorf("content is not an object") - } - - // Process content - content, err := ParseResourceContents(contentMap) - if err != nil { - return nil, err - } - - result.Contents = append(result.Contents, content) - } - - return &result, nil -} - -func ParseArgument(request CallToolRequest, key string, defaultVal any) any { - args := request.GetArguments() - if _, ok := args[key]; !ok { - return defaultVal - } else { - return args[key] - } -} - -// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. -// If the key is not found in the Arguments map, the defaultValue is returned. -// The function uses cast.ToBool for conversion which handles various string representations -// such as "true", "yes", "1", etc. -func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { - v := ParseArgument(request, key, defaultValue) - return cast.ToBool(v) -} - -// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. -// If the key is not found in the Arguments map, the defaultValue is returned. -func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { - v := ParseArgument(request, key, defaultValue) - return cast.ToInt64(v) -} - -// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. -func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { - v := ParseArgument(request, key, defaultValue) - return cast.ToInt32(v) -} - -// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. -func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { - v := ParseArgument(request, key, defaultValue) - return cast.ToInt16(v) -} - -// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. -func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { - v := ParseArgument(request, key, defaultValue) - return cast.ToInt8(v) -} - -// ParseInt extracts and converts an int parameter from a CallToolRequest. -func ParseInt(request CallToolRequest, key string, defaultValue int) int { - v := ParseArgument(request, key, defaultValue) - return cast.ToInt(v) -} - -// ParseUInt extracts and converts an uint parameter from a CallToolRequest. -func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { - v := ParseArgument(request, key, defaultValue) - return cast.ToUint(v) -} - -// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. -func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { - v := ParseArgument(request, key, defaultValue) - return cast.ToUint64(v) -} - -// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. -func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { - v := ParseArgument(request, key, defaultValue) - return cast.ToUint32(v) -} - -// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. -func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { - v := ParseArgument(request, key, defaultValue) - return cast.ToUint16(v) -} - -// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. -func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { - v := ParseArgument(request, key, defaultValue) - return cast.ToUint8(v) -} - -// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. -func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { - v := ParseArgument(request, key, defaultValue) - return cast.ToFloat32(v) -} - -// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. -func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { - v := ParseArgument(request, key, defaultValue) - return cast.ToFloat64(v) -} - -// ParseString extracts and converts a string parameter from a CallToolRequest. -func ParseString(request CallToolRequest, key string, defaultValue string) string { - v := ParseArgument(request, key, defaultValue) - return cast.ToString(v) -} - -// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. -func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { - v := ParseArgument(request, key, defaultValue) - return cast.ToStringMap(v) -} - -// ToBoolPtr returns a pointer to the given boolean value -func ToBoolPtr(b bool) *bool { - return &b -} - -// GetTextFromContent extracts text from a Content interface that might be a TextContent struct -// or a map[string]any that was unmarshaled from JSON. This is useful when dealing with content -// that comes from different transport layers that may handle JSON differently. -// -// This function uses fallback behavior for non-text content - it returns a string representation -// via fmt.Sprintf for any content that cannot be extracted as text. This is a lossy operation -// intended for convenience in logging and display scenarios. -// -// For strict type validation, use ParseContent() instead, which returns an error for invalid content. -func GetTextFromContent(content any) string { - switch c := content.(type) { - case TextContent: - return c.Text - case map[string]any: - // Handle JSON unmarshaled content - if contentType, exists := c["type"]; exists && contentType == "text" { - if text, exists := c["text"].(string); exists { - return text - } - } - return fmt.Sprintf("%v", content) - case string: - return c - default: - return fmt.Sprintf("%v", content) - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/constants.go b/vendor/github.com/mark3labs/mcp-go/server/constants.go deleted file mode 100644 index e071b2ef4..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/constants.go +++ /dev/null @@ -1,7 +0,0 @@ -package server - -// Common HTTP header constants used across server transports -const ( - HeaderKeySessionID = "Mcp-Session-Id" - HeaderKeyProtocolVersion = "Mcp-Protocol-Version" -) diff --git a/vendor/github.com/mark3labs/mcp-go/server/ctx.go b/vendor/github.com/mark3labs/mcp-go/server/ctx.go deleted file mode 100644 index 43f01bb68..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/ctx.go +++ /dev/null @@ -1,8 +0,0 @@ -package server - -type contextKey int - -const ( - // This const is used as key for context value lookup - requestHeader contextKey = iota -) diff --git a/vendor/github.com/mark3labs/mcp-go/server/elicitation.go b/vendor/github.com/mark3labs/mcp-go/server/elicitation.go deleted file mode 100644 index d3e6d3d4c..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/elicitation.go +++ /dev/null @@ -1,32 +0,0 @@ -package server - -import ( - "context" - "errors" - - "github.com/mark3labs/mcp-go/mcp" -) - -var ( - // ErrNoActiveSession is returned when there is no active session in the context - ErrNoActiveSession = errors.New("no active session") - // ErrElicitationNotSupported is returned when the session does not support elicitation - ErrElicitationNotSupported = errors.New("session does not support elicitation") -) - -// RequestElicitation sends an elicitation request to the client. -// The client must have declared elicitation capability during initialization. -// The session must implement SessionWithElicitation to support this operation. -func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { - session := ClientSessionFromContext(ctx) - if session == nil { - return nil, ErrNoActiveSession - } - - // Check if the session supports elicitation requests - if elicitationSession, ok := session.(SessionWithElicitation); ok { - return elicitationSession.RequestElicitation(ctx, request) - } - - return nil, ErrElicitationNotSupported -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/errors.go b/vendor/github.com/mark3labs/mcp-go/server/errors.go deleted file mode 100644 index 5e65f0760..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/errors.go +++ /dev/null @@ -1,36 +0,0 @@ -package server - -import ( - "errors" - "fmt" -) - -var ( - // Common server errors - ErrUnsupported = errors.New("not supported") - ErrResourceNotFound = errors.New("resource not found") - ErrPromptNotFound = errors.New("prompt not found") - ErrToolNotFound = errors.New("tool not found") - - // Session-related errors - ErrSessionNotFound = errors.New("session not found") - ErrSessionExists = errors.New("session already exists") - ErrSessionNotInitialized = errors.New("session not properly initialized") - ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") - ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources") - ErrSessionDoesNotSupportResourceTemplates = errors.New("session does not support resource templates") - ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") - - // Notification-related errors - ErrNotificationNotInitialized = errors.New("notification channel not initialized") - ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") -) - -// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration -type ErrDynamicPathConfig struct { - Method string -} - -func (e *ErrDynamicPathConfig) Error() string { - return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go deleted file mode 100644 index 4baa1c4e0..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/hooks.go +++ /dev/null @@ -1,532 +0,0 @@ -// Code generated by `go generate`. DO NOT EDIT. -// source: server/internal/gen/hooks.go.tmpl -package server - -import ( - "context" - - "github.com/mark3labs/mcp-go/mcp" -) - -// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. -type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) - -// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. -type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) - -// BeforeAnyHookFunc is a function that is called after the request is -// parsed but before the method is called. -type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) - -// OnSuccessHookFunc is a hook that will be called after the request -// successfully generates a result, but before the result is sent to the client. -type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) - -// OnErrorHookFunc is a hook that will be called when an error occurs, -// either during the request parsing or the method execution. -// -// Example usage: -// ``` -// -// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { -// // Check for specific error types using errors.Is -// if errors.Is(err, ErrUnsupported) { -// // Handle capability not supported errors -// log.Printf("Capability not supported: %v", err) -// } -// -// // Use errors.As to get specific error types -// var parseErr = &UnparsableMessageError{} -// if errors.As(err, &parseErr) { -// // Access specific methods/fields of the error type -// log.Printf("Failed to parse message for method %s: %v", -// parseErr.GetMethod(), parseErr.Unwrap()) -// // Access the raw message that failed to parse -// rawMsg := parseErr.GetMessage() -// } -// -// // Check for specific resource/prompt/tool errors -// switch { -// case errors.Is(err, ErrResourceNotFound): -// log.Printf("Resource not found: %v", err) -// case errors.Is(err, ErrPromptNotFound): -// log.Printf("Prompt not found: %v", err) -// case errors.Is(err, ErrToolNotFound): -// log.Printf("Tool not found: %v", err) -// } -// }) -type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) - -// OnRequestInitializationFunc is a function that called before handle diff request method -// Should any errors arise during func execution, the service will promptly return the corresponding error message. -type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error - -type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) -type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) - -type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) -type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) - -type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest) -type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) - -type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) -type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) - -type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) -type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) - -type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest) -type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) - -type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest) -type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) - -type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest) -type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) - -type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest) -type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) - -type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) -type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) - -type Hooks struct { - OnRegisterSession []OnRegisterSessionHookFunc - OnUnregisterSession []OnUnregisterSessionHookFunc - OnBeforeAny []BeforeAnyHookFunc - OnSuccess []OnSuccessHookFunc - OnError []OnErrorHookFunc - OnRequestInitialization []OnRequestInitializationFunc - OnBeforeInitialize []OnBeforeInitializeFunc - OnAfterInitialize []OnAfterInitializeFunc - OnBeforePing []OnBeforePingFunc - OnAfterPing []OnAfterPingFunc - OnBeforeSetLevel []OnBeforeSetLevelFunc - OnAfterSetLevel []OnAfterSetLevelFunc - OnBeforeListResources []OnBeforeListResourcesFunc - OnAfterListResources []OnAfterListResourcesFunc - OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc - OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc - OnBeforeReadResource []OnBeforeReadResourceFunc - OnAfterReadResource []OnAfterReadResourceFunc - OnBeforeListPrompts []OnBeforeListPromptsFunc - OnAfterListPrompts []OnAfterListPromptsFunc - OnBeforeGetPrompt []OnBeforeGetPromptFunc - OnAfterGetPrompt []OnAfterGetPromptFunc - OnBeforeListTools []OnBeforeListToolsFunc - OnAfterListTools []OnAfterListToolsFunc - OnBeforeCallTool []OnBeforeCallToolFunc - OnAfterCallTool []OnAfterCallToolFunc -} - -func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { - c.OnBeforeAny = append(c.OnBeforeAny, hook) -} - -func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { - c.OnSuccess = append(c.OnSuccess, hook) -} - -// AddOnError registers a hook function that will be called when an error occurs. -// The error parameter contains the actual error object, which can be interrogated -// using Go's error handling patterns like errors.Is and errors.As. -// -// Example: -// ``` -// // Create a channel to receive errors for testing -// errChan := make(chan error, 1) -// -// // Register hook to capture and inspect errors -// hooks := &Hooks{} -// -// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { -// // For capability-related errors -// if errors.Is(err, ErrUnsupported) { -// // Handle capability not supported -// errChan <- err -// return -// } -// -// // For parsing errors -// var parseErr = &UnparsableMessageError{} -// if errors.As(err, &parseErr) { -// // Handle unparsable message errors -// fmt.Printf("Failed to parse %s request: %v\n", -// parseErr.GetMethod(), parseErr.Unwrap()) -// errChan <- parseErr -// return -// } -// -// // For resource/prompt/tool not found errors -// if errors.Is(err, ErrResourceNotFound) || -// errors.Is(err, ErrPromptNotFound) || -// errors.Is(err, ErrToolNotFound) { -// // Handle not found errors -// errChan <- err -// return -// } -// -// // For other errors -// errChan <- err -// }) -// -// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) -// ``` -func (c *Hooks) AddOnError(hook OnErrorHookFunc) { - c.OnError = append(c.OnError, hook) -} - -func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { - if c == nil { - return - } - for _, hook := range c.OnBeforeAny { - hook(ctx, id, method, message) - } -} - -func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { - if c == nil { - return - } - for _, hook := range c.OnSuccess { - hook(ctx, id, method, message, result) - } -} - -// onError calls all registered error hooks with the error object. -// The err parameter contains the actual error that occurred, which implements -// the standard error interface and may be a wrapped error or custom error type. -// -// This allows consumer code to use Go's error handling patterns: -// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors -// - errors.As(err, &customErr) to extract custom error types -// -// Common error types include: -// - ErrUnsupported: When a capability is not enabled -// - UnparsableMessageError: When request parsing fails -// - ErrResourceNotFound: When a resource is not found -// - ErrPromptNotFound: When a prompt is not found -// - ErrToolNotFound: When a tool is not found -func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { - if c == nil { - return - } - for _, hook := range c.OnError { - hook(ctx, id, method, message, err) - } -} - -func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { - c.OnRegisterSession = append(c.OnRegisterSession, hook) -} - -func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { - if c == nil { - return - } - for _, hook := range c.OnRegisterSession { - hook(ctx, session) - } -} - -func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { - c.OnUnregisterSession = append(c.OnUnregisterSession, hook) -} - -func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { - if c == nil { - return - } - for _, hook := range c.OnUnregisterSession { - hook(ctx, session) - } -} - -func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) { - c.OnRequestInitialization = append(c.OnRequestInitialization, hook) -} - -func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error { - if c == nil { - return nil - } - for _, hook := range c.OnRequestInitialization { - err := hook(ctx, id, message) - if err != nil { - return err - } - } - return nil -} -func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { - c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) -} - -func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) { - c.OnAfterInitialize = append(c.OnAfterInitialize, hook) -} - -func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) { - c.beforeAny(ctx, id, mcp.MethodInitialize, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeInitialize { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { - c.onSuccess(ctx, id, mcp.MethodInitialize, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterInitialize { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) { - c.OnBeforePing = append(c.OnBeforePing, hook) -} - -func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) { - c.OnAfterPing = append(c.OnAfterPing, hook) -} - -func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) { - c.beforeAny(ctx, id, mcp.MethodPing, message) - if c == nil { - return - } - for _, hook := range c.OnBeforePing { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { - c.onSuccess(ctx, id, mcp.MethodPing, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterPing { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) { - c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook) -} - -func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) { - c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook) -} - -func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) { - c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeSetLevel { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { - c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterSetLevel { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { - c.OnBeforeListResources = append(c.OnBeforeListResources, hook) -} - -func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) { - c.OnAfterListResources = append(c.OnAfterListResources, hook) -} - -func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) { - c.beforeAny(ctx, id, mcp.MethodResourcesList, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeListResources { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { - c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterListResources { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) { - c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook) -} - -func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) { - c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook) -} - -func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { - c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeListResourceTemplates { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { - c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterListResourceTemplates { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) { - c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook) -} - -func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) { - c.OnAfterReadResource = append(c.OnAfterReadResource, hook) -} - -func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) { - c.beforeAny(ctx, id, mcp.MethodResourcesRead, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeReadResource { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { - c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterReadResource { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) { - c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook) -} - -func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) { - c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook) -} - -func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) { - c.beforeAny(ctx, id, mcp.MethodPromptsList, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeListPrompts { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { - c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterListPrompts { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) { - c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook) -} - -func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) { - c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook) -} - -func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) { - c.beforeAny(ctx, id, mcp.MethodPromptsGet, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeGetPrompt { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { - c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterGetPrompt { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) { - c.OnBeforeListTools = append(c.OnBeforeListTools, hook) -} - -func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) { - c.OnAfterListTools = append(c.OnAfterListTools, hook) -} - -func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) { - c.beforeAny(ctx, id, mcp.MethodToolsList, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeListTools { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { - c.onSuccess(ctx, id, mcp.MethodToolsList, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterListTools { - hook(ctx, id, message, result) - } -} -func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) { - c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook) -} - -func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) { - c.OnAfterCallTool = append(c.OnAfterCallTool, hook) -} - -func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) { - c.beforeAny(ctx, id, mcp.MethodToolsCall, message) - if c == nil { - return - } - for _, hook := range c.OnBeforeCallTool { - hook(ctx, id, message) - } -} - -func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { - c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result) - if c == nil { - return - } - for _, hook := range c.OnAfterCallTool { - hook(ctx, id, message, result) - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go deleted file mode 100644 index 4f5ad53d0..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go +++ /dev/null @@ -1,11 +0,0 @@ -package server - -import ( - "context" - "net/http" -) - -// HTTPContextFunc is a function that takes an existing context and the current -// request and returns a potentially modified context based on the request -// content. This can be used to inject context values from headers, for example. -type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context diff --git a/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go b/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go deleted file mode 100644 index 59ab0f366..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go +++ /dev/null @@ -1,165 +0,0 @@ -package server - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/mark3labs/mcp-go/mcp" -) - -// SamplingHandler defines the interface for handling sampling requests from servers. -type SamplingHandler interface { - CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) -} - -// ElicitationHandler defines the interface for handling elicitation requests from servers. -type ElicitationHandler interface { - Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) -} - -// RootsHandler defines the interface for handling roots list requests from servers. -type RootsHandler interface { - ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) -} - -type InProcessSession struct { - sessionID string - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value - clientCapabilities atomic.Value - samplingHandler SamplingHandler - elicitationHandler ElicitationHandler - rootsHandler RootsHandler - mu sync.RWMutex -} - -func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { - return &InProcessSession{ - sessionID: sessionID, - notifications: make(chan mcp.JSONRPCNotification, 100), - samplingHandler: samplingHandler, - } -} - -func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler, rootsHandler RootsHandler) *InProcessSession { - return &InProcessSession{ - sessionID: sessionID, - notifications: make(chan mcp.JSONRPCNotification, 100), - samplingHandler: samplingHandler, - elicitationHandler: elicitationHandler, - rootsHandler: rootsHandler, - } -} - -func (s *InProcessSession) SessionID() string { - return s.sessionID -} - -func (s *InProcessSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - return s.notifications -} - -func (s *InProcessSession) Initialize() { - s.loggingLevel.Store(mcp.LoggingLevelError) - s.initialized.Store(true) -} - -func (s *InProcessSession) Initialized() bool { - return s.initialized.Load() -} - -func (s *InProcessSession) GetClientInfo() mcp.Implementation { - if value := s.clientInfo.Load(); value != nil { - if clientInfo, ok := value.(mcp.Implementation); ok { - return clientInfo - } - } - return mcp.Implementation{} -} - -func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { - s.clientInfo.Store(clientInfo) -} - -func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { - if value := s.clientCapabilities.Load(); value != nil { - if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { - return clientCapabilities - } - } - return mcp.ClientCapabilities{} -} - -func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { - s.clientCapabilities.Store(clientCapabilities) -} - -func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { - s.loggingLevel.Store(level) -} - -func (s *InProcessSession) GetLogLevel() mcp.LoggingLevel { - level := s.loggingLevel.Load() - if level == nil { - return mcp.LoggingLevelError - } - return level.(mcp.LoggingLevel) -} - -func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { - s.mu.RLock() - handler := s.samplingHandler - s.mu.RUnlock() - - if handler == nil { - return nil, fmt.Errorf("no sampling handler available") - } - - return handler.CreateMessage(ctx, request) -} - -func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { - s.mu.RLock() - handler := s.elicitationHandler - s.mu.RUnlock() - - if handler == nil { - return nil, fmt.Errorf("no elicitation handler available") - } - - return handler.Elicit(ctx, request) -} - -// ListRoots sends a list roots request to the client and waits for the response. -// Returns an error if no roots handler is available. -func (s *InProcessSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { - s.mu.RLock() - handler := s.rootsHandler - s.mu.RUnlock() - - if handler == nil { - return nil, fmt.Errorf("no roots handler available") - } - - return handler.ListRoots(ctx, request) -} - -// GenerateInProcessSessionID generates a unique session ID for inprocess clients -func GenerateInProcessSessionID() string { - return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) -} - -// Ensure interface compliance -var ( - _ ClientSession = (*InProcessSession)(nil) - _ SessionWithLogging = (*InProcessSession)(nil) - _ SessionWithClientInfo = (*InProcessSession)(nil) - _ SessionWithSampling = (*InProcessSession)(nil) - _ SessionWithElicitation = (*InProcessSession)(nil) - _ SessionWithRoots = (*InProcessSession)(nil) -) diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go deleted file mode 100644 index b9175dc4e..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go +++ /dev/null @@ -1,339 +0,0 @@ -// Code generated by `go generate`. DO NOT EDIT. -// source: server/internal/gen/request_handler.go.tmpl -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/mark3labs/mcp-go/mcp" -) - -// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response -func (s *MCPServer) HandleMessage( - ctx context.Context, - message json.RawMessage, -) mcp.JSONRPCMessage { - // Add server to context - ctx = context.WithValue(ctx, serverKey{}, s) - var err *requestError - - var baseMessage struct { - JSONRPC string `json:"jsonrpc"` - Method mcp.MCPMethod `json:"method"` - ID any `json:"id,omitempty"` - Result any `json:"result,omitempty"` - } - - if err := json.Unmarshal(message, &baseMessage); err != nil { - return createErrorResponse( - nil, - mcp.PARSE_ERROR, - "Failed to parse message", - ) - } - - // Check for valid JSONRPC version - if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { - return createErrorResponse( - baseMessage.ID, - mcp.INVALID_REQUEST, - "Invalid JSON-RPC version", - ) - } - - if baseMessage.ID == nil { - var notification mcp.JSONRPCNotification - if err := json.Unmarshal(message, ¬ification); err != nil { - return createErrorResponse( - nil, - mcp.PARSE_ERROR, - "Failed to parse notification", - ) - } - s.handleNotification(ctx, notification) - return nil // Return nil for notifications - } - - if baseMessage.Result != nil { - // this is a response to a request sent by the server (e.g. from a ping - // sent due to WithKeepAlive option) - return nil - } - - handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message) - if handleErr != nil { - return createErrorResponse( - baseMessage.ID, - mcp.INVALID_REQUEST, - handleErr.Error(), - ) - } - - // Get request header from ctx - h := ctx.Value(requestHeader) - headers, ok := h.(http.Header) - - if headers == nil || !ok { - headers = make(http.Header) - } - - switch baseMessage.Method { - case mcp.MethodInitialize: - var request mcp.InitializeRequest - var result *mcp.InitializeResult - if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) - result, err = s.handleInitialize(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodPing: - var request mcp.PingRequest - var result *mcp.EmptyResult - if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforePing(ctx, baseMessage.ID, &request) - result, err = s.handlePing(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterPing(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodSetLogLevel: - var request mcp.SetLevelRequest - var result *mcp.EmptyResult - if s.capabilities.logging == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("logging %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) - result, err = s.handleSetLevel(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodResourcesList: - var request mcp.ListResourcesRequest - var result *mcp.ListResourcesResult - if s.capabilities.resources == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("resources %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeListResources(ctx, baseMessage.ID, &request) - result, err = s.handleListResources(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodResourcesTemplatesList: - var request mcp.ListResourceTemplatesRequest - var result *mcp.ListResourceTemplatesResult - if s.capabilities.resources == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("resources %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) - result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodResourcesRead: - var request mcp.ReadResourceRequest - var result *mcp.ReadResourceResult - if s.capabilities.resources == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("resources %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) - result, err = s.handleReadResource(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodPromptsList: - var request mcp.ListPromptsRequest - var result *mcp.ListPromptsResult - if s.capabilities.prompts == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("prompts %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) - result, err = s.handleListPrompts(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodPromptsGet: - var request mcp.GetPromptRequest - var result *mcp.GetPromptResult - if s.capabilities.prompts == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("prompts %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) - result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodToolsList: - var request mcp.ListToolsRequest - var result *mcp.ListToolsResult - if s.capabilities.tools == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("tools %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeListTools(ctx, baseMessage.ID, &request) - result, err = s.handleListTools(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - case mcp.MethodToolsCall: - var request mcp.CallToolRequest - var result *mcp.CallToolResult - if s.capabilities.tools == nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.METHOD_NOT_FOUND, - err: fmt.Errorf("tools %w", ErrUnsupported), - } - } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { - err = &requestError{ - id: baseMessage.ID, - code: mcp.INVALID_REQUEST, - err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, - } - } else { - request.Header = headers - s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) - result, err = s.handleToolCall(ctx, baseMessage.ID, request) - } - if err != nil { - s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) - return err.ToJSONRPCError() - } - s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) - return createResponse(baseMessage.ID, *result) - default: - return createErrorResponse( - baseMessage.ID, - mcp.METHOD_NOT_FOUND, - fmt.Sprintf("Method %s not found", baseMessage.Method), - ) - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/roots.go b/vendor/github.com/mark3labs/mcp-go/server/roots.go deleted file mode 100644 index 29e0b94d1..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/roots.go +++ /dev/null @@ -1,32 +0,0 @@ -package server - -import ( - "context" - "errors" - - "github.com/mark3labs/mcp-go/mcp" -) - -var ( - // ErrNoClientSession is returned when there is no active client session in the context - ErrNoClientSession = errors.New("no active client session") - // ErrRootsNotSupported is returned when the session does not support roots - ErrRootsNotSupported = errors.New("session does not support roots") -) - -// RequestRoots sends an list roots request to the client. -// The client must have declared roots capability during initialization. -// The session must implement SessionWithRoots to support this operation. -func (s *MCPServer) RequestRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { - session := ClientSessionFromContext(ctx) - if session == nil { - return nil, ErrNoClientSession - } - - // Check if the session supports roots requests - if rootsSession, ok := session.(SessionWithRoots); ok { - return rootsSession.ListRoots(ctx, request) - } - - return nil, ErrRootsNotSupported -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go deleted file mode 100644 index 2118db155..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/sampling.go +++ /dev/null @@ -1,61 +0,0 @@ -package server - -import ( - "context" - "fmt" - - "github.com/mark3labs/mcp-go/mcp" -) - -// EnableSampling enables sampling capabilities for the server. -// This allows the server to send sampling requests to clients that support it. -func (s *MCPServer) EnableSampling() { - s.capabilitiesMu.Lock() - defer s.capabilitiesMu.Unlock() - - enabled := true - s.capabilities.sampling = &enabled -} - -// RequestSampling sends a sampling request to the client. -// The client must have declared sampling capability during initialization. -func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { - session := ClientSessionFromContext(ctx) - if session == nil { - return nil, fmt.Errorf("no active session") - } - - // Check if the session supports sampling requests - if samplingSession, ok := session.(SessionWithSampling); ok { - return samplingSession.RequestSampling(ctx, request) - } - - // Check for inprocess sampling handler in context - if handler := InProcessSamplingHandlerFromContext(ctx); handler != nil { - return handler.CreateMessage(ctx, request) - } - - return nil, fmt.Errorf("session does not support sampling") -} - -// SessionWithSampling extends ClientSession to support sampling requests. -type SessionWithSampling interface { - ClientSession - RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) -} - -// inProcessSamplingHandlerKey is the context key for storing inprocess sampling handler -type inProcessSamplingHandlerKey struct{} - -// WithInProcessSamplingHandler adds a sampling handler to the context for inprocess clients -func WithInProcessSamplingHandler(ctx context.Context, handler SamplingHandler) context.Context { - return context.WithValue(ctx, inProcessSamplingHandlerKey{}, handler) -} - -// InProcessSamplingHandlerFromContext retrieves the inprocess sampling handler from context -func InProcessSamplingHandlerFromContext(ctx context.Context) SamplingHandler { - if handler, ok := ctx.Value(inProcessSamplingHandlerKey{}).(SamplingHandler); ok { - return handler - } - return nil -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go deleted file mode 100644 index d46fc868d..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/server.go +++ /dev/null @@ -1,1337 +0,0 @@ -// Package server provides MCP (Model Context Protocol) server implementations. -package server - -import ( - "cmp" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "maps" - "slices" - "sort" - "sync" - - "github.com/mark3labs/mcp-go/mcp" -) - -// resourceEntry holds both a resource and its handler -type resourceEntry struct { - resource mcp.Resource - handler ResourceHandlerFunc -} - -// resourceTemplateEntry holds both a template and its handler -type resourceTemplateEntry struct { - template mcp.ResourceTemplate - handler ResourceTemplateHandlerFunc -} - -// ServerOption is a function that configures an MCPServer. -type ServerOption func(*MCPServer) - -// ResourceHandlerFunc is a function that returns resource contents. -type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) - -// ResourceTemplateHandlerFunc is a function that returns a resource template. -type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) - -// PromptHandlerFunc handles prompt requests with given arguments. -type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) - -// ToolHandlerFunc handles tool calls with given arguments. -type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - -// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. -type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc - -// ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc. -type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc - -// ToolFilterFunc is a function that filters tools based on context, typically using session information. -type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool - -// ServerTool combines a Tool with its ToolHandlerFunc. -type ServerTool struct { - Tool mcp.Tool - Handler ToolHandlerFunc -} - -// ServerPrompt combines a Prompt with its handler function. -type ServerPrompt struct { - Prompt mcp.Prompt - Handler PromptHandlerFunc -} - -// ServerResource combines a Resource with its handler function. -type ServerResource struct { - Resource mcp.Resource - Handler ResourceHandlerFunc -} - -// ServerResourceTemplate combines a ResourceTemplate with its handler function. -type ServerResourceTemplate struct { - Template mcp.ResourceTemplate - Handler ResourceTemplateHandlerFunc -} - -// serverKey is the context key for storing the server instance -type serverKey struct{} - -// ServerFromContext retrieves the MCPServer instance from a context -func ServerFromContext(ctx context.Context) *MCPServer { - if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { - return srv - } - return nil -} - -// UnparsableMessageError is attached to the RequestError when json.Unmarshal -// fails on the request. -type UnparsableMessageError struct { - message json.RawMessage - method mcp.MCPMethod - err error -} - -func (e *UnparsableMessageError) Error() string { - return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) -} - -func (e *UnparsableMessageError) Unwrap() error { - return e.err -} - -func (e *UnparsableMessageError) GetMessage() json.RawMessage { - return e.message -} - -func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { - return e.method -} - -// RequestError is an error that can be converted to a JSON-RPC error. -// Implements Unwrap() to allow inspecting the error chain. -type requestError struct { - id any - code int - err error -} - -func (e *requestError) Error() string { - return fmt.Sprintf("request error: %s", e.err) -} - -func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { - return mcp.JSONRPCError{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: mcp.NewRequestId(e.id), - Error: mcp.NewJSONRPCErrorDetails(e.code, e.err.Error(), nil), - } -} - -func (e *requestError) Unwrap() error { - return e.err -} - -// NotificationHandlerFunc handles incoming notifications. -type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) - -// MCPServer implements a Model Context Protocol server that can handle various types of requests -// including resources, prompts, and tools. -type MCPServer struct { - // Separate mutexes for different resource types - resourcesMu sync.RWMutex - resourceMiddlewareMu sync.RWMutex - promptsMu sync.RWMutex - toolsMu sync.RWMutex - toolMiddlewareMu sync.RWMutex - notificationHandlersMu sync.RWMutex - capabilitiesMu sync.RWMutex - toolFiltersMu sync.RWMutex - - name string - version string - instructions string - resources map[string]resourceEntry - resourceTemplates map[string]resourceTemplateEntry - prompts map[string]mcp.Prompt - promptHandlers map[string]PromptHandlerFunc - tools map[string]ServerTool - toolHandlerMiddlewares []ToolHandlerMiddleware - resourceHandlerMiddlewares []ResourceHandlerMiddleware - toolFilters []ToolFilterFunc - notificationHandlers map[string]NotificationHandlerFunc - capabilities serverCapabilities - paginationLimit *int - sessions sync.Map - hooks *Hooks -} - -// WithPaginationLimit sets the pagination limit for the server. -func WithPaginationLimit(limit int) ServerOption { - return func(s *MCPServer) { - s.paginationLimit = &limit - } -} - -// serverCapabilities defines the supported features of the MCP server -type serverCapabilities struct { - tools *toolCapabilities - resources *resourceCapabilities - prompts *promptCapabilities - logging *bool - sampling *bool - elicitation *bool - roots *bool -} - -// resourceCapabilities defines the supported resource-related features -type resourceCapabilities struct { - subscribe bool - listChanged bool -} - -// promptCapabilities defines the supported prompt-related features -type promptCapabilities struct { - listChanged bool -} - -// toolCapabilities defines the supported tool-related features -type toolCapabilities struct { - listChanged bool -} - -// WithResourceCapabilities configures resource-related server capabilities -func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { - return func(s *MCPServer) { - // Always create a non-nil capability object - s.capabilities.resources = &resourceCapabilities{ - subscribe: subscribe, - listChanged: listChanged, - } - } -} - -// WithToolHandlerMiddleware allows adding a middleware for the -// tool handler call chain. -func WithToolHandlerMiddleware( - toolHandlerMiddleware ToolHandlerMiddleware, -) ServerOption { - return func(s *MCPServer) { - s.toolMiddlewareMu.Lock() - s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware) - s.toolMiddlewareMu.Unlock() - } -} - -// WithResourceHandlerMiddleware allows adding a middleware for the -// resource handler call chain. -func WithResourceHandlerMiddleware( - resourceHandlerMiddleware ResourceHandlerMiddleware, -) ServerOption { - return func(s *MCPServer) { - s.resourceMiddlewareMu.Lock() - s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware) - s.resourceMiddlewareMu.Unlock() - } -} - -// WithResourceRecovery adds a middleware that recovers from panics in resource handlers. -func WithResourceRecovery() ServerOption { - return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc { - return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf( - "panic recovered in %s resource handler: %v", - request.Params.URI, - r, - ) - } - }() - return next(ctx, request) - } - }) -} - -// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools -func WithToolFilter( - toolFilter ToolFilterFunc, -) ServerOption { - return func(s *MCPServer) { - s.toolFiltersMu.Lock() - s.toolFilters = append(s.toolFilters, toolFilter) - s.toolFiltersMu.Unlock() - } -} - -// WithRecovery adds a middleware that recovers from panics in tool handlers. -func WithRecovery() ServerOption { - return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf( - "panic recovered in %s tool handler: %v", - request.Params.Name, - r, - ) - } - }() - return next(ctx, request) - } - }) -} - -// WithHooks allows adding hooks that will be called before or after -// either [all] requests or before / after specific request methods, or else -// prior to returning an error to the client. -func WithHooks(hooks *Hooks) ServerOption { - return func(s *MCPServer) { - s.hooks = hooks - } -} - -// WithPromptCapabilities configures prompt-related server capabilities -func WithPromptCapabilities(listChanged bool) ServerOption { - return func(s *MCPServer) { - // Always create a non-nil capability object - s.capabilities.prompts = &promptCapabilities{ - listChanged: listChanged, - } - } -} - -// WithToolCapabilities configures tool-related server capabilities -func WithToolCapabilities(listChanged bool) ServerOption { - return func(s *MCPServer) { - // Always create a non-nil capability object - s.capabilities.tools = &toolCapabilities{ - listChanged: listChanged, - } - } -} - -// WithLogging enables logging capabilities for the server -func WithLogging() ServerOption { - return func(s *MCPServer) { - s.capabilities.logging = mcp.ToBoolPtr(true) - } -} - -// WithElicitation enables elicitation capabilities for the server -func WithElicitation() ServerOption { - return func(s *MCPServer) { - s.capabilities.elicitation = mcp.ToBoolPtr(true) - } -} - -// WithRoots returns a ServerOption that enables the roots capability on the MCPServer -func WithRoots() ServerOption { - return func(s *MCPServer) { - s.capabilities.roots = mcp.ToBoolPtr(true) - } -} - -// WithInstructions sets the server instructions for the client returned in the initialize response -func WithInstructions(instructions string) ServerOption { - return func(s *MCPServer) { - s.instructions = instructions - } -} - -// NewMCPServer creates a new MCP server instance with the given name, version and options -func NewMCPServer( - name, version string, - opts ...ServerOption, -) *MCPServer { - s := &MCPServer{ - resources: make(map[string]resourceEntry), - resourceTemplates: make(map[string]resourceTemplateEntry), - prompts: make(map[string]mcp.Prompt), - promptHandlers: make(map[string]PromptHandlerFunc), - tools: make(map[string]ServerTool), - toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0), - resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0), - name: name, - version: version, - notificationHandlers: make(map[string]NotificationHandlerFunc), - capabilities: serverCapabilities{ - tools: nil, - resources: nil, - prompts: nil, - logging: nil, - }, - } - - for _, opt := range opts { - opt(s) - } - - return s -} - -// GenerateInProcessSessionID generates a unique session ID for inprocess clients -func (s *MCPServer) GenerateInProcessSessionID() string { - return GenerateInProcessSessionID() -} - -// AddResources registers multiple resources at once -func (s *MCPServer) AddResources(resources ...ServerResource) { - s.implicitlyRegisterResourceCapabilities() - - s.resourcesMu.Lock() - for _, entry := range resources { - s.resources[entry.Resource.URI] = resourceEntry{ - resource: entry.Resource, - handler: entry.Handler, - } - } - s.resourcesMu.Unlock() - - // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification - if s.capabilities.resources.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) - } -} - -// SetResources replaces all existing resources with the provided list -func (s *MCPServer) SetResources(resources ...ServerResource) { - s.resourcesMu.Lock() - s.resources = make(map[string]resourceEntry, len(resources)) - s.resourcesMu.Unlock() - s.AddResources(resources...) -} - -// AddResource registers a new resource and its handler -func (s *MCPServer) AddResource( - resource mcp.Resource, - handler ResourceHandlerFunc, -) { - s.AddResources(ServerResource{Resource: resource, Handler: handler}) -} - -// DeleteResources removes resources from the server -func (s *MCPServer) DeleteResources(uris ...string) { - s.resourcesMu.Lock() - var exists bool - for _, uri := range uris { - if _, ok := s.resources[uri]; ok { - delete(s.resources, uri) - exists = true - } - } - s.resourcesMu.Unlock() - - // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource - if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) - } -} - -// RemoveResource removes a resource from the server -func (s *MCPServer) RemoveResource(uri string) { - s.resourcesMu.Lock() - _, exists := s.resources[uri] - if exists { - delete(s.resources, uri) - } - s.resourcesMu.Unlock() - - // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource - if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) - } -} - -// AddResourceTemplates registers multiple resource templates at once -func (s *MCPServer) AddResourceTemplates(resourceTemplates ...ServerResourceTemplate) { - s.implicitlyRegisterResourceCapabilities() - - s.resourcesMu.Lock() - for _, entry := range resourceTemplates { - s.resourceTemplates[entry.Template.URITemplate.Raw()] = resourceTemplateEntry{ - template: entry.Template, - handler: entry.Handler, - } - } - s.resourcesMu.Unlock() - - // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification - if s.capabilities.resources.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) - } -} - -// SetResourceTemplates replaces all existing resource templates with the provided list -func (s *MCPServer) SetResourceTemplates(templates ...ServerResourceTemplate) { - s.resourcesMu.Lock() - s.resourceTemplates = make(map[string]resourceTemplateEntry, len(templates)) - s.resourcesMu.Unlock() - s.AddResourceTemplates(templates...) -} - -// AddResourceTemplate registers a new resource template and its handler -func (s *MCPServer) AddResourceTemplate( - template mcp.ResourceTemplate, - handler ResourceTemplateHandlerFunc, -) { - s.AddResourceTemplates(ServerResourceTemplate{Template: template, Handler: handler}) -} - -// AddPrompts registers multiple prompts at once -func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { - s.implicitlyRegisterPromptCapabilities() - - s.promptsMu.Lock() - for _, entry := range prompts { - s.prompts[entry.Prompt.Name] = entry.Prompt - s.promptHandlers[entry.Prompt.Name] = entry.Handler - } - s.promptsMu.Unlock() - - // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. - if s.capabilities.prompts.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) - } -} - -// AddPrompt registers a new prompt handler with the given name -func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { - s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) -} - -// SetPrompts replaces all existing prompts with the provided list -func (s *MCPServer) SetPrompts(prompts ...ServerPrompt) { - s.promptsMu.Lock() - s.prompts = make(map[string]mcp.Prompt, len(prompts)) - s.promptHandlers = make(map[string]PromptHandlerFunc, len(prompts)) - s.promptsMu.Unlock() - s.AddPrompts(prompts...) -} - -// DeletePrompts removes prompts from the server -func (s *MCPServer) DeletePrompts(names ...string) { - s.promptsMu.Lock() - var exists bool - for _, name := range names { - if _, ok := s.prompts[name]; ok { - delete(s.prompts, name) - delete(s.promptHandlers, name) - exists = true - } - } - s.promptsMu.Unlock() - - // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt - if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) - } -} - -// AddTool registers a new tool and its handler -func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { - s.AddTools(ServerTool{Tool: tool, Handler: handler}) -} - -// Register tool capabilities due to a tool being added. Default to -// listChanged: true, but don't change the value if we've already explicitly -// registered tools.listChanged false. -func (s *MCPServer) implicitlyRegisterToolCapabilities() { - s.implicitlyRegisterCapabilities( - func() bool { return s.capabilities.tools != nil }, - func() { s.capabilities.tools = &toolCapabilities{listChanged: true} }, - ) -} - -func (s *MCPServer) implicitlyRegisterResourceCapabilities() { - s.implicitlyRegisterCapabilities( - func() bool { return s.capabilities.resources != nil }, - func() { s.capabilities.resources = &resourceCapabilities{} }, - ) -} - -func (s *MCPServer) implicitlyRegisterPromptCapabilities() { - s.implicitlyRegisterCapabilities( - func() bool { return s.capabilities.prompts != nil }, - func() { s.capabilities.prompts = &promptCapabilities{} }, - ) -} - -func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) { - s.capabilitiesMu.RLock() - if check() { - s.capabilitiesMu.RUnlock() - return - } - s.capabilitiesMu.RUnlock() - - s.capabilitiesMu.Lock() - if !check() { - register() - } - s.capabilitiesMu.Unlock() -} - -// AddTools registers multiple tools at once -func (s *MCPServer) AddTools(tools ...ServerTool) { - s.implicitlyRegisterToolCapabilities() - - s.toolsMu.Lock() - for _, entry := range tools { - s.tools[entry.Tool.Name] = entry - } - s.toolsMu.Unlock() - - // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. - if s.capabilities.tools.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) - } -} - -// SetTools replaces all existing tools with the provided list -func (s *MCPServer) SetTools(tools ...ServerTool) { - s.toolsMu.Lock() - s.tools = make(map[string]ServerTool, len(tools)) - s.toolsMu.Unlock() - s.AddTools(tools...) -} - -// GetTool retrieves the specified tool -func (s *MCPServer) GetTool(toolName string) *ServerTool { - s.toolsMu.RLock() - defer s.toolsMu.RUnlock() - if tool, ok := s.tools[toolName]; ok { - return &tool - } - return nil -} - -func (s *MCPServer) ListTools() map[string]*ServerTool { - s.toolsMu.RLock() - defer s.toolsMu.RUnlock() - if len(s.tools) == 0 { - return nil - } - // Create a copy to prevent external modification - toolsCopy := make(map[string]*ServerTool, len(s.tools)) - for name, tool := range s.tools { - toolsCopy[name] = &tool - } - return toolsCopy -} - -// DeleteTools removes tools from the server -func (s *MCPServer) DeleteTools(names ...string) { - s.toolsMu.Lock() - var exists bool - for _, name := range names { - if _, ok := s.tools[name]; ok { - delete(s.tools, name) - exists = true - } - } - s.toolsMu.Unlock() - - // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. - if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { - // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) - } -} - -// AddNotificationHandler registers a new handler for incoming notifications -func (s *MCPServer) AddNotificationHandler( - method string, - handler NotificationHandlerFunc, -) { - s.notificationHandlersMu.Lock() - defer s.notificationHandlersMu.Unlock() - s.notificationHandlers[method] = handler -} - -func (s *MCPServer) handleInitialize( - ctx context.Context, - _ any, - request mcp.InitializeRequest, -) (*mcp.InitializeResult, *requestError) { - capabilities := mcp.ServerCapabilities{} - - // Only add resource capabilities if they're configured - if s.capabilities.resources != nil { - capabilities.Resources = &struct { - Subscribe bool `json:"subscribe,omitempty"` - ListChanged bool `json:"listChanged,omitempty"` - }{ - Subscribe: s.capabilities.resources.subscribe, - ListChanged: s.capabilities.resources.listChanged, - } - } - - // Only add prompt capabilities if they're configured - if s.capabilities.prompts != nil { - capabilities.Prompts = &struct { - ListChanged bool `json:"listChanged,omitempty"` - }{ - ListChanged: s.capabilities.prompts.listChanged, - } - } - - // Only add tool capabilities if they're configured - if s.capabilities.tools != nil { - capabilities.Tools = &struct { - ListChanged bool `json:"listChanged,omitempty"` - }{ - ListChanged: s.capabilities.tools.listChanged, - } - } - - if s.capabilities.logging != nil && *s.capabilities.logging { - capabilities.Logging = &struct{}{} - } - - if s.capabilities.sampling != nil && *s.capabilities.sampling { - capabilities.Sampling = &struct{}{} - } - - if s.capabilities.elicitation != nil && *s.capabilities.elicitation { - capabilities.Elicitation = &struct{}{} - } - - if s.capabilities.roots != nil && *s.capabilities.roots { - capabilities.Roots = &struct{}{} - } - - result := mcp.InitializeResult{ - ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), - ServerInfo: mcp.Implementation{ - Name: s.name, - Version: s.version, - }, - Capabilities: capabilities, - Instructions: s.instructions, - } - - if session := ClientSessionFromContext(ctx); session != nil { - session.Initialize() - - // Store client info if the session supports it - if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { - sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) - sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) - } - } - - return &result, nil -} - -func (s *MCPServer) protocolVersion(clientVersion string) string { - // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header, - // and has no other way to identify the version - for example, by relying on the protocol version negotiated - // during initialization - the server SHOULD assume protocol version 2025-03-26 - // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header - if len(clientVersion) == 0 { - clientVersion = "2025-03-26" - } - - if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { - return clientVersion - } - - return mcp.LATEST_PROTOCOL_VERSION -} - -func (s *MCPServer) handlePing( - _ context.Context, - _ any, - _ mcp.PingRequest, -) (*mcp.EmptyResult, *requestError) { - return &mcp.EmptyResult{}, nil -} - -func (s *MCPServer) handleSetLevel( - ctx context.Context, - id any, - request mcp.SetLevelRequest, -) (*mcp.EmptyResult, *requestError) { - clientSession := ClientSessionFromContext(ctx) - if clientSession == nil || !clientSession.Initialized() { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: ErrSessionNotInitialized, - } - } - - sessionLogging, ok := clientSession.(SessionWithLogging) - if !ok { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: ErrSessionDoesNotSupportLogging, - } - } - - level := request.Params.Level - // Validate logging level - switch level { - case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice, - mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical, - mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency: - // Valid level - default: - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: fmt.Errorf("invalid logging level '%s'", level), - } - } - - sessionLogging.SetLogLevel(level) - - return &mcp.EmptyResult{}, nil -} - -func listByPagination[T mcp.Named]( - _ context.Context, - s *MCPServer, - cursor mcp.Cursor, - allElements []T, -) ([]T, mcp.Cursor, error) { - startPos := 0 - if cursor != "" { - c, err := base64.StdEncoding.DecodeString(string(cursor)) - if err != nil { - return nil, "", err - } - cString := string(c) - startPos = sort.Search(len(allElements), func(i int) bool { - return allElements[i].GetName() > cString - }) - } - endPos := len(allElements) - if s.paginationLimit != nil { - if len(allElements) > startPos+*s.paginationLimit { - endPos = startPos + *s.paginationLimit - } - } - elementsToReturn := allElements[startPos:endPos] - // set the next cursor - nextCursor := func() mcp.Cursor { - if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit { - nc := elementsToReturn[len(elementsToReturn)-1].GetName() - toString := base64.StdEncoding.EncodeToString([]byte(nc)) - return mcp.Cursor(toString) - } - return "" - }() - return elementsToReturn, nextCursor, nil -} - -func (s *MCPServer) handleListResources( - ctx context.Context, - id any, - request mcp.ListResourcesRequest, -) (*mcp.ListResourcesResult, *requestError) { - s.resourcesMu.RLock() - resourceMap := make(map[string]mcp.Resource, len(s.resources)) - for uri, entry := range s.resources { - resourceMap[uri] = entry.resource - } - s.resourcesMu.RUnlock() - - // Check if there are session-specific resources - session := ClientSessionFromContext(ctx) - if session != nil { - if sessionWithResources, ok := session.(SessionWithResources); ok { - if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil { - // Merge session-specific resources with global resources - for uri, serverResource := range sessionResources { - resourceMap[uri] = serverResource.Resource - } - } - } - } - - // Sort the resources by name - resourcesList := slices.SortedFunc(maps.Values(resourceMap), func(a, b mcp.Resource) int { - return cmp.Compare(a.Name, b.Name) - }) - - // Apply pagination - resourcesToReturn, nextCursor, err := listByPagination( - ctx, - s, - request.Params.Cursor, - resourcesList, - ) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: err, - } - } - result := mcp.ListResourcesResult{ - Resources: resourcesToReturn, - PaginatedResult: mcp.PaginatedResult{ - NextCursor: nextCursor, - }, - } - return &result, nil -} - -func (s *MCPServer) handleListResourceTemplates( - ctx context.Context, - id any, - request mcp.ListResourceTemplatesRequest, -) (*mcp.ListResourceTemplatesResult, *requestError) { - // Get global templates - s.resourcesMu.RLock() - templateMap := make(map[string]mcp.ResourceTemplate, len(s.resourceTemplates)) - for uri, entry := range s.resourceTemplates { - templateMap[uri] = entry.template - } - s.resourcesMu.RUnlock() - - // Check if there are session-specific resource templates - session := ClientSessionFromContext(ctx) - if session != nil { - if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { - if sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates(); sessionTemplates != nil { - // Merge session-specific templates with global templates - // Session templates override global ones - for uriTemplate, serverTemplate := range sessionTemplates { - templateMap[uriTemplate] = serverTemplate.Template - } - } - } - } - - // Convert map to slice for sorting and pagination - templates := make([]mcp.ResourceTemplate, 0, len(templateMap)) - for _, template := range templateMap { - templates = append(templates, template) - } - - sort.Slice(templates, func(i, j int) bool { - return templates[i].Name < templates[j].Name - }) - templatesToReturn, nextCursor, err := listByPagination( - ctx, - s, - request.Params.Cursor, - templates, - ) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: err, - } - } - result := mcp.ListResourceTemplatesResult{ - ResourceTemplates: templatesToReturn, - PaginatedResult: mcp.PaginatedResult{ - NextCursor: nextCursor, - }, - } - return &result, nil -} - -func (s *MCPServer) handleReadResource( - ctx context.Context, - id any, - request mcp.ReadResourceRequest, -) (*mcp.ReadResourceResult, *requestError) { - s.resourcesMu.RLock() - - // First check session-specific resources - var handler ResourceHandlerFunc - var ok bool - - session := ClientSessionFromContext(ctx) - if session != nil { - if sessionWithResources, typeAssertOk := session.(SessionWithResources); typeAssertOk { - if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil { - resource, sessionOk := sessionResources[request.Params.URI] - if sessionOk { - handler = resource.Handler - ok = true - } - } - } - } - - // If not found in session tools, check global tools - if !ok { - globalResource, rok := s.resources[request.Params.URI] - if rok { - handler = globalResource.handler - ok = true - } - } - - // First try direct resource handlers - if ok { - s.resourcesMu.RUnlock() - - finalHandler := handler - s.resourceMiddlewareMu.RLock() - mw := s.resourceHandlerMiddlewares - // Apply middlewares in reverse order - for i := len(mw) - 1; i >= 0; i-- { - finalHandler = mw[i](finalHandler) - } - s.resourceMiddlewareMu.RUnlock() - - contents, err := finalHandler(ctx, request) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: err, - } - } - return &mcp.ReadResourceResult{Contents: contents}, nil - } - - // If no direct handler found, try matching against templates - var matchedHandler ResourceTemplateHandlerFunc - var matched bool - - // First check session templates if available - if session != nil { - if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok { - sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates() - for _, serverTemplate := range sessionTemplates { - if serverTemplate.Template.URITemplate == nil { - continue - } - if matchesTemplate(request.Params.URI, serverTemplate.Template.URITemplate) { - matchedHandler = serverTemplate.Handler - matched = true - matchedVars := serverTemplate.Template.URITemplate.Match(request.Params.URI) - // Convert matched variables to a map - request.Params.Arguments = make(map[string]any, len(matchedVars)) - for name, value := range matchedVars { - request.Params.Arguments[name] = value.V - } - break - } - } - } - } - - // If not found in session templates, check global templates - if !matched { - for _, entry := range s.resourceTemplates { - template := entry.template - if template.URITemplate == nil { - continue - } - if matchesTemplate(request.Params.URI, template.URITemplate) { - matchedHandler = entry.handler - matched = true - matchedVars := template.URITemplate.Match(request.Params.URI) - // Convert matched variables to a map - request.Params.Arguments = make(map[string]any, len(matchedVars)) - for name, value := range matchedVars { - request.Params.Arguments[name] = value.V - } - break - } - } - } - s.resourcesMu.RUnlock() - - if matched { - // If a match is found, then we have a final handler and can - // apply middlewares. - s.resourceMiddlewareMu.RLock() - finalHandler := ResourceHandlerFunc(matchedHandler) - mw := s.resourceHandlerMiddlewares - // Apply middlewares in reverse order - for i := len(mw) - 1; i >= 0; i-- { - finalHandler = mw[i](finalHandler) - } - s.resourceMiddlewareMu.RUnlock() - contents, err := finalHandler(ctx, request) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: err, - } - } - return &mcp.ReadResourceResult{Contents: contents}, nil - } - - return nil, &requestError{ - id: id, - code: mcp.RESOURCE_NOT_FOUND, - err: fmt.Errorf( - "handler not found for resource URI '%s': %w", - request.Params.URI, - ErrResourceNotFound, - ), - } -} - -// matchesTemplate checks if a URI matches a URI template pattern -func matchesTemplate(uri string, template *mcp.URITemplate) bool { - return template.Regexp().MatchString(uri) -} - -func (s *MCPServer) handleListPrompts( - ctx context.Context, - id any, - request mcp.ListPromptsRequest, -) (*mcp.ListPromptsResult, *requestError) { - s.promptsMu.RLock() - prompts := make([]mcp.Prompt, 0, len(s.prompts)) - for _, prompt := range s.prompts { - prompts = append(prompts, prompt) - } - s.promptsMu.RUnlock() - - // sort prompts by name - sort.Slice(prompts, func(i, j int) bool { - return prompts[i].Name < prompts[j].Name - }) - promptsToReturn, nextCursor, err := listByPagination( - ctx, - s, - request.Params.Cursor, - prompts, - ) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: err, - } - } - result := mcp.ListPromptsResult{ - Prompts: promptsToReturn, - PaginatedResult: mcp.PaginatedResult{ - NextCursor: nextCursor, - }, - } - return &result, nil -} - -func (s *MCPServer) handleGetPrompt( - ctx context.Context, - id any, - request mcp.GetPromptRequest, -) (*mcp.GetPromptResult, *requestError) { - s.promptsMu.RLock() - handler, ok := s.promptHandlers[request.Params.Name] - s.promptsMu.RUnlock() - - if !ok { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound), - } - } - - result, err := handler(ctx, request) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: err, - } - } - - return result, nil -} - -func (s *MCPServer) handleListTools( - ctx context.Context, - id any, - request mcp.ListToolsRequest, -) (*mcp.ListToolsResult, *requestError) { - // Get the base tools from the server - s.toolsMu.RLock() - tools := make([]mcp.Tool, 0, len(s.tools)) - - // Get all tool names for consistent ordering - toolNames := make([]string, 0, len(s.tools)) - for name := range s.tools { - toolNames = append(toolNames, name) - } - - // Sort the tool names for consistent ordering - sort.Strings(toolNames) - - // Add tools in sorted order - for _, name := range toolNames { - tools = append(tools, s.tools[name].Tool) - } - s.toolsMu.RUnlock() - - // Check if there are session-specific tools - session := ClientSessionFromContext(ctx) - if session != nil { - if sessionWithTools, ok := session.(SessionWithTools); ok { - if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { - // Override or add session-specific tools - // We need to create a map first to merge the tools properly - toolMap := make(map[string]mcp.Tool) - - // Add global tools first - for _, tool := range tools { - toolMap[tool.Name] = tool - } - - // Then override with session-specific tools - for name, serverTool := range sessionTools { - toolMap[name] = serverTool.Tool - } - - // Convert back to slice - tools = make([]mcp.Tool, 0, len(toolMap)) - for _, tool := range toolMap { - tools = append(tools, tool) - } - - // Sort again to maintain consistent ordering - sort.Slice(tools, func(i, j int) bool { - return tools[i].Name < tools[j].Name - }) - } - } - } - - // Apply tool filters if any are defined - s.toolFiltersMu.RLock() - if len(s.toolFilters) > 0 { - for _, filter := range s.toolFilters { - tools = filter(ctx, tools) - } - } - s.toolFiltersMu.RUnlock() - - // Apply pagination - toolsToReturn, nextCursor, err := listByPagination( - ctx, - s, - request.Params.Cursor, - tools, - ) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: err, - } - } - - result := mcp.ListToolsResult{ - Tools: toolsToReturn, - PaginatedResult: mcp.PaginatedResult{ - NextCursor: nextCursor, - }, - } - return &result, nil -} - -func (s *MCPServer) handleToolCall( - ctx context.Context, - id any, - request mcp.CallToolRequest, -) (*mcp.CallToolResult, *requestError) { - // First check session-specific tools - var tool ServerTool - var ok bool - - session := ClientSessionFromContext(ctx) - if session != nil { - if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { - if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { - var sessionOk bool - tool, sessionOk = sessionTools[request.Params.Name] - if sessionOk { - ok = true - } - } - } - } - - // If not found in session tools, check global tools - if !ok { - s.toolsMu.RLock() - tool, ok = s.tools[request.Params.Name] - s.toolsMu.RUnlock() - } - - if !ok { - return nil, &requestError{ - id: id, - code: mcp.INVALID_PARAMS, - err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound), - } - } - - finalHandler := tool.Handler - - s.toolMiddlewareMu.RLock() - mw := s.toolHandlerMiddlewares - - // Apply middlewares in reverse order - for i := len(mw) - 1; i >= 0; i-- { - finalHandler = mw[i](finalHandler) - } - s.toolMiddlewareMu.RUnlock() - - result, err := finalHandler(ctx, request) - if err != nil { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: err, - } - } - - return result, nil -} - -func (s *MCPServer) handleNotification( - ctx context.Context, - notification mcp.JSONRPCNotification, -) mcp.JSONRPCMessage { - s.notificationHandlersMu.RLock() - handler, ok := s.notificationHandlers[notification.Method] - s.notificationHandlersMu.RUnlock() - - if ok { - handler(ctx, notification) - } - return nil -} - -func createResponse(id any, result any) mcp.JSONRPCMessage { - return mcp.NewJSONRPCResultResponse(mcp.NewRequestId(id), result) -} - -func createErrorResponse( - id any, - code int, - message string, -) mcp.JSONRPCMessage { - return mcp.JSONRPCError{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: mcp.NewRequestId(id), - Error: mcp.NewJSONRPCErrorDetails(code, message, nil), - } -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/session.go b/vendor/github.com/mark3labs/mcp-go/server/session.go deleted file mode 100644 index 48fd52d75..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/session.go +++ /dev/null @@ -1,770 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net/url" - - "github.com/mark3labs/mcp-go/mcp" -) - -// ClientSession represents an active session that can be used by MCPServer to interact with client. -type ClientSession interface { - // Initialize marks session as fully initialized and ready for notifications - Initialize() - // Initialized returns if session is ready to accept notifications - Initialized() bool - // NotificationChannel provides a channel suitable for sending notifications to client. - NotificationChannel() chan<- mcp.JSONRPCNotification - // SessionID is a unique identifier used to track user session. - SessionID() string -} - -// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level -type SessionWithLogging interface { - ClientSession - // SetLogLevel sets the minimum log level - SetLogLevel(level mcp.LoggingLevel) - // GetLogLevel retrieves the minimum log level - GetLogLevel() mcp.LoggingLevel -} - -// SessionWithTools is an extension of ClientSession that can store session-specific tool data -type SessionWithTools interface { - ClientSession - // GetSessionTools returns the tools specific to this session, if any - // This method must be thread-safe for concurrent access - GetSessionTools() map[string]ServerTool - // SetSessionTools sets tools specific to this session - // This method must be thread-safe for concurrent access - SetSessionTools(tools map[string]ServerTool) -} - -// SessionWithResources is an extension of ClientSession that can store session-specific resource data -type SessionWithResources interface { - ClientSession - // GetSessionResources returns the resources specific to this session, if any - // This method must be thread-safe for concurrent access - GetSessionResources() map[string]ServerResource - // SetSessionResources sets resources specific to this session - // This method must be thread-safe for concurrent access - SetSessionResources(resources map[string]ServerResource) -} - -// SessionWithResourceTemplates is an extension of ClientSession that can store session-specific resource template data -type SessionWithResourceTemplates interface { - ClientSession - // GetSessionResourceTemplates returns the resource templates specific to this session, if any - // This method must be thread-safe for concurrent access - GetSessionResourceTemplates() map[string]ServerResourceTemplate - // SetSessionResourceTemplates sets resource templates specific to this session - // This method must be thread-safe for concurrent access - SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) -} - -// SessionWithClientInfo is an extension of ClientSession that can store client info -type SessionWithClientInfo interface { - ClientSession - // GetClientInfo returns the client information for this session - GetClientInfo() mcp.Implementation - // SetClientInfo sets the client information for this session - SetClientInfo(clientInfo mcp.Implementation) - // GetClientCapabilities returns the client capabilities for this session - GetClientCapabilities() mcp.ClientCapabilities - // SetClientCapabilities sets the client capabilities for this session - SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) -} - -// SessionWithElicitation is an extension of ClientSession that can send elicitation requests -type SessionWithElicitation interface { - ClientSession - // RequestElicitation sends an elicitation request to the client and waits for response - RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) -} - -// SessionWithRoots is an extension of ClientSession that can send list roots requests -type SessionWithRoots interface { - ClientSession - // ListRoots sends an list roots request to the client and waits for response - ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) -} - -// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations -type SessionWithStreamableHTTPConfig interface { - ClientSession - // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server - // sends notifications to the client - // - // The protocol specification: - // - If the server response contains any JSON-RPC notifications, it MUST either: - // - Return Content-Type: text/event-stream to initiate an SSE stream, OR - // - Return Content-Type: application/json for a single JSON object - // - The client MUST support both response types. - // - // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server - UpgradeToSSEWhenReceiveNotification() -} - -// clientSessionKey is the context key for storing current client notification channel. -type clientSessionKey struct{} - -// ClientSessionFromContext retrieves current client notification context from context. -func ClientSessionFromContext(ctx context.Context) ClientSession { - if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { - return session - } - return nil -} - -// WithContext sets the current client session and returns the provided context -func (s *MCPServer) WithContext( - ctx context.Context, - session ClientSession, -) context.Context { - return context.WithValue(ctx, clientSessionKey{}, session) -} - -// RegisterSession saves session that should be notified in case if some server attributes changed. -func (s *MCPServer) RegisterSession( - ctx context.Context, - session ClientSession, -) error { - sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { - return ErrSessionExists - } - s.hooks.RegisterSession(ctx, session) - return nil -} - -func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification { - return mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: notification.Method, - Params: mcp.NotificationParams{ - AdditionalFields: map[string]any{ - "level": notification.Params.Level, - "logger": notification.Params.Logger, - "data": notification.Params.Data, - }, - }, - }, - } -} - -func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return ErrNotificationNotInitialized - } - sessionLogging, ok := session.(SessionWithLogging) - if !ok { - return ErrSessionDoesNotSupportLogging - } - if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { - return nil - } - return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification)) -} - -func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { - if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { - sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() - } - select { - case session.NotificationChannel() <- notification: - // Successfully sent notification - default: - // Channel is blocked, if there's an error hook, use it - if s.hooks != nil && len(s.hooks.OnError) > 0 { - err := ErrNotificationChannelBlocked - // Copy hooks pointer to local variable to avoid race condition - hooks := s.hooks - go func(sessionID string, hooks *Hooks) { - ctx := context.Background() - // Use the error hook to report the blocked channel - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": notification.Method, - "sessionID": sessionID, - }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) - }(session.SessionID(), hooks) - } - } - } - return true - }) -} - -func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error { - // upgrades the client-server communication to SSE stream when the server sends notifications to the client - if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { - sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() - } - select { - case session.NotificationChannel() <- notification: - return nil - default: - // Channel is blocked, if there's an error hook, use it - if s.hooks != nil && len(s.hooks.OnError) > 0 { - err := ErrNotificationChannelBlocked - ctx := context.Background() - // Copy hooks pointer to local variable to avoid race condition - hooks := s.hooks - go func(sID string, hooks *Hooks) { - // Use the error hook to report the blocked channel - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": notification.Method, - "sessionID": sID, - }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) - }(session.SessionID(), hooks) - } - return ErrNotificationChannelBlocked - } -} - -func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - session, ok := sessionValue.(ClientSession) - if !ok || !session.Initialized() { - return ErrSessionNotInitialized - } - sessionLogging, ok := session.(SessionWithLogging) - if !ok { - return ErrSessionDoesNotSupportLogging - } - if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { - return nil - } - return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification)) -} - -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - ctx context.Context, - sessionID string, -) { - sessionValue, ok := s.sessions.LoadAndDelete(sessionID) - if !ok { - return - } - if session, ok := sessionValue.(ClientSession); ok { - s.hooks.UnregisterSession(ctx, session) - } -} - -// SendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) SendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - s.sendNotificationToAllClients(notification) -} - -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) sendNotificationCore( - ctx context.Context, - session ClientSession, - notification mcp.JSONRPCNotification, -) error { - // upgrades the client-server communication to SSE stream when the server sends notifications to the client - if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { - sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() - } - select { - case session.NotificationChannel() <- notification: - return nil - default: - // Channel is blocked, if there's an error hook, use it - if s.hooks != nil && len(s.hooks.OnError) > 0 { - method := notification.Method - err := ErrNotificationChannelBlocked - // Copy hooks pointer to local variable to avoid race condition - hooks := s.hooks - go func(sessionID string, hooks *Hooks) { - // Use the error hook to report the blocked channel - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": method, - "sessionID": sessionID, - }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) - }(session.SessionID(), hooks) - } - return ErrNotificationChannelBlocked - } -} - -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return ErrNotificationNotInitialized - } - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - return s.sendNotificationCore(ctx, session, notification) -} - -// SendNotificationToSpecificClient sends a notification to a specific client by session ID -func (s *MCPServer) SendNotificationToSpecificClient( - sessionID string, - method string, - params map[string]any, -) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - session, ok := sessionValue.(ClientSession) - if !ok || !session.Initialized() { - return ErrSessionNotInitialized - } - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - return s.sendNotificationToSpecificClient(session, notification) -} - -// AddSessionTool adds a tool for a specific session -func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { - return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) -} - -// AddSessionTools adds tools for a specific session -func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithTools) - if !ok { - return ErrSessionDoesNotSupportTools - } - - s.implicitlyRegisterToolCapabilities() - - // Get existing tools (this should return a thread-safe copy) - sessionTools := session.GetSessionTools() - - // Create a new map to avoid concurrent modification issues - newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) - - // Copy existing tools - for k, v := range sessionTools { - newSessionTools[k] = v - } - - // Add new tools - for _, tool := range tools { - newSessionTools[tool.Tool.Name] = tool - } - - // Set the tools (this should be thread-safe) - session.SetSessionTools(newSessionTools) - - // It only makes sense to send tool notifications to initialized sessions -- - // if we're not initialized yet the client can't possibly have sent their - // initial tools/list message. - // - // For initialized sessions, honor tools.listChanged, which is specifically - // about whether notifications will be sent or not. - // see - if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { - // Send notification only to this session - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { - // Log the error but don't fail the operation - // The tools were successfully added, but notification failed - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/tools/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after adding tools: %w", err)) - }(sessionID, hooks) - } - } - } - - return nil -} - -// DeleteSessionTools removes tools from a specific session -func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithTools) - if !ok { - return ErrSessionDoesNotSupportTools - } - - // Get existing tools (this should return a thread-safe copy) - sessionTools := session.GetSessionTools() - if sessionTools == nil { - return nil - } - - // Create a new map to avoid concurrent modification issues - newSessionTools := make(map[string]ServerTool, len(sessionTools)) - - // Copy existing tools except those being deleted - for k, v := range sessionTools { - newSessionTools[k] = v - } - - // Remove specified tools - for _, name := range names { - delete(newSessionTools, name) - } - - // Set the tools (this should be thread-safe) - session.SetSessionTools(newSessionTools) - - // It only makes sense to send tool notifications to initialized sessions -- - // if we're not initialized yet the client can't possibly have sent their - // initial tools/list message. - // - // For initialized sessions, honor tools.listChanged, which is specifically - // about whether notifications will be sent or not. - // see - if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { - // Send notification only to this session - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { - // Log the error but don't fail the operation - // The tools were successfully deleted, but notification failed - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/tools/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) - }(sessionID, hooks) - } - } - } - - return nil -} - -// AddSessionResource adds a resource for a specific session -func (s *MCPServer) AddSessionResource(sessionID string, resource mcp.Resource, handler ResourceHandlerFunc) error { - return s.AddSessionResources(sessionID, ServerResource{Resource: resource, Handler: handler}) -} - -// AddSessionResources adds resources for a specific session -func (s *MCPServer) AddSessionResources(sessionID string, resources ...ServerResource) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithResources) - if !ok { - return ErrSessionDoesNotSupportResources - } - - // For session resources, we want listChanged enabled by default - s.implicitlyRegisterCapabilities( - func() bool { return s.capabilities.resources != nil }, - func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} }, - ) - - // Get existing resources (this should return a thread-safe copy) - sessionResources := session.GetSessionResources() - - // Create a new map to avoid concurrent modification issues - newSessionResources := make(map[string]ServerResource, len(sessionResources)+len(resources)) - - // Copy existing resources - for k, v := range sessionResources { - newSessionResources[k] = v - } - - // Add new resources with validation - for _, resource := range resources { - // Validate that URI is non-empty - if resource.Resource.URI == "" { - return fmt.Errorf("resource URI cannot be empty") - } - - // Validate that URI conforms to RFC 3986 - if _, err := url.ParseRequestURI(resource.Resource.URI); err != nil { - return fmt.Errorf("invalid resource URI: %w", err) - } - - newSessionResources[resource.Resource.URI] = resource - } - - // Set the resources (this should be thread-safe) - session.SetSessionResources(newSessionResources) - - // It only makes sense to send resource notifications to initialized sessions -- - // if we're not initialized yet the client can't possibly have sent their - // initial resources/list message. - // - // For initialized sessions, honor resources.listChanged, which is specifically - // about whether notifications will be sent or not. - // see - if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - // Send notification only to this session - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { - // Log the error but don't fail the operation - // The resources were successfully added, but notification failed - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/resources/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after adding resources: %w", err)) - }(sessionID, hooks) - } - } - } - - return nil -} - -// DeleteSessionResources removes resources from a specific session -func (s *MCPServer) DeleteSessionResources(sessionID string, uris ...string) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithResources) - if !ok { - return ErrSessionDoesNotSupportResources - } - - // Get existing resources (this should return a thread-safe copy) - sessionResources := session.GetSessionResources() - if sessionResources == nil { - return nil - } - - // Create a new map to avoid concurrent modification issues - newSessionResources := make(map[string]ServerResource, len(sessionResources)) - - // Copy existing resources except those being deleted - for k, v := range sessionResources { - newSessionResources[k] = v - } - - // Remove specified resources and track if anything was actually deleted - actuallyDeleted := false - for _, uri := range uris { - if _, exists := newSessionResources[uri]; exists { - delete(newSessionResources, uri) - actuallyDeleted = true - } - } - - // Skip no-op write if nothing was actually deleted - if !actuallyDeleted { - return nil - } - - // Set the resources (this should be thread-safe) - session.SetSessionResources(newSessionResources) - - // It only makes sense to send resource notifications to initialized sessions -- - // if we're not initialized yet the client can't possibly have sent their - // initial resources/list message. - // - // For initialized sessions, honor resources.listChanged, which is specifically - // about whether notifications will be sent or not. - // see - // Only send notification if something was actually deleted - if actuallyDeleted && session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - // Send notification only to this session - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { - // Log the error but don't fail the operation - // The resources were successfully deleted, but notification failed - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/resources/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after deleting resources: %w", err)) - }(sessionID, hooks) - } - } - } - - return nil -} - -// AddSessionResourceTemplate adds a resource template for a specific session -func (s *MCPServer) AddSessionResourceTemplate(sessionID string, template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc) error { - return s.AddSessionResourceTemplates(sessionID, ServerResourceTemplate{ - Template: template, - Handler: handler, - }) -} - -// AddSessionResourceTemplates adds resource templates for a specific session -func (s *MCPServer) AddSessionResourceTemplates(sessionID string, templates ...ServerResourceTemplate) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithResourceTemplates) - if !ok { - return ErrSessionDoesNotSupportResourceTemplates - } - - // For session resource templates, enable listChanged by default - // This is the same behavior as session resources - s.implicitlyRegisterCapabilities( - func() bool { return s.capabilities.resources != nil }, - func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} }, - ) - - // Get existing templates (this returns a thread-safe copy) - sessionTemplates := session.GetSessionResourceTemplates() - - // Create a new map to avoid modifying the returned copy - newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)+len(templates)) - - // Copy existing templates - for k, v := range sessionTemplates { - newTemplates[k] = v - } - - // Validate and add new templates - for _, t := range templates { - if t.Template.URITemplate == nil { - return fmt.Errorf("resource template URITemplate cannot be nil") - } - raw := t.Template.URITemplate.Raw() - if raw == "" { - return fmt.Errorf("resource template URITemplate cannot be empty") - } - if t.Template.Name == "" { - return fmt.Errorf("resource template name cannot be empty") - } - newTemplates[raw] = t - } - - // Set the new templates (this method must handle thread-safety) - session.SetSessionResourceTemplates(newTemplates) - - // Send notification if the session is initialized and listChanged is enabled - if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { - // Log the error but don't fail the operation - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/resources/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after adding resource templates: %w", err)) - }(sessionID, hooks) - } - } - } - - return nil -} - -// DeleteSessionResourceTemplates removes resource templates from a specific session -func (s *MCPServer) DeleteSessionResourceTemplates(sessionID string, uriTemplates ...string) error { - sessionValue, ok := s.sessions.Load(sessionID) - if !ok { - return ErrSessionNotFound - } - - session, ok := sessionValue.(SessionWithResourceTemplates) - if !ok { - return ErrSessionDoesNotSupportResourceTemplates - } - - // Get existing templates (this returns a thread-safe copy) - sessionTemplates := session.GetSessionResourceTemplates() - - // Track if any were actually deleted - deletedAny := false - - // Create a new map without the deleted templates - newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)) - for k, v := range sessionTemplates { - newTemplates[k] = v - } - - // Delete specified templates - for _, uriTemplate := range uriTemplates { - if _, exists := newTemplates[uriTemplate]; exists { - delete(newTemplates, uriTemplate) - deletedAny = true - } - } - - // Only update if something was actually deleted - if deletedAny { - // Set the new templates (this method must handle thread-safety) - session.SetSessionResourceTemplates(newTemplates) - - // Send notification if the session is initialized and listChanged is enabled - if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil { - // Log the error but don't fail the operation - if s.hooks != nil && len(s.hooks.OnError) > 0 { - hooks := s.hooks - go func(sID string, hooks *Hooks) { - ctx := context.Background() - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": "notifications/resources/list_changed", - "sessionID": sID, - }, fmt.Errorf("failed to send notification after deleting resource templates: %w", err)) - }(sessionID, hooks) - } - } - } - } - - return nil -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go deleted file mode 100644 index 97c765cc7..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/sse.go +++ /dev/null @@ -1,797 +0,0 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "net/http/httptest" - "net/url" - "path" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" - - "github.com/mark3labs/mcp-go/mcp" -) - -// sseSession represents an active SSE connection. -type sseSession struct { - done chan struct{} - eventQueue chan string // Channel for queuing events - sessionID string - requestID atomic.Int64 - notificationChannel chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - tools sync.Map // stores session-specific tools - resources sync.Map // stores session-specific resources - resourceTemplates sync.Map // stores session-specific resource templates - clientInfo atomic.Value // stores session-specific client info - clientCapabilities atomic.Value // stores session-specific client capabilities -} - -// SSEContextFunc is a function that takes an existing context and the current -// request and returns a potentially modified context based on the request -// content. This can be used to inject context values from headers, for example. -type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context - -// DynamicBasePathFunc allows the user to provide a function to generate the -// base path for a given request and sessionID. This is useful for cases where -// the base path is not known at the time of SSE server creation, such as when -// using a reverse proxy or when the base path is dynamically generated. The -// function should return the base path (e.g., "/mcp/tenant123"). -type DynamicBasePathFunc func(r *http.Request, sessionID string) string - -func (s *sseSession) SessionID() string { - return s.sessionID -} - -func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - return s.notificationChannel -} - -func (s *sseSession) Initialize() { - // set default logging level - s.loggingLevel.Store(mcp.LoggingLevelError) - s.initialized.Store(true) -} - -func (s *sseSession) Initialized() bool { - return s.initialized.Load() -} - -func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) { - s.loggingLevel.Store(level) -} - -func (s *sseSession) GetLogLevel() mcp.LoggingLevel { - level := s.loggingLevel.Load() - if level == nil { - return mcp.LoggingLevelError - } - return level.(mcp.LoggingLevel) -} - -func (s *sseSession) GetSessionResources() map[string]ServerResource { - resources := make(map[string]ServerResource) - s.resources.Range(func(key, value any) bool { - if resource, ok := value.(ServerResource); ok { - resources[key.(string)] = resource - } - return true - }) - return resources -} - -func (s *sseSession) SetSessionResources(resources map[string]ServerResource) { - // Clear existing resources - s.resources.Clear() - - // Set new resources - for name, resource := range resources { - s.resources.Store(name, resource) - } -} - -func (s *sseSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { - templates := make(map[string]ServerResourceTemplate) - s.resourceTemplates.Range(func(key, value any) bool { - if template, ok := value.(ServerResourceTemplate); ok { - templates[key.(string)] = template - } - return true - }) - return templates -} - -func (s *sseSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { - // Clear existing templates - s.resourceTemplates.Clear() - - // Set new templates - for uriTemplate, template := range templates { - s.resourceTemplates.Store(uriTemplate, template) - } -} - -func (s *sseSession) GetSessionTools() map[string]ServerTool { - tools := make(map[string]ServerTool) - s.tools.Range(func(key, value any) bool { - if tool, ok := value.(ServerTool); ok { - tools[key.(string)] = tool - } - return true - }) - return tools -} - -func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { - // Clear existing tools - s.tools.Clear() - - // Set new tools - for name, tool := range tools { - s.tools.Store(name, tool) - } -} - -func (s *sseSession) GetClientInfo() mcp.Implementation { - if value := s.clientInfo.Load(); value != nil { - if clientInfo, ok := value.(mcp.Implementation); ok { - return clientInfo - } - } - return mcp.Implementation{} -} - -func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { - s.clientInfo.Store(clientInfo) -} - -func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { - s.clientCapabilities.Store(clientCapabilities) -} - -func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { - if value := s.clientCapabilities.Load(); value != nil { - if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { - return clientCapabilities - } - } - return mcp.ClientCapabilities{} -} - -var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) - _ SessionWithResources = (*sseSession)(nil) - _ SessionWithResourceTemplates = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) - _ SessionWithClientInfo = (*sseSession)(nil) -) - -// SSEServer implements a Server-Sent Events (SSE) based MCP server. -// It provides real-time communication capabilities over HTTP using the SSE protocol. -type SSEServer struct { - server *MCPServer - baseURL string - basePath string - appendQueryToMessageEndpoint bool - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc - dynamicBasePathFunc DynamicBasePathFunc - - keepAlive bool - keepAliveInterval time.Duration - - mu sync.RWMutex -} - -// SSEOption defines a function type for configuring SSEServer -type SSEOption func(*SSEServer) - -// WithBaseURL sets the base URL for the SSE server -func WithBaseURL(baseURL string) SSEOption { - return func(s *SSEServer) { - if baseURL != "" { - u, err := url.Parse(baseURL) - if err != nil { - return - } - if u.Scheme != "http" && u.Scheme != "https" { - return - } - // Check if the host is empty or only contains a port - if u.Host == "" || strings.HasPrefix(u.Host, ":") { - return - } - if len(u.Query()) > 0 { - return - } - } - s.baseURL = strings.TrimSuffix(baseURL, "/") - } -} - -// WithStaticBasePath adds a new option for setting a static base path -func WithStaticBasePath(basePath string) SSEOption { - return func(s *SSEServer) { - s.basePath = normalizeURLPath(basePath) - } -} - -// WithBasePath adds a new option for setting a static base path. -// -// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. -// -//go:deprecated -func WithBasePath(basePath string) SSEOption { - return WithStaticBasePath(basePath) -} - -// WithDynamicBasePath accepts a function for generating the base path. This is -// useful for cases where the base path is not known at the time of SSE server -// creation, such as when using a reverse proxy or when the server is mounted -// at a dynamic path. -func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { - return func(s *SSEServer) { - if fn != nil { - s.dynamicBasePathFunc = func(r *http.Request, sid string) string { - bp := fn(r, sid) - return normalizeURLPath(bp) - } - } - } -} - -// WithMessageEndpoint sets the message endpoint path -func WithMessageEndpoint(endpoint string) SSEOption { - return func(s *SSEServer) { - s.messageEndpoint = endpoint - } -} - -// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's -// query parameters to the message endpoint URL that is sent to clients during the SSE connection -// initialization. This is useful when you need to preserve query parameters from the initial -// SSE connection request and carry them over to subsequent message requests, maintaining -// context or authentication details across the communication channel. -func WithAppendQueryToMessageEndpoint() SSEOption { - return func(s *SSEServer) { - s.appendQueryToMessageEndpoint = true - } -} - -// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) -// or just the path portion for the message endpoint. Set to false when clients will concatenate -// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". -func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { - return func(s *SSEServer) { - s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint - } -} - -// WithSSEEndpoint sets the SSE endpoint path -func WithSSEEndpoint(endpoint string) SSEOption { - return func(s *SSEServer) { - s.sseEndpoint = endpoint - } -} - -// WithHTTPServer sets the HTTP server instance. -// NOTE: When providing a custom HTTP server, you must handle routing yourself -// If routing is not set up, the server will start but won't handle any MCP requests. -func WithHTTPServer(srv *http.Server) SSEOption { - return func(s *SSEServer) { - s.srv = srv - } -} - -func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption { - return func(s *SSEServer) { - s.keepAlive = true - s.keepAliveInterval = keepAliveInterval - } -} - -func WithKeepAlive(keepAlive bool) SSEOption { - return func(s *SSEServer) { - s.keepAlive = keepAlive - } -} - -// WithSSEContextFunc sets a function that will be called to customise the context -// to the server using the incoming request. -func WithSSEContextFunc(fn SSEContextFunc) SSEOption { - return func(s *SSEServer) { - s.contextFunc = fn - } -} - -// NewSSEServer creates a new SSE server instance with the given MCP server and options. -func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { - s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - useFullURLForMessageEndpoint: true, - keepAlive: false, - keepAliveInterval: 10 * time.Second, - } - - // Apply all options - for _, opt := range opts { - opt(s) - } - - return s -} - -// NewTestServer creates a test server for testing purposes -func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { - sseServer := NewSSEServer(server, opts...) - - testServer := httptest.NewServer(sseServer) - sseServer.baseURL = testServer.URL - return testServer -} - -// Start begins serving SSE connections on the specified address. -// It sets up HTTP handlers for SSE and message endpoints. -func (s *SSEServer) Start(addr string) error { - s.mu.Lock() - if s.srv == nil { - s.srv = &http.Server{ - Addr: addr, - Handler: s, - } - } else { - if s.srv.Addr == "" { - s.srv.Addr = addr - } else if s.srv.Addr != addr { - return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr) - } - } - srv := s.srv - s.mu.Unlock() - - return srv.ListenAndServe() -} - -// Shutdown gracefully stops the SSE server, closing all active sessions -// and shutting down the HTTP server. -func (s *SSEServer) Shutdown(ctx context.Context) error { - s.mu.RLock() - srv := s.srv - s.mu.RUnlock() - - if srv != nil { - s.sessions.Range(func(key, value any) bool { - if session, ok := value.(*sseSession); ok { - close(session.done) - } - s.sessions.Delete(key) - return true - }) - - return srv.Shutdown(ctx) - } - return nil -} - -// handleSSE handles incoming SSE connection requests. -// It sets up appropriate headers and creates a new session for the client. -func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported", http.StatusInternalServerError) - return - } - - sessionID := uuid.New().String() - session := &sseSession{ - done: make(chan struct{}), - eventQueue: make(chan string, 100), // Buffer for events - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - } - - s.sessions.Store(sessionID, session) - defer s.sessions.Delete(sessionID) - - if err := s.server.RegisterSession(r.Context(), session); err != nil { - http.Error( - w, - fmt.Sprintf("Session registration failed: %v", err), - http.StatusInternalServerError, - ) - return - } - defer s.server.UnregisterSession(r.Context(), sessionID) - - // Start notification handler for this session - go func() { - for { - select { - case notification := <-session.notificationChannel: - eventData, err := json.Marshal(notification) - if err == nil { - select { - case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): - // Event queued successfully - case <-session.done: - return - } - } - case <-session.done: - return - case <-r.Context().Done(): - return - } - } - }() - - // Start keep alive : ping - if s.keepAlive { - go func() { - ticker := time.NewTicker(s.keepAliveInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - message := mcp.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcp.NewRequestId(session.requestID.Add(1)), - Request: mcp.Request{ - Method: "ping", - }, - } - messageBytes, _ := json.Marshal(message) - pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) - select { - case session.eventQueue <- pingMsg: - // Message sent successfully - case <-session.done: - return - } - case <-session.done: - return - case <-r.Context().Done(): - return - } - } - }() - } - - // Send the initial endpoint event - endpoint := s.GetMessageEndpointForClient(r, sessionID) - if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 { - endpoint += "&" + r.URL.RawQuery - } - fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint) - flusher.Flush() - - // Main event loop - this runs in the HTTP handler goroutine - for { - select { - case event := <-session.eventQueue: - // Write the event to the response - fmt.Fprint(w, event) - flusher.Flush() - case <-r.Context().Done(): - close(session.done) - return - case <-session.done: - return - } - } -} - -// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID -// for the given request. This is the canonical way to compute the message endpoint for a client. -// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag. -func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string { - basePath := s.basePath - if s.dynamicBasePathFunc != nil { - basePath = s.dynamicBasePathFunc(r, sessionID) - } - - endpointPath := normalizeURLPath(basePath, s.messageEndpoint) - if s.useFullURLForMessageEndpoint && s.baseURL != "" { - endpointPath = s.baseURL + endpointPath - } - - return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) -} - -// handleMessage processes incoming JSON-RPC messages from clients and sends responses -// back through the SSE connection and 202 code to HTTP response. -func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") - return - } - - sessionID := r.URL.Query().Get("sessionId") - if sessionID == "" { - s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") - return - } - sessionI, ok := s.sessions.Load(sessionID) - if !ok { - s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") - return - } - session := sessionI.(*sseSession) - - // Set the client context before handling the message - ctx := s.server.WithContext(r.Context(), session) - if s.contextFunc != nil { - ctx = s.contextFunc(ctx, r) - } - - // Parse message as raw JSON - var rawMessage json.RawMessage - if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { - s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") - return - } - - // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. - // this is required because the http ctx will be canceled when the client disconnects - detachedCtx := context.WithoutCancel(ctx) - - // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE - w.WriteHeader(http.StatusAccepted) - - // Create a new context for handling the message that will be canceled when the message handling is done - messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header) - messageCtx, cancel := context.WithCancel(messageCtx) - - go func(ctx context.Context) { - defer cancel() - // Use the context that will be canceled when session is done - // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawMessage) - // Only send response if there is one (not for notifications) - if response != nil { - var message string - if eventData, err := json.Marshal(response); err != nil { - // If there is an error marshalling the response, send a generic error response - log.Printf("failed to marshal response: %v", err) - message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" - } else { - message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) - } - - // Queue the event for sending via SSE - select { - case session.eventQueue <- message: - // Event queued successfully - case <-session.done: - // Session is closed, don't try to queue - default: - // Queue is full, log this situation - log.Printf("Event queue full for session %s", sessionID) - } - } - }(messageCtx) -} - -// writeJSONRPCError writes a JSON-RPC error response with the given error details. -func (s *SSEServer) writeJSONRPCError( - w http.ResponseWriter, - id any, - code int, - message string, -) { - response := createErrorResponse(id, code, message) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error( - w, - fmt.Sprintf("Failed to encode response: %v", err), - http.StatusInternalServerError, - ) - return - } -} - -// SendEventToSession sends an event to a specific SSE session identified by sessionID. -// Returns an error if the session is not found or closed. -func (s *SSEServer) SendEventToSession( - sessionID string, - event any, -) error { - sessionI, ok := s.sessions.Load(sessionID) - if !ok { - return fmt.Errorf("session not found: %s", sessionID) - } - session := sessionI.(*sseSession) - - eventData, err := json.Marshal(event) - if err != nil { - return err - } - - // Queue the event for sending via SSE - select { - case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): - return nil - case <-session.done: - return fmt.Errorf("session closed") - default: - return fmt.Errorf("event queue full") - } -} - -func (s *SSEServer) GetUrlPath(input string) (string, error) { - parse, err := url.Parse(input) - if err != nil { - return "", fmt.Errorf("failed to parse URL %s: %w", input, err) - } - return parse.Path, nil -} - -func (s *SSEServer) CompleteSseEndpoint() (string, error) { - if s.dynamicBasePathFunc != nil { - return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} - } - - path := normalizeURLPath(s.basePath, s.sseEndpoint) - return s.baseURL + path, nil -} - -func (s *SSEServer) CompleteSsePath() string { - path, err := s.CompleteSseEndpoint() - if err != nil { - return normalizeURLPath(s.basePath, s.sseEndpoint) - } - urlPath, err := s.GetUrlPath(path) - if err != nil { - return normalizeURLPath(s.basePath, s.sseEndpoint) - } - return urlPath -} - -func (s *SSEServer) CompleteMessageEndpoint() (string, error) { - if s.dynamicBasePathFunc != nil { - return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} - } - path := normalizeURLPath(s.basePath, s.messageEndpoint) - return s.baseURL + path, nil -} - -func (s *SSEServer) CompleteMessagePath() string { - path, err := s.CompleteMessageEndpoint() - if err != nil { - return normalizeURLPath(s.basePath, s.messageEndpoint) - } - urlPath, err := s.GetUrlPath(path) - if err != nil { - return normalizeURLPath(s.basePath, s.messageEndpoint) - } - return urlPath -} - -// SSEHandler returns an http.Handler for the SSE endpoint. -// -// This method allows you to mount the SSE handler at any arbitrary path -// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is -// intended for advanced scenarios where you want to control the routing or -// support dynamic segments. -// -// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, -// you must use the WithDynamicBasePath option to ensure the correct base path -// is communicated to clients. -// -// Example usage: -// -// // Advanced/dynamic: -// sseServer := NewSSEServer(mcpServer, -// WithDynamicBasePath(func(r *http.Request, sessionID string) string { -// tenant := r.PathValue("tenant") -// return "/mcp/" + tenant -// }), -// WithBaseURL("http://localhost:8080") -// ) -// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) -// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) -// -// For non-dynamic cases, use ServeHTTP method instead. -func (s *SSEServer) SSEHandler() http.Handler { - return http.HandlerFunc(s.handleSSE) -} - -// MessageHandler returns an http.Handler for the message endpoint. -// -// This method allows you to mount the message handler at any arbitrary path -// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is -// intended for advanced scenarios where you want to control the routing or -// support dynamic segments. -// -// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, -// you must use the WithDynamicBasePath option to ensure the correct base path -// is communicated to clients. -// -// Example usage: -// -// // Advanced/dynamic: -// sseServer := NewSSEServer(mcpServer, -// WithDynamicBasePath(func(r *http.Request, sessionID string) string { -// tenant := r.PathValue("tenant") -// return "/mcp/" + tenant -// }), -// WithBaseURL("http://localhost:8080") -// ) -// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) -// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) -// -// For non-dynamic cases, use ServeHTTP method instead. -func (s *SSEServer) MessageHandler() http.Handler { - return http.HandlerFunc(s.handleMessage) -} - -// ServeHTTP implements the http.Handler interface. -func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if s.dynamicBasePathFunc != nil { - http.Error( - w, - (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), - http.StatusInternalServerError, - ) - return - } - path := r.URL.Path - // Use exact path matching rather than Contains - ssePath := s.CompleteSsePath() - if ssePath != "" && path == ssePath { - s.handleSSE(w, r) - return - } - messagePath := s.CompleteMessagePath() - if messagePath != "" && path == messagePath { - s.handleMessage(w, r) - return - } - - http.NotFound(w, r) -} - -// normalizeURLPath joins path elements like path.Join but ensures the -// result always starts with a leading slash and never ends with a slash -func normalizeURLPath(elem ...string) string { - joined := path.Join(elem...) - - // Ensure leading slash - if !strings.HasPrefix(joined, "/") { - joined = "/" + joined - } - - // Remove trailing slash if not just "/" - if len(joined) > 1 && strings.HasSuffix(joined, "/") { - joined = joined[:len(joined)-1] - } - - return joined -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go deleted file mode 100644 index f5c8ddfd2..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/stdio.go +++ /dev/null @@ -1,877 +0,0 @@ -package server - -import ( - "bufio" - "context" - "encoding/json" - "fmt" - "io" - "log" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "github.com/mark3labs/mcp-go/mcp" -) - -// StdioContextFunc is a function that takes an existing context and returns -// a potentially modified context. -// This can be used to inject context values from environment variables, -// for example. -type StdioContextFunc func(ctx context.Context) context.Context - -// StdioServer wraps a MCPServer and handles stdio communication. -// It provides a simple way to create command-line MCP servers that -// communicate via standard input/output streams using JSON-RPC messages. -type StdioServer struct { - server *MCPServer - errLogger *log.Logger - contextFunc StdioContextFunc - - // Thread-safe tool call processing - toolCallQueue chan *toolCallWork - workerWg sync.WaitGroup - workerPoolSize int - queueSize int - writeMu sync.Mutex // Protects concurrent writes -} - -// toolCallWork represents a queued tool call request -type toolCallWork struct { - ctx context.Context - message json.RawMessage - writer io.Writer -} - -// StdioOption defines a function type for configuring StdioServer -type StdioOption func(*StdioServer) - -// WithErrorLogger sets the error logger for the server -func WithErrorLogger(logger *log.Logger) StdioOption { - return func(s *StdioServer) { - s.errLogger = logger - } -} - -// WithStdioContextFunc sets a function that will be called to customise the context -// to the server. Note that the stdio server uses the same context for all requests, -// so this function will only be called once per server instance. -func WithStdioContextFunc(fn StdioContextFunc) StdioOption { - return func(s *StdioServer) { - s.contextFunc = fn - } -} - -// WithWorkerPoolSize sets the number of workers for processing tool calls -func WithWorkerPoolSize(size int) StdioOption { - return func(s *StdioServer) { - const maxWorkerPoolSize = 100 - if size > 0 && size <= maxWorkerPoolSize { - s.workerPoolSize = size - } else if size > maxWorkerPoolSize { - s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize) - s.workerPoolSize = maxWorkerPoolSize - } - } -} - -// WithQueueSize sets the size of the tool call queue -func WithQueueSize(size int) StdioOption { - return func(s *StdioServer) { - const maxQueueSize = 10000 - if size > 0 && size <= maxQueueSize { - s.queueSize = size - } else if size > maxQueueSize { - s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize) - s.queueSize = maxQueueSize - } - } -} - -// stdioSession is a static client session, since stdio has only one client. -type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info - clientCapabilities atomic.Value // stores session-specific client capabilities - writer io.Writer // for sending requests to client - requestID atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects writer - pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests - pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests - pendingRoots map[int64]chan *rootsResponse // for tracking pending list roots requests - pendingMu sync.RWMutex // protects pendingRequests and pendingElicitations -} - -// samplingResponse represents a response to a sampling request -type samplingResponse struct { - result *mcp.CreateMessageResult - err error -} - -// elicitationResponse represents a response to an elicitation request -type elicitationResponse struct { - result *mcp.ElicitationResult - err error -} - -// rootsResponse represents a response to an list root request -type rootsResponse struct { - result *mcp.ListRootsResult - err error -} - -func (s *stdioSession) SessionID() string { - return "stdio" -} - -func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - return s.notifications -} - -func (s *stdioSession) Initialize() { - // set default logging level - s.loggingLevel.Store(mcp.LoggingLevelError) - s.initialized.Store(true) -} - -func (s *stdioSession) Initialized() bool { - return s.initialized.Load() -} - -func (s *stdioSession) GetClientInfo() mcp.Implementation { - if value := s.clientInfo.Load(); value != nil { - if clientInfo, ok := value.(mcp.Implementation); ok { - return clientInfo - } - } - return mcp.Implementation{} -} - -func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { - s.clientInfo.Store(clientInfo) -} - -func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { - if value := s.clientCapabilities.Load(); value != nil { - if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { - return clientCapabilities - } - } - return mcp.ClientCapabilities{} -} - -func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { - s.clientCapabilities.Store(clientCapabilities) -} - -func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { - s.loggingLevel.Store(level) -} - -func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { - level := s.loggingLevel.Load() - if level == nil { - return mcp.LoggingLevelError - } - return level.(mcp.LoggingLevel) -} - -// RequestSampling sends a sampling request to the client and waits for the response. -func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { - s.mu.RLock() - writer := s.writer - s.mu.RUnlock() - - if writer == nil { - return nil, fmt.Errorf("no writer available for sending requests") - } - - // Generate a unique request ID - id := s.requestID.Add(1) - - // Create a response channel for this request - responseChan := make(chan *samplingResponse, 1) - s.pendingMu.Lock() - s.pendingRequests[id] = responseChan - s.pendingMu.Unlock() - - // Cleanup function to remove the pending request - cleanup := func() { - s.pendingMu.Lock() - delete(s.pendingRequests, id) - s.pendingMu.Unlock() - } - defer cleanup() - - // Create the JSON-RPC request - jsonRPCRequest := struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - Params mcp.CreateMessageParams `json:"params"` - }{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Method: string(mcp.MethodSamplingCreateMessage), - Params: request.CreateMessageParams, - } - - // Marshal and send the request - requestBytes, err := json.Marshal(jsonRPCRequest) - if err != nil { - return nil, fmt.Errorf("failed to marshal sampling request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := writer.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write sampling request: %w", err) - } - - // Wait for the response or context cancellation - select { - case <-ctx.Done(): - return nil, ctx.Err() - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - return response.result, nil - } -} - -// ListRoots sends an list roots request to the client and waits for the response. -func (s *stdioSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { - s.mu.RLock() - writer := s.writer - s.mu.RUnlock() - - if writer == nil { - return nil, fmt.Errorf("no writer available for sending requests") - } - - // Generate a unique request ID - id := s.requestID.Add(1) - - // Create a response channel for this request - responseChan := make(chan *rootsResponse, 1) - s.pendingMu.Lock() - s.pendingRoots[id] = responseChan - s.pendingMu.Unlock() - - // Cleanup function to remove the pending request - cleanup := func() { - s.pendingMu.Lock() - delete(s.pendingRoots, id) - s.pendingMu.Unlock() - } - defer cleanup() - - // Create the JSON-RPC request - jsonRPCRequest := struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - }{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Method: string(mcp.MethodListRoots), - } - - // Marshal and send the request - requestBytes, err := json.Marshal(jsonRPCRequest) - if err != nil { - return nil, fmt.Errorf("failed to marshal list roots request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := writer.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write list roots request: %w", err) - } - - // Wait for the response or context cancellation - select { - case <-ctx.Done(): - return nil, ctx.Err() - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - return response.result, nil - } -} - -// RequestElicitation sends an elicitation request to the client and waits for the response. -func (s *stdioSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { - s.mu.RLock() - writer := s.writer - s.mu.RUnlock() - - if writer == nil { - return nil, fmt.Errorf("no writer available for sending requests") - } - - // Generate a unique request ID - id := s.requestID.Add(1) - - // Create a response channel for this request - responseChan := make(chan *elicitationResponse, 1) - s.pendingMu.Lock() - s.pendingElicitations[id] = responseChan - s.pendingMu.Unlock() - - // Cleanup function to remove the pending request - cleanup := func() { - s.pendingMu.Lock() - delete(s.pendingElicitations, id) - s.pendingMu.Unlock() - } - defer cleanup() - - // Create the JSON-RPC request - jsonRPCRequest := struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - Params mcp.ElicitationParams `json:"params"` - }{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: id, - Method: string(mcp.MethodElicitationCreate), - Params: request.Params, - } - - // Marshal and send the request - requestBytes, err := json.Marshal(jsonRPCRequest) - if err != nil { - return nil, fmt.Errorf("failed to marshal elicitation request: %w", err) - } - requestBytes = append(requestBytes, '\n') - - if _, err := writer.Write(requestBytes); err != nil { - return nil, fmt.Errorf("failed to write elicitation request: %w", err) - } - - // Wait for the response or context cancellation - select { - case <-ctx.Done(): - return nil, ctx.Err() - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - return response.result, nil - } -} - -// SetWriter sets the writer for sending requests to the client. -func (s *stdioSession) SetWriter(writer io.Writer) { - s.mu.Lock() - defer s.mu.Unlock() - s.writer = writer -} - -var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) - _ SessionWithClientInfo = (*stdioSession)(nil) - _ SessionWithSampling = (*stdioSession)(nil) - _ SessionWithElicitation = (*stdioSession)(nil) - _ SessionWithRoots = (*stdioSession)(nil) -) - -var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), - pendingRequests: make(map[int64]chan *samplingResponse), - pendingElicitations: make(map[int64]chan *elicitationResponse), - pendingRoots: make(map[int64]chan *rootsResponse), -} - -// NewStdioServer creates a new stdio server wrapper around an MCPServer. -// It initializes the server with a default error logger that discards all output. -func NewStdioServer(server *MCPServer) *StdioServer { - return &StdioServer{ - server: server, - errLogger: log.New( - os.Stderr, - "", - log.LstdFlags, - ), // Default to discarding logs - workerPoolSize: 5, // Default worker pool size - queueSize: 100, // Default queue size - } -} - -// SetErrorLogger configures where error messages from the StdioServer are logged. -// The provided logger will receive all error messages generated during server operation. -func (s *StdioServer) SetErrorLogger(logger *log.Logger) { - s.errLogger = logger -} - -// SetContextFunc sets a function that will be called to customise the context -// to the server. Note that the stdio server uses the same context for all requests, -// so this function will only be called once per server instance. -func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { - s.contextFunc = fn -} - -// handleNotifications continuously processes notifications from the session's notification channel -// and writes them to the provided output. It runs until the context is cancelled. -// Any errors encountered while writing notifications are logged but do not stop the handler. -func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { - for { - select { - case notification := <-stdioSessionInstance.notifications: - if err := s.writeResponse(notification, stdout); err != nil { - s.errLogger.Printf("Error writing notification: %v", err) - } - case <-ctx.Done(): - return - } - } -} - -// processInputStream continuously reads and processes messages from the input stream. -// It handles EOF gracefully as a normal termination condition. -// The function returns when either: -// - The context is cancelled (returns context.Err()) -// - EOF is encountered (returns nil) -// - An error occurs while reading or processing messages (returns the error) -func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { - for { - if err := ctx.Err(); err != nil { - return err - } - - line, err := s.readNextLine(ctx, reader) - if err != nil { - if err == io.EOF { - return nil - } - s.errLogger.Printf("Error reading input: %v", err) - return err - } - - if err := s.processMessage(ctx, line, stdout); err != nil { - if err == io.EOF { - return nil - } - s.errLogger.Printf("Error handling message: %v", err) - return err - } - } -} - -// toolCallWorker processes tool calls from the queue -func (s *StdioServer) toolCallWorker(ctx context.Context) { - defer s.workerWg.Done() - - for { - select { - case work, ok := <-s.toolCallQueue: - if !ok { - // Channel closed, exit worker - return - } - // Process the tool call - response := s.server.HandleMessage(work.ctx, work.message) - if response != nil { - if err := s.writeResponse(response, work.writer); err != nil { - s.errLogger.Printf("Error writing tool response: %v", err) - } - } - case <-ctx.Done(): - return - } - } -} - -// readNextLine reads a single line from the input reader in a context-aware manner. -// It uses channels to make the read operation cancellable via context. -// Returns the read line and any error encountered. If the context is cancelled, -// returns an empty string and the context's error. EOF is returned when the input -// stream is closed. -func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { - type result struct { - line string - err error - } - - resultCh := make(chan result, 1) - - go func() { - line, err := reader.ReadString('\n') - resultCh <- result{line: line, err: err} - }() - - select { - case <-ctx.Done(): - return "", nil - case res := <-resultCh: - return res.line, res.err - } -} - -// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. -// It runs until the context is cancelled or an error occurs. -// Returns an error if there are issues with reading input or writing output. -func (s *StdioServer) Listen( - ctx context.Context, - stdin io.Reader, - stdout io.Writer, -) error { - // Initialize the tool call queue - s.toolCallQueue = make(chan *toolCallWork, s.queueSize) - - // Set a static client context since stdio only has one client - if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { - return fmt.Errorf("register session: %w", err) - } - defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) - ctx = s.server.WithContext(ctx, &stdioSessionInstance) - - // Set the writer for sending requests to the client - stdioSessionInstance.SetWriter(stdout) - - // Add in any custom context. - if s.contextFunc != nil { - ctx = s.contextFunc(ctx) - } - - reader := bufio.NewReader(stdin) - - // Start worker pool for tool calls - for i := 0; i < s.workerPoolSize; i++ { - s.workerWg.Add(1) - go s.toolCallWorker(ctx) - } - - // Start notification handler - go s.handleNotifications(ctx, stdout) - - // Process input stream - err := s.processInputStream(ctx, reader, stdout) - - // Shutdown workers gracefully - close(s.toolCallQueue) - s.workerWg.Wait() - - return err -} - -// processMessage handles a single JSON-RPC message and writes the response. -// It parses the message, processes it through the wrapped MCPServer, and writes any response. -// Returns an error if there are issues with message processing or response writing. -func (s *StdioServer) processMessage( - ctx context.Context, - line string, - writer io.Writer, -) error { - // If line is empty, likely due to ctx cancellation - if len(line) == 0 { - return nil - } - - // Parse the message as raw JSON - var rawMessage json.RawMessage - if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { - response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") - return s.writeResponse(response, writer) - } - - // Check if this is a response to a sampling request - if s.handleSamplingResponse(rawMessage) { - return nil - } - - // Check if this is a response to an elicitation request - if s.handleElicitationResponse(rawMessage) { - return nil - } - - // Check if this is a response to an list roots request - if s.handleListRootsResponse(rawMessage) { - return nil - } - - // Check if this is a tool call that might need sampling (and thus should be processed concurrently) - var baseMessage struct { - Method string `json:"method"` - } - if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { - // Queue tool calls for processing by workers - select { - case s.toolCallQueue <- &toolCallWork{ - ctx: ctx, - message: rawMessage, - writer: writer, - }: - return nil - case <-ctx.Done(): - return ctx.Err() - default: - // Queue is full, process synchronously as fallback - s.errLogger.Printf("Tool call queue full, processing synchronously") - response := s.server.HandleMessage(ctx, rawMessage) - if response != nil { - return s.writeResponse(response, writer) - } - return nil - } - } - - // Handle other messages synchronously - response := s.server.HandleMessage(ctx, rawMessage) - - // Only write response if there is one (not for notifications) - if response != nil { - if err := s.writeResponse(response, writer); err != nil { - return fmt.Errorf("failed to write response: %w", err) - } - } - - return nil -} - -// handleSamplingResponse checks if the message is a response to a sampling request -// and routes it to the appropriate pending request channel. -func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { - return stdioSessionInstance.handleSamplingResponse(rawMessage) -} - -// handleSamplingResponse handles incoming sampling responses for this session -func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { - // Try to parse as a JSON-RPC response - var response struct { - JSONRPC string `json:"jsonrpc"` - ID json.Number `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal(rawMessage, &response); err != nil { - return false - } - // Parse the ID as int64 - idInt64, err := response.ID.Int64() - if err != nil || (response.Result == nil && response.Error == nil) { - return false - } - - // Look for a pending request with this ID - s.pendingMu.RLock() - responseChan, exists := s.pendingRequests[idInt64] - s.pendingMu.RUnlock() - - if !exists { - return false - } // Parse and send the response - samplingResp := &samplingResponse{} - - if response.Error != nil { - samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) - } else { - var result mcp.CreateMessageResult - if err := json.Unmarshal(response.Result, &result); err != nil { - samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) - } else { - // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) - if contentMap, ok := result.Content.(map[string]any); ok { - content, err := mcp.ParseContent(contentMap) - if err != nil { - samplingResp.err = fmt.Errorf("failed to parse sampling response content: %w", err) - } else { - result.Content = content - samplingResp.result = &result - } - } else { - samplingResp.result = &result - } - } - } - - // Send the response (non-blocking) - select { - case responseChan <- samplingResp: - default: - // Channel is full or closed, ignore - } - - return true -} - -// handleElicitationResponse checks if the message is a response to an elicitation request -// and routes it to the appropriate pending request channel. -func (s *StdioServer) handleElicitationResponse(rawMessage json.RawMessage) bool { - return stdioSessionInstance.handleElicitationResponse(rawMessage) -} - -// handleElicitationResponse handles incoming elicitation responses for this session -func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) bool { - // Try to parse as a JSON-RPC response - var response struct { - JSONRPC string `json:"jsonrpc"` - ID json.Number `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal(rawMessage, &response); err != nil { - return false - } - // Parse the ID as int64 - id, err := response.ID.Int64() - if err != nil || (response.Result == nil && response.Error == nil) { - return false - } - - // Check if we have a pending elicitation request with this ID - s.pendingMu.RLock() - responseChan, exists := s.pendingElicitations[id] - s.pendingMu.RUnlock() - - if !exists { - return false - } - - // Parse and send the response - elicitationResp := &elicitationResponse{} - - if response.Error != nil { - elicitationResp.err = fmt.Errorf("elicitation request failed: %s", response.Error.Message) - } else { - var result mcp.ElicitationResult - if err := json.Unmarshal(response.Result, &result); err != nil { - elicitationResp.err = fmt.Errorf("failed to unmarshal elicitation response: %w", err) - } else { - elicitationResp.result = &result - } - } - - // Send the response (non-blocking) - select { - case responseChan <- elicitationResp: - default: - // Channel is full or closed, ignore - } - - return true -} - -// handleListRootsResponse checks if the message is a response to an list roots request -// and routes it to the appropriate pending request channel. -func (s *StdioServer) handleListRootsResponse(rawMessage json.RawMessage) bool { - return stdioSessionInstance.handleListRootsResponse(rawMessage) -} - -// handleListRootsResponse handles incoming list root responses for this session -func (s *stdioSession) handleListRootsResponse(rawMessage json.RawMessage) bool { - // Try to parse as a JSON-RPC response - var response struct { - JSONRPC string `json:"jsonrpc"` - ID json.Number `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` - } - - if err := json.Unmarshal(rawMessage, &response); err != nil { - return false - } - // Parse the ID as int64 - id, err := response.ID.Int64() - if err != nil || (response.Result == nil && response.Error == nil) { - return false - } - - // Check if we have a pending list root request with this ID - s.pendingMu.RLock() - responseChan, exists := s.pendingRoots[id] - s.pendingMu.RUnlock() - - if !exists { - return false - } - - // Parse and send the response - rootsResp := &rootsResponse{} - - if response.Error != nil { - rootsResp.err = fmt.Errorf("list root request failed: %s", response.Error.Message) - } else { - var result mcp.ListRootsResult - if err := json.Unmarshal(response.Result, &result); err != nil { - rootsResp.err = fmt.Errorf("failed to unmarshal list root response: %w", err) - } else { - rootsResp.result = &result - } - } - - // Send the response (non-blocking) - select { - case responseChan <- rootsResp: - default: - // Channel is full or closed, ignore - } - - return true -} - -// writeResponse marshals and writes a JSON-RPC response message followed by a newline. -// Returns an error if marshaling or writing fails. -func (s *StdioServer) writeResponse( - response mcp.JSONRPCMessage, - writer io.Writer, -) error { - responseBytes, err := json.Marshal(response) - if err != nil { - return err - } - - // Protect concurrent writes - s.writeMu.Lock() - defer s.writeMu.Unlock() - - // Write response followed by newline - if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { - return err - } - - return nil -} - -// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. -// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. -// Returns an error if the server encounters any issues during operation. -func ServeStdio(server *MCPServer, opts ...StdioOption) error { - s := NewStdioServer(server) - - for _, opt := range opts { - opt(s) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Set up signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) - - go func() { - <-sigChan - cancel() - }() - - return s.Listen(ctx, os.Stdin, os.Stdout) -} diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go deleted file mode 100644 index 4535943da..000000000 --- a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go +++ /dev/null @@ -1,1434 +0,0 @@ -package server - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "maps" - "mime" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync" - "sync/atomic" - "time" - "unicode" - - "github.com/google/uuid" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/util" -) - -// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer -type StreamableHTTPOption func(*StreamableHTTPServer) - -// WithEndpointPath sets the endpoint path for the server. -// The default is "/mcp". -// It's only works for `Start` method. When used as a http.Handler, it has no effect. -func WithEndpointPath(endpointPath string) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one - normalizedPath := "/" + strings.Trim(endpointPath, "/") - s.endpointPath = normalizedPath - } -} - -// WithStateLess sets the server to stateless mode. -// If true, the server will manage no session information. Every request will be treated -// as a new session. No session id returned to the client. -// The default is false. -// -// Note: This is a convenience method. It's identical to set WithSessionIdManager option -// to StatelessSessionIdManager. -func WithStateLess(stateLess bool) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - if stateLess { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) - } - } -} - -// WithSessionIdManager sets a custom session id generator for the server. -// By default, the server uses StatelessGeneratingSessionIdManager (generates IDs but no local validation). -// Note: Options are applied in order; the last one wins. If combined with -// WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect. -func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - if manager == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) - return - } - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager) - } -} - -// WithSessionIdManagerResolver sets a custom session id manager resolver for the server. -// This allows for request-based session id management strategies. -// Note: Options are applied in order; the last one wins. If combined with -// WithStateLess or WithSessionIdManager, whichever is applied last takes effect. -func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - if resolver == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) - return - } - s.sessionIdManagerResolver = resolver - } -} - -// WithStateful enables stateful session management using InsecureStatefulSessionIdManager. -// This requires sticky sessions in multi-instance deployments. -func WithStateful(stateful bool) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - if stateful { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) - } - } -} - -// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the -// server will send a heartbeat to the client through the GET connection, to keep -// the connection alive from being closed by the network infrastructure (e.g. -// gateways). If the client does not establish a GET connection, it has no -// effect. The default is not to send heartbeats. -func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.listenHeartbeatInterval = interval - } -} - -// WithDisableStreaming prevents the server from responding to GET requests with -// a streaming response. Instead, it will respond with a 405 Method Not Allowed status. -// This can be useful in scenarios where streaming is not desired or supported. -// The default is false, meaning streaming is enabled. -func WithDisableStreaming(disable bool) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.disableStreaming = disable - } -} - -// WithHTTPContextFunc sets a function that will be called to customise the context -// to the server using the incoming request. -// This can be used to inject context values from headers, for example. -func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.contextFunc = fn - } -} - -// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer. -// NOTE: When providing a custom HTTP server, you must handle routing yourself -// If routing is not set up, the server will start but won't handle any MCP requests. -func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.httpServer = srv - } -} - -// WithLogger sets the logger for the server -func WithLogger(logger util.Logger) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.logger = logger - } -} - -// WithTLSCert sets the TLS certificate and key files for HTTPS support. -// Both certFile and keyFile must be provided to enable TLS. -func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { - return func(s *StreamableHTTPServer) { - s.tlsCertFile = certFile - s.tlsKeyFile = keyFile - } -} - -// StreamableHTTPServer implements a Streamable-http based MCP server. -// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. -// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http -// -// Usage: -// -// server := NewStreamableHTTPServer(mcpServer) -// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default -// -// or the server itself can be used as a http.Handler, which is convenient to -// integrate with existing http servers, or advanced usage: -// -// handler := NewStreamableHTTPServer(mcpServer) -// http.Handle("/streamable-http", handler) -// http.ListenAndServe(":8080", nil) -// -// Notice: -// Except for the GET handlers(listening), the POST handlers(request/notification) will -// not trigger the session registration. So the methods like `SendNotificationToSpecificClient` -// or `hooks.onRegisterSession` will not be triggered for POST messages. -// -// The current implementation does not support the following features from the specification: -// - Stream Resumability -type StreamableHTTPServer struct { - server *MCPServer - sessionTools *sessionToolsStore - sessionResources *sessionResourcesStore - sessionResourceTemplates *sessionResourceTemplatesStore - sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) - activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) - - httpServer *http.Server - mu sync.RWMutex - - endpointPath string - contextFunc HTTPContextFunc - sessionIdManagerResolver SessionIdManagerResolver - listenHeartbeatInterval time.Duration - logger util.Logger - sessionLogLevels *sessionLogLevelsStore - disableStreaming bool - - tlsCertFile string - tlsKeyFile string -} - -// NewStreamableHTTPServer creates a new streamable-http server instance -func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { - s := &StreamableHTTPServer{ - server: server, - sessionTools: newSessionToolsStore(), - sessionLogLevels: newSessionLogLevelsStore(), - endpointPath: "/mcp", - sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&StatelessGeneratingSessionIdManager{}), - logger: util.DefaultLogger(), - sessionResources: newSessionResourcesStore(), - sessionResourceTemplates: newSessionResourceTemplatesStore(), - } - - // Apply all options - for _, opt := range opts { - opt(s) - } - return s -} - -// ServeHTTP implements the http.Handler interface. -func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPost: - s.handlePost(w, r) - case http.MethodGet: - s.handleGet(w, r) - case http.MethodDelete: - s.handleDelete(w, r) - default: - http.NotFound(w, r) - } -} - -// Start begins serving the http server on the specified address and path -// (endpointPath). like: -// -// s.Start(":8080") -func (s *StreamableHTTPServer) Start(addr string) error { - s.mu.Lock() - if s.httpServer == nil { - mux := http.NewServeMux() - mux.Handle(s.endpointPath, s) - s.httpServer = &http.Server{ - Addr: addr, - Handler: mux, - } - } else { - if s.httpServer.Addr == "" { - s.httpServer.Addr = addr - } else if s.httpServer.Addr != addr { - return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr) - } - } - srv := s.httpServer - s.mu.Unlock() - - if s.tlsCertFile != "" || s.tlsKeyFile != "" { - if s.tlsCertFile == "" || s.tlsKeyFile == "" { - return fmt.Errorf("both TLS cert and key must be provided") - } - if _, err := os.Stat(s.tlsCertFile); err != nil { - return fmt.Errorf("failed to find TLS certificate file: %w", err) - } - if _, err := os.Stat(s.tlsKeyFile); err != nil { - return fmt.Errorf("failed to find TLS key file: %w", err) - } - return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) - } - - return srv.ListenAndServe() -} - -// Shutdown gracefully stops the server, closing all active sessions -// and shutting down the HTTP server. -func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { - - // shutdown the server if needed (may use as a http.Handler) - s.mu.RLock() - srv := s.httpServer - s.mu.RUnlock() - if srv != nil { - return srv.Shutdown(ctx) - } - return nil -} - -// --- internal methods --- - -func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { - // post request carry request/notification message - - // Check content type - contentType := r.Header.Get("Content-Type") - mediaType, _, err := mime.ParseMediaType(contentType) - if err != nil || mediaType != "application/json" { - http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest) - return - } - - // Check the request body is valid json, meanwhile, get the request Method - rawData, err := io.ReadAll(r.Body) - if err != nil { - s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) - return - } - // First, try to parse as a response (sampling responses don't have a method field) - var jsonMessage struct { - ID json.RawMessage `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error json.RawMessage `json:"error,omitempty"` - Method mcp.MCPMethod `json:"method,omitempty"` - } - if err := json.Unmarshal(rawData, &jsonMessage); err != nil { - s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") - return - } - - // detect empty ping response, skip session ID validation - isPingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && - (isJSONEmpty(jsonMessage.Result) && isJSONEmpty(jsonMessage.Error)) - - if isPingResponse { - return - } - - // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && - (jsonMessage.Result != nil || jsonMessage.Error != nil) - - isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize - - // Handle sampling responses separately - if isSamplingResponse { - if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { - s.logger.Errorf("Failed to handle sampling response: %v", err) - http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) - } - return - } - - // Prepare the session for the mcp server - // The session is ephemeral. Its life is the same as the request. It's only created - // for interaction with the mcp server. - var sessionID string - sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) - if isInitializeRequest { - // generate a new one for initialize request - sessionID = sessionIdManager.Generate() - } else { - // Get session ID from header. - // Stateful servers need the client to carry the session ID. - sessionID = r.Header.Get(HeaderKeySessionID) - isTerminated, err := sessionIdManager.Validate(sessionID) - if err != nil { - http.Error(w, "Invalid session ID", http.StatusBadRequest) - return - } - if isTerminated { - http.Error(w, "Session terminated", http.StatusNotFound) - return - } - } - - // For non-initialize requests, try to reuse existing registered session - var session *streamableHttpSession - if !isInitializeRequest { - if sessionValue, ok := s.server.sessions.Load(sessionID); ok { - if existingSession, ok := sessionValue.(*streamableHttpSession); ok { - session = existingSession - } - } - } - - // Check if a persistent session exists (for sampling support), otherwise create ephemeral session - // Persistent sessions are created by GET (continuous listening) connections - if session == nil { - if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { - if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { - session = persistentSession - } - } - } - - // Create ephemeral session if no persistent session exists - if session == nil { - session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) - } - - // Set the client context before handling the message - ctx := s.server.WithContext(r.Context(), session) - if s.contextFunc != nil { - ctx = s.contextFunc(ctx, r) - } - - // handle potential notifications - mu := sync.Mutex{} - upgradedHeader := false - done := make(chan struct{}) - - ctx = context.WithValue(ctx, requestHeader, r.Header) - go func() { - for { - select { - case nt := <-session.notificationChannel: - func() { - mu.Lock() - defer mu.Unlock() - // if the done chan is closed, as the request is terminated, just return - select { - case <-done: - return - default: - } - defer func() { - flusher, ok := w.(http.Flusher) - if ok { - flusher.Flush() - } - }() - - // if there's notifications, upgradedHeader to SSE response - if !upgradedHeader { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusOK) - upgradedHeader = true - } - err := writeSSEEvent(w, nt) - if err != nil { - s.logger.Errorf("Failed to write SSE event: %v", err) - return - } - }() - case <-done: - return - case <-ctx.Done(): - return - } - } - }() - - // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawData) - if response == nil { - // For notifications, just send 202 Accepted with no body - w.WriteHeader(http.StatusAccepted) - return - } - - // Write response - mu.Lock() - defer mu.Unlock() - // close the done chan before unlock - defer close(done) - if ctx.Err() != nil { - return - } - // If client-server communication already upgraded to SSE stream - if session.upgradeToSSE.Load() { - if !upgradedHeader { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusOK) - upgradedHeader = true - } - if err := writeSSEEvent(w, response); err != nil { - s.logger.Errorf("Failed to write final SSE response event: %v", err) - } - } else { - w.Header().Set("Content-Type", "application/json") - if isInitializeRequest && sessionID != "" { - // send the session ID back to the client - w.Header().Set(HeaderKeySessionID, sessionID) - } - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - s.logger.Errorf("Failed to write response: %v", err) - } - } - - // Register session after successful initialization - // Only register if not already registered (e.g., by a GET connection) - if isInitializeRequest && sessionID != "" { - if _, exists := s.server.sessions.Load(sessionID); !exists { - // Store in activeSessions to prevent duplicate registration from GET - s.activeSessions.Store(sessionID, session) - // Register the session with the MCPServer for notification support - if err := s.server.RegisterSession(ctx, session); err != nil { - s.logger.Errorf("Failed to register POST session: %v", err) - s.activeSessions.Delete(sessionID) - // Don't fail the request, just log the error - } - } - } -} - -func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { - // get request is for listening to notifications - // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server - if s.disableStreaming { - s.logger.Infof("Rejected GET request: streaming is disabled (session: %s)", r.Header.Get(HeaderKeySessionID)) - http.Error(w, "Streaming is disabled on this server", http.StatusMethodNotAllowed) - return - } - - sessionID := r.Header.Get(HeaderKeySessionID) - // the specification didn't say we should validate the session id - - if sessionID == "" { - // It's a stateless server, - // but the MCP server requires a unique ID for registering, so we use a random one - sessionID = uuid.New().String() - } - - // Get or create session atomically to prevent TOCTOU races - // where concurrent GETs could both create and register duplicate sessions - var session *streamableHttpSession - newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionResources, s.sessionResourceTemplates, s.sessionLogLevels) - actual, loaded := s.activeSessions.LoadOrStore(sessionID, newSession) - session = actual.(*streamableHttpSession) - - if !loaded { - // We created a new session, need to register it - if err := s.server.RegisterSession(r.Context(), session); err != nil { - s.activeSessions.Delete(sessionID) - http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) - return - } - defer s.server.UnregisterSession(r.Context(), sessionID) - defer s.activeSessions.Delete(sessionID) - } - - // Set the client context before handling the message - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported", http.StatusInternalServerError) - return - } - flusher.Flush() - - // Start notification handler for this session - done := make(chan struct{}) - defer close(done) - writeChan := make(chan any, 16) - - go func() { - for { - select { - case nt := <-session.notificationChannel: - select { - case writeChan <- &nt: - case <-done: - return - } - case samplingReq := <-session.samplingRequestChan: - // Send sampling request to client via SSE - jsonrpcRequest := mcp.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcp.NewRequestId(samplingReq.requestID), - Request: mcp.Request{ - Method: string(mcp.MethodSamplingCreateMessage), - }, - Params: samplingReq.request.CreateMessageParams, - } - select { - case writeChan <- jsonrpcRequest: - case <-done: - return - } - case elicitationReq := <-session.elicitationRequestChan: - // Send elicitation request to client via SSE - jsonrpcRequest := mcp.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcp.NewRequestId(elicitationReq.requestID), - Request: mcp.Request{ - Method: string(mcp.MethodElicitationCreate), - }, - Params: elicitationReq.request.Params, - } - select { - case writeChan <- jsonrpcRequest: - case <-done: - return - } - case rootsReq := <-session.rootsRequestChan: - // Send list roots request to client via SSE - jsonrpcRequest := mcp.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcp.NewRequestId(rootsReq.requestID), - Request: mcp.Request{ - Method: string(mcp.MethodListRoots), - }, - } - select { - case writeChan <- jsonrpcRequest: - case <-done: - return - } - case <-done: - return - } - } - }() - - if s.listenHeartbeatInterval > 0 { - // heartbeat to keep the connection alive - go func() { - ticker := time.NewTicker(s.listenHeartbeatInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - message := mcp.JSONRPCRequest{ - JSONRPC: "2.0", - ID: mcp.NewRequestId(s.nextRequestID(sessionID)), - Request: mcp.Request{ - Method: "ping", - }, - } - select { - case writeChan <- message: - case <-done: - return - } - case <-done: - return - } - } - }() - } - - // Keep the connection open until the client disconnects - // - // There's will a Available() check when handler ends, and it maybe race with Flush(), - // so we use a separate channel to send the data, inteading of flushing directly in other goroutine. - for { - select { - case data := <-writeChan: - if data == nil { - continue - } - if err := writeSSEEvent(w, data); err != nil { - s.logger.Errorf("Failed to write SSE event: %v", err) - return - } - flusher.Flush() - case <-r.Context().Done(): - return - } - } -} - -func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { - // delete request terminate the session - sessionID := r.Header.Get(HeaderKeySessionID) - sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) - notAllowed, err := sessionIdManager.Terminate(sessionID) - if err != nil { - http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) - return - } - if notAllowed { - http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed) - return - } - - // remove the session relateddata from the sessionToolsStore - s.sessionTools.delete(sessionID) - s.sessionResources.delete(sessionID) - s.sessionResourceTemplates.delete(sessionID) - s.sessionLogLevels.delete(sessionID) - // remove current session's requstID information - s.sessionRequestIDs.Delete(sessionID) - - w.WriteHeader(http.StatusOK) -} - -func writeSSEEvent(w io.Writer, data any) error { - jsonData, err := json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData) - if err != nil { - return fmt.Errorf("failed to write SSE event: %w", err) - } - return nil -} - -// handleSamplingResponse processes incoming sampling responses from clients -func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { - ID json.RawMessage `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error json.RawMessage `json:"error,omitempty"` - Method mcp.MCPMethod `json:"method,omitempty"` -}) error { - // Get session ID from header - sessionID := r.Header.Get(HeaderKeySessionID) - if sessionID == "" { - http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) - return fmt.Errorf("missing session ID") - } - - // Validate session - sessionIdManager := s.sessionIdManagerResolver.ResolveSessionIdManager(r) - isTerminated, err := sessionIdManager.Validate(sessionID) - if err != nil { - http.Error(w, "Invalid session ID", http.StatusBadRequest) - return err - } - if isTerminated { - http.Error(w, "Session terminated", http.StatusNotFound) - return fmt.Errorf("session terminated") - } - - // Parse the request ID - var requestID int64 - if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { - http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) - return err - } - - // Create the sampling response item - response := samplingResponseItem{ - requestID: requestID, - } - - // Parse result or error - if responseMessage.Error != nil { - // Parse error - var jsonrpcError struct { - Code int `json:"code"` - Message string `json:"message"` - } - if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { - response.err = fmt.Errorf("failed to parse error: %v", err) - } else { - response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) - } - } else if responseMessage.Result != nil { - // Store the result to be unmarshaled later - response.result = responseMessage.Result - } else { - response.err = fmt.Errorf("sampling response has neither result nor error") - } - - // Find the corresponding session and deliver the response - // The response is delivered to the specific session identified by sessionID - if err := s.deliverSamplingResponse(sessionID, response); err != nil { - s.logger.Errorf("Failed to deliver sampling response: %v", err) - http.Error(w, "Failed to deliver response", http.StatusInternalServerError) - return err - } - - // Acknowledge receipt - w.WriteHeader(http.StatusOK) - return nil -} - -// deliverSamplingResponse delivers a sampling response to the appropriate session -func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { - // Look up the active session - sessionInterface, ok := s.activeSessions.Load(sessionID) - if !ok { - return fmt.Errorf("no active session found for session %s", sessionID) - } - - session, ok := sessionInterface.(*streamableHttpSession) - if !ok { - return fmt.Errorf("invalid session type for session %s", sessionID) - } - - // Look up the dedicated response channel for this specific request - responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) - if !exists { - return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) - } - - responseChan, ok := responseChannelInterface.(chan samplingResponseItem) - if !ok { - return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) - } - - // Attempt to deliver the response with timeout to prevent indefinite blocking - select { - case responseChan <- response: - s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) - return nil - default: - return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) - } -} - -// writeJSONRPCError writes a JSON-RPC error response with the given error details. -func (s *StreamableHTTPServer) writeJSONRPCError( - w http.ResponseWriter, - id any, - code int, - message string, -) { - response := createErrorResponse(id, code, message) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - err := json.NewEncoder(w).Encode(response) - if err != nil { - s.logger.Errorf("Failed to write JSONRPCError: %v", err) - } -} - -// nextRequestID gets the next incrementing requestID for the current session -func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { - actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64)) - counter := actual.(*atomic.Int64) - return counter.Add(1) -} - -// --- session --- -type sessionLogLevelsStore struct { - mu sync.RWMutex - logs map[string]mcp.LoggingLevel -} - -func newSessionLogLevelsStore() *sessionLogLevelsStore { - return &sessionLogLevelsStore{ - logs: make(map[string]mcp.LoggingLevel), - } -} - -func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel { - s.mu.RLock() - defer s.mu.RUnlock() - val, ok := s.logs[sessionID] - if !ok { - return mcp.LoggingLevelError - } - return val -} - -func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) { - s.mu.Lock() - defer s.mu.Unlock() - s.logs[sessionID] = level -} - -func (s *sessionLogLevelsStore) delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.logs, sessionID) -} - -type sessionResourcesStore struct { - mu sync.RWMutex - resources map[string]map[string]ServerResource // sessionID -> resourceURI -> resource -} - -func newSessionResourcesStore() *sessionResourcesStore { - return &sessionResourcesStore{ - resources: make(map[string]map[string]ServerResource), - } -} - -func (s *sessionResourcesStore) get(sessionID string) map[string]ServerResource { - s.mu.RLock() - defer s.mu.RUnlock() - cloned := make(map[string]ServerResource, len(s.resources[sessionID])) - maps.Copy(cloned, s.resources[sessionID]) - return cloned -} - -func (s *sessionResourcesStore) set(sessionID string, resources map[string]ServerResource) { - s.mu.Lock() - defer s.mu.Unlock() - cloned := make(map[string]ServerResource, len(resources)) - maps.Copy(cloned, resources) - s.resources[sessionID] = cloned -} - -func (s *sessionResourcesStore) delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.resources, sessionID) -} - -type sessionResourceTemplatesStore struct { - mu sync.RWMutex - templates map[string]map[string]ServerResourceTemplate // sessionID -> uriTemplate -> template -} - -func newSessionResourceTemplatesStore() *sessionResourceTemplatesStore { - return &sessionResourceTemplatesStore{ - templates: make(map[string]map[string]ServerResourceTemplate), - } -} - -func (s *sessionResourceTemplatesStore) get(sessionID string) map[string]ServerResourceTemplate { - s.mu.RLock() - defer s.mu.RUnlock() - cloned := make(map[string]ServerResourceTemplate, len(s.templates[sessionID])) - maps.Copy(cloned, s.templates[sessionID]) - return cloned -} - -func (s *sessionResourceTemplatesStore) set(sessionID string, templates map[string]ServerResourceTemplate) { - s.mu.Lock() - defer s.mu.Unlock() - cloned := make(map[string]ServerResourceTemplate, len(templates)) - maps.Copy(cloned, templates) - s.templates[sessionID] = cloned -} - -func (s *sessionResourceTemplatesStore) delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.templates, sessionID) -} - -type sessionToolsStore struct { - mu sync.RWMutex - tools map[string]map[string]ServerTool // sessionID -> toolName -> tool -} - -func newSessionToolsStore() *sessionToolsStore { - return &sessionToolsStore{ - tools: make(map[string]map[string]ServerTool), - } -} - -func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool { - s.mu.RLock() - defer s.mu.RUnlock() - cloned := make(map[string]ServerTool, len(s.tools[sessionID])) - maps.Copy(cloned, s.tools[sessionID]) - return cloned -} - -func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) { - s.mu.Lock() - defer s.mu.Unlock() - cloned := make(map[string]ServerTool, len(tools)) - maps.Copy(cloned, tools) - s.tools[sessionID] = cloned -} - -func (s *sessionToolsStore) delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.tools, sessionID) -} - -// Sampling support types for HTTP transport -type samplingRequestItem struct { - requestID int64 - request mcp.CreateMessageRequest - response chan samplingResponseItem -} - -type samplingResponseItem struct { - requestID int64 - result json.RawMessage - err error -} - -// Elicitation support types for HTTP transport -type elicitationRequestItem struct { - requestID int64 - request mcp.ElicitationRequest - response chan samplingResponseItem -} - -// Roots support types for HTTP transport -type rootsRequestItem struct { - requestID int64 - request mcp.ListRootsRequest - response chan samplingResponseItem -} - -// streamableHttpSession is a session for streamable-http transport -// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. -// When in GET handlers(listening), it's a real session, and will be registered in the MCP server. -type streamableHttpSession struct { - sessionID string - notificationChannel chan mcp.JSONRPCNotification // server -> client notifications - tools *sessionToolsStore - resources *sessionResourcesStore - resourceTemplates *sessionResourceTemplatesStore - upgradeToSSE atomic.Bool - logLevels *sessionLogLevelsStore - clientInfo atomic.Value // stores session-specific client info - clientCapabilities atomic.Value // stores session-specific client capabilities - - // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests - rootsRequestChan chan rootsRequestItem // server -> client list roots requests - - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs -} - -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, resourcesStore *sessionResourcesStore, templatesStore *sessionResourceTemplatesStore, levels *sessionLogLevelsStore) *streamableHttpSession { - s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - resources: resourcesStore, - resourceTemplates: templatesStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), - elicitationRequestChan: make(chan elicitationRequestItem, 10), - rootsRequestChan: make(chan rootsRequestItem, 10), - } - return s -} - -func (s *streamableHttpSession) SessionID() string { - return s.sessionID -} - -func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification { - return s.notificationChannel -} - -func (s *streamableHttpSession) Initialize() { - // do nothing - // the session is ephemeral, no real initialized action needed -} - -func (s *streamableHttpSession) Initialized() bool { - // the session is ephemeral, no real initialized action needed - return true -} - -func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { - s.logLevels.set(s.sessionID, level) -} - -func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { - return s.logLevels.get(s.sessionID) -} - -var _ ClientSession = (*streamableHttpSession)(nil) - -func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { - return s.tools.get(s.sessionID) -} - -func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { - s.tools.set(s.sessionID, tools) -} - -func (s *streamableHttpSession) GetSessionResources() map[string]ServerResource { - return s.resources.get(s.sessionID) -} - -func (s *streamableHttpSession) SetSessionResources(resources map[string]ServerResource) { - s.resources.set(s.sessionID, resources) -} - -func (s *streamableHttpSession) GetSessionResourceTemplates() map[string]ServerResourceTemplate { - return s.resourceTemplates.get(s.sessionID) -} - -func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string]ServerResourceTemplate) { - s.resourceTemplates.set(s.sessionID, templates) -} - -func (s *streamableHttpSession) GetClientInfo() mcp.Implementation { - if value := s.clientInfo.Load(); value != nil { - if clientInfo, ok := value.(mcp.Implementation); ok { - return clientInfo - } - } - return mcp.Implementation{} -} - -func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) { - s.clientInfo.Store(clientInfo) -} - -func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities { - if value := s.clientCapabilities.Load(); value != nil { - if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { - return clientCapabilities - } - } - return mcp.ClientCapabilities{} -} - -func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { - s.clientCapabilities.Store(clientCapabilities) -} - -var ( - _ SessionWithTools = (*streamableHttpSession)(nil) - _ SessionWithResources = (*streamableHttpSession)(nil) - _ SessionWithResourceTemplates = (*streamableHttpSession)(nil) - _ SessionWithLogging = (*streamableHttpSession)(nil) - _ SessionWithClientInfo = (*streamableHttpSession)(nil) -) - -func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { - s.upgradeToSSE.Store(true) -} - -var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) - -// RequestSampling implements SessionWithSampling interface for HTTP transport -func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { - // Generate unique request ID - requestID := s.requestIDCounter.Add(1) - - // Create response channel for this specific request - responseChan := make(chan samplingResponseItem, 1) - - // Create the sampling request item - samplingRequest := samplingRequestItem{ - requestID: requestID, - request: request, - response: responseChan, - } - - // Store the pending request - s.samplingRequests.Store(requestID, responseChan) - defer s.samplingRequests.Delete(requestID) - - // Send the sampling request via the channel (non-blocking) - select { - case s.samplingRequestChan <- samplingRequest: - // Request queued successfully - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, fmt.Errorf("sampling request queue is full - server overloaded") - } - - // Wait for response or context cancellation - select { - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - var result mcp.CreateMessageResult - if err := json.Unmarshal(response.result, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) - } - - // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) - // HTTP transport unmarshals Content as map[string]any, we need to convert it to the proper type - if contentMap, ok := result.Content.(map[string]any); ok { - content, err := mcp.ParseContent(contentMap) - if err != nil { - return nil, fmt.Errorf("failed to parse sampling response content: %w", err) - } - result.Content = content - } - - return &result, nil - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// ListRoots implements SessionWithRoots interface for HTTP transport. -// It sends a list roots request to the client via SSE and waits for the response. -func (s *streamableHttpSession) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) { - // Generate unique request ID - requestID := s.requestIDCounter.Add(1) - - // Create response channel for this specific request - responseChan := make(chan samplingResponseItem, 1) - - // Create the roots request item - rootsRequest := rootsRequestItem{ - requestID: requestID, - request: request, - response: responseChan, - } - - // Store the pending request - s.samplingRequests.Store(requestID, responseChan) - defer s.samplingRequests.Delete(requestID) - - // Send the list roots request via the channel (non-blocking) - select { - case s.rootsRequestChan <- rootsRequest: - // Request queued successfully - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, fmt.Errorf("list roots request queue is full - server overloaded") - } - - // Wait for response or context cancellation - select { - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - var result mcp.ListRootsResult - if err := json.Unmarshal(response.result, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal list roots response: %v", err) - } - return &result, nil - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// RequestElicitation implements SessionWithElicitation interface for HTTP transport -func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { - // Generate unique request ID - requestID := s.requestIDCounter.Add(1) - - // Create response channel for this specific request - responseChan := make(chan samplingResponseItem, 1) - - // Create the sampling request item - elicitationRequest := elicitationRequestItem{ - requestID: requestID, - request: request, - response: responseChan, - } - - // Store the pending request - s.samplingRequests.Store(requestID, responseChan) - defer s.samplingRequests.Delete(requestID) - - // Send the sampling request via the channel (non-blocking) - select { - case s.elicitationRequestChan <- elicitationRequest: - // Request queued successfully - case <-ctx.Done(): - return nil, ctx.Err() - default: - return nil, fmt.Errorf("elicitation request queue is full - server overloaded") - } - - // Wait for response or context cancellation - select { - case response := <-responseChan: - if response.err != nil { - return nil, response.err - } - var result mcp.ElicitationResult - if err := json.Unmarshal(response.result, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err) - } - return &result, nil - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -var _ SessionWithSampling = (*streamableHttpSession)(nil) -var _ SessionWithElicitation = (*streamableHttpSession)(nil) -var _ SessionWithRoots = (*streamableHttpSession)(nil) - -// --- session id manager --- - -// SessionIdManagerResolver resolves a SessionIdManager based on the HTTP request -type SessionIdManagerResolver interface { - ResolveSessionIdManager(r *http.Request) SessionIdManager -} - -type SessionIdManager interface { - Generate() string - // Validate checks if a session ID is valid and not terminated. - // Returns isTerminated=true if the ID is valid but belongs to a terminated session. - // Returns err!=nil if the ID format is invalid or lookup failed. - Validate(sessionID string) (isTerminated bool, err error) - // Terminate marks a session ID as terminated. - // Returns isNotAllowed=true if the server policy prevents client termination. - // Returns err!=nil if the ID is invalid or termination failed. - Terminate(sessionID string) (isNotAllowed bool, err error) -} - -// DefaultSessionIdManagerResolver is a simple resolver that returns the same SessionIdManager for all requests -type DefaultSessionIdManagerResolver struct { - manager SessionIdManager -} - -// NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager -func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver { - if manager == nil { - manager = &StatelessSessionIdManager{} - } - return &DefaultSessionIdManagerResolver{manager: manager} -} - -// ResolveSessionIdManager returns the configured SessionIdManager for all requests -func (r *DefaultSessionIdManagerResolver) ResolveSessionIdManager(_ *http.Request) SessionIdManager { - return r.manager -} - -// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless. -type StatelessSessionIdManager struct{} - -func (s *StatelessSessionIdManager) Generate() string { - return "" -} - -func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { - // In stateless mode, ignore session IDs completely - don't validate or reject them - return false, nil -} - -func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { - return false, nil -} - -// StatelessGeneratingSessionIdManager generates session IDs but doesn't validate them locally. -// This allows session IDs to be generated for clients while working across multiple instances. -type StatelessGeneratingSessionIdManager struct{} - -func (s *StatelessGeneratingSessionIdManager) Generate() string { - return idPrefix + uuid.New().String() -} - -func (s *StatelessGeneratingSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { - // Only validate format, not existence - allows cross-instance operation - if !strings.HasPrefix(sessionID, idPrefix) { - return false, fmt.Errorf("invalid session id: %s", sessionID) - } - if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { - return false, fmt.Errorf("invalid session id: %s", sessionID) - } - return false, nil -} - -func (s *StatelessGeneratingSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { - // No-op termination since we don't track sessions - return false, nil -} - -// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions. -// It validates both format and existence of session IDs. -// For more secure session id, use a more complex generator, like a JWT. -type InsecureStatefulSessionIdManager struct { - sessions sync.Map - terminated sync.Map -} - -const idPrefix = "mcp-session-" - -func (s *InsecureStatefulSessionIdManager) Generate() string { - sessionID := idPrefix + uuid.New().String() - s.sessions.Store(sessionID, true) - return sessionID -} - -func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { - if !strings.HasPrefix(sessionID, idPrefix) { - return false, fmt.Errorf("invalid session id: %s", sessionID) - } - if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { - return false, fmt.Errorf("invalid session id: %s", sessionID) - } - if _, exists := s.terminated.Load(sessionID); exists { - return true, nil - } - if _, exists := s.sessions.Load(sessionID); !exists { - return false, fmt.Errorf("session not found: %s", sessionID) - } - return false, nil -} - -func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { - if _, exists := s.terminated.Load(sessionID); exists { - return false, nil - } - if _, exists := s.sessions.Load(sessionID); !exists { - return false, nil - } - s.terminated.Store(sessionID, true) - s.sessions.Delete(sessionID) - return false, nil -} - -// NewTestStreamableHTTPServer creates a test server for testing purposes -func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server { - sseServer := NewStreamableHTTPServer(server, opts...) - testServer := httptest.NewServer(sseServer) - return testServer -} - -// isJSONEmpty reports whether the provided JSON value is "empty": -// - null -// - empty object: {} -// - empty array: [] -// -// It also treats nil/whitespace-only input as empty. -// It does NOT treat 0, false, "" or non-empty composites as empty. -func isJSONEmpty(data json.RawMessage) bool { - if len(data) == 0 { - return true - } - - trimmed := bytes.TrimSpace(data) - if len(trimmed) == 0 { - return true - } - - switch trimmed[0] { - case '{': - if len(trimmed) == 2 && trimmed[1] == '}' { - return true - } - for i := 1; i < len(trimmed); i++ { - if !unicode.IsSpace(rune(trimmed[i])) { - return trimmed[i] == '}' - } - } - case '[': - if len(trimmed) == 2 && trimmed[1] == ']' { - return true - } - for i := 1; i < len(trimmed); i++ { - if !unicode.IsSpace(rune(trimmed[i])) { - return trimmed[i] == ']' - } - } - - case '"': // treat "" as not empty - return false - - case 'n': // null - return len(trimmed) == 4 && - trimmed[1] == 'u' && - trimmed[2] == 'l' && - trimmed[3] == 'l' - } - return false -} diff --git a/vendor/github.com/mark3labs/mcp-go/util/logger.go b/vendor/github.com/mark3labs/mcp-go/util/logger.go deleted file mode 100644 index 8d7555ce3..000000000 --- a/vendor/github.com/mark3labs/mcp-go/util/logger.go +++ /dev/null @@ -1,33 +0,0 @@ -package util - -import ( - "log" -) - -// Logger defines a minimal logging interface -type Logger interface { - Infof(format string, v ...any) - Errorf(format string, v ...any) -} - -// --- Standard Library Logger Wrapper --- - -// DefaultStdLogger implements Logger using the standard library's log.Logger. -func DefaultLogger() Logger { - return &stdLogger{ - logger: log.Default(), - } -} - -// stdLogger wraps the standard library's log.Logger. -type stdLogger struct { - logger *log.Logger -} - -func (l *stdLogger) Infof(format string, v ...any) { - l.logger.Printf("INFO: "+format, v...) -} - -func (l *stdLogger) Errorf(format string, v ...any) { - l.logger.Printf("ERROR: "+format, v...) -} diff --git a/vendor/github.com/buger/jsonparser/LICENSE b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE similarity index 96% rename from vendor/github.com/buger/jsonparser/LICENSE rename to vendor/github.com/modelcontextprotocol/go-sdk/LICENSE index ac25aeb7d..508be9266 100644 --- a/vendor/github.com/buger/jsonparser/LICENSE +++ b/vendor/github.com/modelcontextprotocol/go-sdk/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2016 Leonid Bugaev +Copyright (c) 2025 Go MCP SDK Authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go new file mode 100644 index 000000000..87665121c --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go @@ -0,0 +1,168 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "slices" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TokenInfo holds information from a bearer token. +type TokenInfo struct { + Scopes []string + Expiration time.Time + // UserID is an optional identifier for the authenticated user. + // If set by a TokenVerifier, it can be used by transports to prevent + // session hijacking by ensuring that all requests for a given session + // come from the same user. + UserID string + // TODO: add standard JWT fields + Extra map[string]any +} + +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// The error that a TokenVerifier should return for OAuth-specific protocol errors. +var ErrOAuth = errors.New("oauth error") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) + +// RequireBearerTokenOptions are options for [RequireBearerToken]. +type RequireBearerTokenOptions struct { + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. + ResourceMetadataURL string + // The required scopes. + Scopes []string +} + +type tokenInfoKey struct{} + +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + +// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. +// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. +// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header +// is populated to enable [protected resource metadata]. +// +// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 +func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { + // Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenInfo, errmsg, code := verify(r, verifier, opts) + if code != 0 { + if code == http.StatusUnauthorized || code == http.StatusForbidden { + if opts != nil && opts.ResourceMetadataURL != "" { + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + } + } + http.Error(w, errmsg, code) + return + } + r = r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo)) + handler.ServeHTTP(w, r) + }) + } +} + +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { + // Extract bearer token. + authHeader := req.Header.Get("Authorization") + fields := strings.Fields(authHeader) + if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { + return nil, "no bearer token", http.StatusUnauthorized + } + + // Verify the token and get information from it. + tokenInfo, err := verifier(req.Context(), fields[1], req) + if err != nil { + if errors.Is(err, ErrInvalidToken) { + return nil, err.Error(), http.StatusUnauthorized + } + if errors.Is(err, ErrOAuth) { + return nil, err.Error(), http.StatusBadRequest + } + return nil, err.Error(), http.StatusInternalServerError + } + + // Check scopes. All must be present. + if opts != nil { + // Note: quadratic, but N is small. + for _, s := range opts.Scopes { + if !slices.Contains(tokenInfo.Scopes, s) { + return nil, "insufficient scope", http.StatusForbidden + } + } + } + + // Check expiration. + if tokenInfo.Expiration.IsZero() { + return nil, "token missing expiration", http.StatusUnauthorized + } + if tokenInfo.Expiration.Before(time.Now()) { + return nil, "token expired", http.StatusUnauthorized + } + return tokenInfo, "", 0 +} + +// ProtectedResourceMetadataHandler returns an http.Handler that serves OAuth 2.0 +// protected resource metadata (RFC 9728) with CORS support. +// +// This handler allows cross-origin requests from any origin (Access-Control-Allow-Origin: *) +// because OAuth metadata is public information intended for client discovery (RFC 9728 §3.1). +// The metadata contains only non-sensitive configuration data about authorization servers +// and supported scopes. +// +// No validation of metadata fields is performed; ensure metadata accuracy at configuration time. +// +// For more sophisticated CORS policies or to restrict origins, wrap this handler with a +// CORS middleware like github.com/rs/cors or github.com/jub0bs/cors. +func ProtectedResourceMetadataHandler(metadata *oauthex.ProtectedResourceMetadata) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for cross-origin client discovery. + // OAuth metadata is public information, so allowing any origin is safe. + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // Handle CORS preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only GET allowed for metadata retrieval + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } + }) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go new file mode 100644 index 000000000..acadc51be --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/auth/client.go @@ -0,0 +1,123 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +type HTTPTransport struct { + handler OAuthHandler + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go new file mode 100644 index 000000000..627ffe7b6 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go @@ -0,0 +1,841 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" +) + +// Binder builds a connection configuration. +// This may be used in servers to generate a new configuration per connection. +// ConnectionOptions itself implements Binder returning itself unmodified, to +// allow for the simple cases where no per connection information is needed. +type Binder interface { + // Bind returns the ConnectionOptions to use when establishing the passed-in + // Connection. + // + // The connection is not ready to use when Bind is called, + // but Bind may close it without reading or writing to it. + Bind(context.Context, *Connection) ConnectionOptions +} + +// A BinderFunc implements the Binder interface for a standalone Bind function. +type BinderFunc func(context.Context, *Connection) ConnectionOptions + +func (f BinderFunc) Bind(ctx context.Context, c *Connection) ConnectionOptions { + return f(ctx, c) +} + +var _ Binder = BinderFunc(nil) + +// ConnectionOptions holds the options for new connections. +type ConnectionOptions struct { + // Framer allows control over the message framing and encoding. + // If nil, HeaderFramer will be used. + Framer Framer + // Preempter allows registration of a pre-queue message handler. + // If nil, no messages will be preempted. + Preempter Preempter + // Handler is used as the queued message handler for inbound messages. + // If nil, all responses will be ErrNotHandled. + Handler Handler + // OnInternalError, if non-nil, is called with any internal errors that occur + // while serving the connection, such as protocol errors or invariant + // violations. (If nil, internal errors result in panics.) + OnInternalError func(error) +} + +// Connection manages the jsonrpc2 protocol, connecting responses back to their +// calls. Connection is bidirectional; it does not have a designated server or +// client end. +// +// Note that the word 'Connection' is overloaded: the mcp.Connection represents +// the bidirectional stream of messages between client an server. The +// jsonrpc2.Connection layers RPC logic on top of that stream, dispatching RPC +// handlers, and correlating requests with responses from the peer. +// +// Some of the complexity of the Connection type is grown out of its usage in +// gopls: it could probably be simplified based on our usage in MCP. +type Connection struct { + seq int64 // must only be accessed using atomic operations + + stateMu sync.Mutex + state inFlightState // accessed only in updateInFlight + done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed + + writer Writer + handler Handler + + onInternalError func(error) + onDone func() +} + +// inFlightState records the state of the incoming and outgoing calls on a +// Connection. +type inFlightState struct { + connClosing bool // true when the Connection's Close method has been called + reading bool // true while the readIncoming goroutine is running + readErr error // non-nil when the readIncoming goroutine exits (typically io.EOF) + writeErr error // non-nil if a call to the Writer has failed with a non-canceled Context + + // closer shuts down and cleans up the Reader and Writer state, ideally + // interrupting any Read or Write call that is currently blocked. It is closed + // when the state is idle and one of: connClosing is true, readErr is non-nil, + // or writeErr is non-nil. + // + // After the closer has been invoked, the closer field is set to nil + // and the closeErr field is simultaneously set to its result. + closer io.Closer + closeErr error // error returned from closer.Close + + outgoingCalls map[ID]*AsyncCall // calls only + outgoingNotifications int // # of notifications awaiting "write" + + // incoming stores the total number of incoming calls and notifications + // that have not yet written or processed a result. + incoming int + + incomingByID map[ID]*incomingRequest // calls only + + // handlerQueue stores the backlog of calls and notifications that were not + // already handled by a preempter. + // The queue does not include the request currently being handled (if any). + handlerQueue []*incomingRequest + handlerRunning bool +} + +// updateInFlight locks the state of the connection's in-flight requests, allows +// f to mutate that state, and closes the connection if it is idle and either +// is closing or has a read or write error. +func (c *Connection) updateInFlight(f func(*inFlightState)) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + s := &c.state + + f(s) + + select { + case <-c.done: + // The connection was already completely done at the start of this call to + // updateInFlight, so it must remain so. (The call to f should have noticed + // that and avoided making any updates that would cause the state to be + // non-idle.) + if !s.idle() { + panic("jsonrpc2: updateInFlight transitioned to non-idle when already done") + } + return + default: + } + + if s.idle() && s.shuttingDown(ErrUnknown) != nil { + if s.closer != nil { + s.closeErr = s.closer.Close() + s.closer = nil // prevent duplicate Close calls + } + if s.reading { + // The readIncoming goroutine is still running. Our call to Close should + // cause it to exit soon, at which point it will make another call to + // updateInFlight, set s.reading to false, and mark the Connection done. + } else { + // The readIncoming goroutine has exited, or never started to begin with. + // Since everything else is idle, we're completely done. + if c.onDone != nil { + c.onDone() + } + close(c.done) + } + } +} + +// idle reports whether the connection is in a state with no pending calls or +// notifications. +// +// If idle returns true, the readIncoming goroutine may still be running, +// but no other goroutines are doing work on behalf of the connection. +func (s *inFlightState) idle() bool { + return len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning +} + +// shuttingDown reports whether the connection is in a state that should +// disallow new (incoming and outgoing) calls. It returns either nil or +// an error that is or wraps the provided errClosing. +func (s *inFlightState) shuttingDown(errClosing error) error { + if s.connClosing { + // If Close has been called explicitly, it doesn't matter what state the + // Reader and Writer are in: we shouldn't be starting new work because the + // caller told us not to start new work. + return errClosing + } + if s.readErr != nil { + // If the read side of the connection is broken, we cannot read new call + // requests, and cannot read responses to our outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.readErr) + } + if s.writeErr != nil { + // If the write side of the connection is broken, we cannot write responses + // for incoming calls, and cannot write requests for outgoing calls. + return fmt.Errorf("%w: %v", errClosing, s.writeErr) + } + return nil +} + +// incomingRequest is used to track an incoming request as it is being handled +type incomingRequest struct { + *Request // the request being processed + ctx context.Context + cancel context.CancelFunc +} + +// Bind returns the options unmodified. +func (o ConnectionOptions) Bind(context.Context, *Connection) ConnectionOptions { + return o +} + +// A ConnectionConfig configures a bidirectional jsonrpc2 connection. +type ConnectionConfig struct { + Reader Reader // required + Writer Writer // required + Closer io.Closer // required + Preempter Preempter // optional + Bind func(*Connection) Handler // required + OnDone func() // optional + OnInternalError func(error) // optional +} + +// NewConnection creates a new [Connection] object and starts processing +// incoming messages. +func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { + ctx = notDone{ctx} + + c := &Connection{ + state: inFlightState{closer: cfg.Closer}, + done: make(chan struct{}), + writer: cfg.Writer, + onDone: cfg.OnDone, + onInternalError: cfg.OnInternalError, + } + c.handler = cfg.Bind(c) + c.start(ctx, cfg.Reader, cfg.Preempter) + return c +} + +// bindConnection creates a new connection and runs it. +// +// This is used by the Dial and Serve functions to build the actual connection. +// +// The connection is closed automatically (and its resources cleaned up) when +// the last request has completed after the underlying ReadWriteCloser breaks, +// but it may be stopped earlier by calling Close (for a clean shutdown). +func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binder, onDone func()) *Connection { + // TODO: Should we create a new event span here? + // This will propagate cancellation from ctx; should it? + ctx := notDone{bindCtx} + + c := &Connection{ + state: inFlightState{closer: rwc}, + done: make(chan struct{}), + onDone: onDone, + } + // It's tempting to set a finalizer on c to verify that the state has gone + // idle when the connection becomes unreachable. Unfortunately, the Binder + // interface makes that unsafe: it allows the Handler to close over the + // Connection, which could create a reference cycle that would cause the + // Connection to become uncollectable. + + options := binder.Bind(bindCtx, c) + framer := options.Framer + if framer == nil { + framer = HeaderFramer() + } + c.handler = options.Handler + if c.handler == nil { + c.handler = defaultHandler{} + } + c.onInternalError = options.OnInternalError + + c.writer = framer.Writer(rwc) + reader := framer.Reader(rwc) + c.start(ctx, reader, options.Preempter) + return c +} + +func (c *Connection) start(ctx context.Context, reader Reader, preempter Preempter) { + c.updateInFlight(func(s *inFlightState) { + select { + case <-c.done: + // Bind already closed the connection; don't start a goroutine to read it. + return + default: + } + + // The goroutine started here will continue until the underlying stream is closed. + // + // (If the Binder closed the Connection already, this should error out and + // return almost immediately.) + s.reading = true + go c.readIncoming(ctx, reader, preempter) + }) +} + +// Notify invokes the target method but does not wait for a response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +func (c *Connection) Notify(ctx context.Context, method string, params any) (err error) { + attempted := false + + defer func() { + if attempted { + c.updateInFlight(func(s *inFlightState) { + s.outgoingNotifications-- + }) + } + }() + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, allow outgoing notifications only if + // there is at least one call still in flight. The number of calls in flight + // cannot increase once shutdown begins, and allowing outgoing notifications + // may permit notifications that will cancel in-flight calls. + if len(s.outgoingCalls) == 0 && len(s.incomingByID) == 0 { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + } + s.outgoingNotifications++ + attempted = true + }) + if err != nil { + return err + } + + notify, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("marshaling notify parameters: %v", err) + } + + return c.write(ctx, notify) +} + +// Call invokes the target method and returns an object that can be used to await the response. +// The params will be marshaled to JSON before sending over the wire, and will +// be handed to the method invoked. +// You do not have to wait for the response, it can just be ignored if not needed. +// If sending the call failed, the response will be ready and have the error in it. +func (c *Connection) Call(ctx context.Context, method string, params any) *AsyncCall { + // Generate a new request identifier. + id := Int64ID(atomic.AddInt64(&c.seq, 1)) + + ac := &AsyncCall{ + id: id, + ready: make(chan struct{}), + } + // When this method returns, either ac is retired, or the request has been + // written successfully and the call is awaiting a response (to be provided by + // the readIncoming goroutine). + + call, err := NewCall(ac.id, method, params) + if err != nil { + ac.retire(&Response{ID: id, Error: fmt.Errorf("marshaling call parameters: %w", err)}) + return ac + } + + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrClientClosing) + if err != nil { + return + } + if s.outgoingCalls == nil { + s.outgoingCalls = make(map[ID]*AsyncCall) + } + s.outgoingCalls[ac.id] = ac + }) + if err != nil { + ac.retire(&Response{ID: id, Error: err}) + return ac + } + + if err := c.write(ctx, call); err != nil { + // Sending failed. We will never get a response, so deliver a fake one if it + // wasn't already retired by the connection breaking. + c.Retire(ac, err) + } + return ac +} + +// Retire stops tracking the call, and reports err as its terminal error. +// +// Retire is safe to call multiple times: if the call is already no longer +// tracked, Retire is a no op. +func (c *Connection) Retire(ac *AsyncCall, err error) { + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: ac.id, Error: err}) + } else { + // ac was already retired elsewhere. + } + }) +} + +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + +type AsyncCall struct { + id ID + ready chan struct{} // closed after response has been set + response *Response +} + +// ID used for this call. +// This can be used to cancel the call if needed. +func (ac *AsyncCall) ID() ID { return ac.id } + +// IsReady can be used to check if the result is already prepared. +// This is guaranteed to return true on a result for which Await has already +// returned, or a call that failed to send in the first place. +func (ac *AsyncCall) IsReady() bool { + select { + case <-ac.ready: + return true + default: + return false + } +} + +// retire processes the response to the call. +// +// It is an error to call retire more than once: retire is guarded by the +// connection's outgoingCalls map. +func (ac *AsyncCall) retire(response *Response) { + select { + case <-ac.ready: + panic(fmt.Sprintf("jsonrpc2: retire called twice for ID %v", ac.id)) + default: + } + + ac.response = response + close(ac.ready) +} + +// Await waits for (and decodes) the results of a Call. +// The response will be unmarshaled from JSON into the result. +// +// If the call is cancelled due to context cancellation, the result is +// ctx.Err(). +func (ac *AsyncCall) Await(ctx context.Context, result any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ac.ready: + } + if ac.response.Error != nil { + return ac.response.Error + } + if result == nil { + return nil + } + return json.Unmarshal(ac.response.Result, result) +} + +// Cancel cancels the Context passed to the Handle call for the inbound message +// with the given ID. +// +// Cancel will not complain if the ID is not a currently active message, and it +// will not cause any messages that have not arrived yet with that ID to be +// cancelled. +func (c *Connection) Cancel(id ID) { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + req = s.incomingByID[id] + }) + if req != nil { + req.cancel() + } +} + +// Wait blocks until the connection is fully closed, but does not close it. +func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { + var err error + <-c.done + c.updateInFlight(func(s *inFlightState) { + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } + }) + return err +} + +// Close stops accepting new requests, waits for in-flight requests and enqueued +// Handle calls to complete, and then closes the underlying stream. +// +// After the start of a Close, notification requests (that lack IDs and do not +// receive responses) will continue to be passed to the Preempter, but calls +// with IDs will receive immediate responses with ErrServerClosing, and no new +// requests (not even notifications!) will be enqueued to the Handler. +func (c *Connection) Close() error { + // Stop handling new requests, and interrupt the reader (by closing the + // connection) as soon as the active requests finish. + c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) + return c.wait(false) +} + +// readIncoming collects inbound messages from the reader and delivers them, either responding +// to outgoing calls or feeding requests to the queue. +func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter Preempter) { + var err error + for { + var msg Message + msg, err = reader.Read(ctx) + if err != nil { + break + } + + switch msg := msg.(type) { + case *Request: + c.acceptRequest(ctx, msg, preempter) + + case *Response: + c.updateInFlight(func(s *inFlightState) { + if ac, ok := s.outgoingCalls[msg.ID]; ok { + delete(s.outgoingCalls, msg.ID) + ac.retire(msg) + } else { + // TODO: How should we report unexpected responses? + } + }) + + default: + c.internalErrorf("Read returned an unexpected message of type %T", msg) + } + } + + c.updateInFlight(func(s *inFlightState) { + s.reading = false + s.readErr = err + + // Retire any outgoing requests that were still in flight: with the Reader no + // longer being processed, they necessarily cannot receive a response. + for id, ac := range s.outgoingCalls { + ac.retire(&Response{ID: id, Error: err}) + } + s.outgoingCalls = nil + }) +} + +// acceptRequest either handles msg synchronously or enqueues it to be handled +// asynchronously. +func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter Preempter) { + // In theory notifications cannot be cancelled, but we build them a cancel + // context anyway. + reqCtx, cancel := context.WithCancel(ctx) + req := &incomingRequest{ + Request: msg, + ctx: reqCtx, + cancel: cancel, + } + + // If the request is a call, add it to the incoming map so it can be + // cancelled (or responded) by ID. + var err error + c.updateInFlight(func(s *inFlightState) { + s.incoming++ + + if req.IsCall() { + if s.incomingByID[req.ID] != nil { + err = fmt.Errorf("%w: request ID %v already in use", ErrInvalidRequest, req.ID) + req.ID = ID{} // Don't misattribute this error to the existing request. + return + } + + if s.incomingByID == nil { + s.incomingByID = make(map[ID]*incomingRequest) + } + s.incomingByID[req.ID] = req + + // When shutting down, reject all new Call requests, even if they could + // theoretically be handled by the preempter. The preempter could return + // ErrAsyncResponse, which would increase the amount of work in flight + // when we're trying to ensure that it strictly decreases. + err = s.shuttingDown(ErrServerClosing) + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + return + } + + if preempter != nil { + result, err := preempter.Preempt(req.ctx, req.Request) + + if !errors.Is(err, ErrNotHandled) { + c.processResult("Preempt", req, result, err) + return + } + } + + c.updateInFlight(func(s *inFlightState) { + // If the connection is shutting down, don't enqueue anything to the + // handler — not even notifications. That ensures that if the handler + // continues to make progress, it will eventually become idle and + // close the connection. + err = s.shuttingDown(ErrServerClosing) + if err != nil { + return + } + + // We enqueue requests that have not been preempted to an unbounded slice. + // Unfortunately, we cannot in general limit the size of the handler + // queue: we have to read every response that comes in on the wire + // (because it may be responding to a request issued by, say, an + // asynchronous handler), and in order to get to that response we have + // to read all of the requests that came in ahead of it. + s.handlerQueue = append(s.handlerQueue, req) + if !s.handlerRunning { + // We start the handleAsync goroutine when it has work to do, and let it + // exit when the queue empties. + // + // Otherwise, in order to synchronize the handler we would need some other + // goroutine (probably readIncoming?) to explicitly wait for handleAsync + // to finish, and that would complicate error reporting: either the error + // report from the goroutine would be blocked on the handler emptying its + // queue (which was tried, and introduced a deadlock detected by + // TestCloseCallRace), or the error would need to be reported separately + // from synchronizing completion. Allowing the handler goroutine to exit + // when idle seems simpler than trying to implement either of those + // alternatives correctly. + s.handlerRunning = true + go c.handleAsync() + } + }) + if err != nil { + c.processResult("acceptRequest", req, nil, err) + } +} + +// handleAsync invokes the handler on the requests in the handler queue +// sequentially until the queue is empty. +func (c *Connection) handleAsync() { + for { + var req *incomingRequest + c.updateInFlight(func(s *inFlightState) { + if len(s.handlerQueue) > 0 { + req, s.handlerQueue = s.handlerQueue[0], s.handlerQueue[1:] + } else { + s.handlerRunning = false + } + }) + if req == nil { + return + } + + // Only deliver to the Handler if not already canceled. + if err := req.ctx.Err(); err != nil { + c.updateInFlight(func(s *inFlightState) { + if s.writeErr != nil { + // Assume that req.ctx was canceled due to s.writeErr. + // TODO(#51365): use a Context API to plumb this through req.ctx. + err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr) + } + }) + c.processResult("handleAsync", req, nil, err) + continue + } + + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch + } +} + +// processResult processes the result of a request and, if appropriate, sends a response. +func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { + switch err { + case ErrNotHandled, ErrMethodNotFound: + // Add detail describing the unhandled method. + err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) + } + + if result != nil && err != nil { + c.internalErrorf("%#v returned a non-nil result with a non-nil error for %s:\n%v\n%#v", from, req.Method, err, result) + result = nil // Discard the spurious result and respond with err. + } + + if req.IsCall() { + if result == nil && err == nil { + err = c.internalErrorf("%#v returned a nil result and nil error for a %q Request that requires a Response", from, req.Method) + } + + response, respErr := NewResponse(req.ID, result, err) + + // The caller could theoretically reuse the request's ID as soon as we've + // sent the response, so ensure that it is removed from the incoming map + // before sending. + c.updateInFlight(func(s *inFlightState) { + delete(s.incomingByID, req.ID) + }) + if respErr == nil { + writeErr := c.write(notDone{req.ctx}, response) + if err == nil { + err = writeErr + } + } else { + err = c.internalErrorf("%#v returned a malformed result for %q: %w", from, req.Method, respErr) + } + } else { // req is a notification + if result != nil { + err = c.internalErrorf("%#v returned a non-nil result for a %q Request without an ID", from, req.Method) + } else if err != nil { + err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) + } + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) + } + + // Cancel the request to free any associated resources. + req.cancel() + c.updateInFlight(func(s *inFlightState) { + if s.incoming == 0 { + panic("jsonrpc2: processResult called when incoming count is already zero") + } + s.incoming-- + }) + return nil +} + +// write is used by all things that write outgoing messages, including replies. +// it makes sure that writes are atomic +func (c *Connection) write(ctx context.Context, msg Message) error { + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For cancelled or rejected requests, we don't set the writeErr (which would + // break the connection). They can just be returned to the caller. + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { + // The call to Write failed, and since ctx.Err() is nil we can't attribute + // the failure (even indirectly) to Context cancellation. The writer appears + // to be broken, and future writes are likely to also fail. + // + // If the read side of the connection is also broken, we might not even be + // able to receive cancellation notifications. Since we can't reliably write + // the results of incoming calls and can't receive explicit cancellations, + // cancel the calls now. + c.updateInFlight(func(s *inFlightState) { + if s.writeErr == nil { + s.writeErr = err + for _, r := range s.incomingByID { + r.cancel() + } + } + }) + } + + return err +} + +// internalErrorf reports an internal error. By default it panics, but if +// c.onInternalError is non-nil it instead calls that and returns an error +// wrapping ErrInternal. +func (c *Connection) internalErrorf(format string, args ...any) error { + err := fmt.Errorf(format, args...) + if c.onInternalError == nil { + panic("jsonrpc2: " + err.Error()) + } + c.onInternalError(err) + + return fmt.Errorf("%w: %v", ErrInternal, err) +} + +// notDone is a context.Context wrapper that returns a nil Done channel. +type notDone struct{ ctx context.Context } + +func (ic notDone) Value(key any) any { + return ic.ctx.Value(key) +} + +func (notDone) Done() <-chan struct{} { return nil } +func (notDone) Err() error { return nil } +func (notDone) Deadline() (time.Time, bool) { return time.Time{}, false } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go new file mode 100644 index 000000000..46fcc9db9 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/frame.go @@ -0,0 +1,208 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "sync" +) + +// Reader abstracts the transport mechanics from the JSON RPC protocol. +// A Conn reads messages from the reader it was provided on construction, +// and assumes that each call to Read fully transfers a single message, +// or returns an error. +// +// A reader is not safe for concurrent use, it is expected it will be used by +// a single Conn in a safe manner. +type Reader interface { + // Read gets the next message from the stream. + Read(context.Context) (Message, error) +} + +// Writer abstracts the transport mechanics from the JSON RPC protocol. +// A Conn writes messages using the writer it was provided on construction, +// and assumes that each call to Write fully transfers a single message, +// or returns an error. +// +// A writer must be safe for concurrent use, as writes may occur concurrently +// in practice: libraries may make calls or respond to requests asynchronously. +type Writer interface { + // Write sends a message to the stream. + Write(context.Context, Message) error +} + +// Framer wraps low level byte readers and writers into jsonrpc2 message +// readers and writers. +// It is responsible for the framing and encoding of messages into wire form. +// +// TODO(rfindley): rethink the framer interface, as with JSONRPC2 batching +// there is a need for Reader and Writer to be correlated, and while the +// implementation of framing here allows that, it is not made explicit by the +// interface. +// +// Perhaps a better interface would be +// +// Frame(io.ReadWriteCloser) (Reader, Writer). +type Framer interface { + // Reader wraps a byte reader into a message reader. + Reader(io.Reader) Reader + // Writer wraps a byte writer into a message writer. + Writer(io.Writer) Writer +} + +// RawFramer returns a new Framer. +// The messages are sent with no wrapping, and rely on json decode consistency +// to determine message boundaries. +func RawFramer() Framer { return rawFramer{} } + +type rawFramer struct{} +type rawReader struct{ in *json.Decoder } +type rawWriter struct { + mu sync.Mutex + out io.Writer +} + +func (rawFramer) Reader(rw io.Reader) Reader { + return &rawReader{in: json.NewDecoder(rw)} +} + +func (rawFramer) Writer(rw io.Writer) Writer { + return &rawWriter{out: rw} +} + +func (r *rawReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + var raw json.RawMessage + if err := r.in.Decode(&raw); err != nil { + return nil, err + } + msg, err := DecodeMessage(raw) + return msg, err +} + +func (w *rawWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + + w.mu.Lock() + defer w.mu.Unlock() + _, err = w.out.Write(data) + return err +} + +// HeaderFramer returns a new Framer. +// The messages are sent with HTTP content length and MIME type headers. +// This is the format used by LSP and others. +func HeaderFramer() Framer { return headerFramer{} } + +type headerFramer struct{} +type headerReader struct{ in *bufio.Reader } +type headerWriter struct { + mu sync.Mutex + out io.Writer +} + +func (headerFramer) Reader(rw io.Reader) Reader { + return &headerReader{in: bufio.NewReader(rw)} +} + +func (headerFramer) Writer(rw io.Writer) Writer { + return &headerWriter{out: rw} +} + +func (r *headerReader) Read(ctx context.Context) (Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + firstRead := true // to detect a clean EOF below + var contentLength int64 + // read the header, stop on the first empty line + for { + line, err := r.in.ReadString('\n') + if err != nil { + if err == io.EOF { + if firstRead && line == "" { + return nil, io.EOF // clean EOF + } + err = io.ErrUnexpectedEOF + } + return nil, fmt.Errorf("failed reading header line: %w", err) + } + firstRead = false + + line = strings.TrimSpace(line) + // check we have a header line + if line == "" { + break + } + colon := strings.IndexRune(line, ':') + if colon < 0 { + return nil, fmt.Errorf("invalid header line %q", line) + } + name, value := line[:colon], strings.TrimSpace(line[colon+1:]) + switch name { + case "Content-Length": + if contentLength, err = strconv.ParseInt(value, 10, 32); err != nil { + return nil, fmt.Errorf("failed parsing Content-Length: %v", value) + } + if contentLength <= 0 { + return nil, fmt.Errorf("invalid Content-Length: %v", contentLength) + } + default: + // ignoring unknown headers + } + } + if contentLength == 0 { + return nil, fmt.Errorf("missing Content-Length header") + } + data := make([]byte, contentLength) + _, err := io.ReadFull(r.in, data) + if err != nil { + return nil, err + } + msg, err := DecodeMessage(data) + return msg, err +} + +func (w *headerWriter) Write(ctx context.Context, msg Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + w.mu.Lock() + defer w.mu.Unlock() + + data, err := EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + _, err = fmt.Fprintf(w.out, "Content-Length: %v\r\n\r\n", len(data)) + if err == nil { + _, err = w.out.Write(data) + } + return err +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go new file mode 100644 index 000000000..234e6ee3a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/jsonrpc2.go @@ -0,0 +1,121 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc2 is a minimal implementation of the JSON RPC 2 spec. +// https://www.jsonrpc.org/specification +// It is intended to be compatible with other implementations at the wire level. +package jsonrpc2 + +import ( + "context" + "errors" +) + +var ( + // ErrIdleTimeout is returned when serving timed out waiting for new connections. + ErrIdleTimeout = errors.New("timed out waiting for new connections") + + // ErrNotHandled is returned from a Handler or Preempter to indicate it did + // not handle the request. + // + // If a Handler returns ErrNotHandled, the server replies with + // ErrMethodNotFound. + ErrNotHandled = errors.New("JSON RPC not handled") +) + +// Preempter handles messages on a connection before they are queued to the main +// handler. +// Primarily this is used for cancel handlers or notifications for which out of +// order processing is not an issue. +type Preempter interface { + // Preempt is invoked for each incoming request before it is queued for handling. + // + // If Preempt returns ErrNotHandled, the request will be queued, + // and eventually passed to a Handle call. + // + // Otherwise, the result and error are processed as if returned by Handle. + // + // Preempt must not block. (The Context passed to it is for Values only.) + Preempt(ctx context.Context, req *Request) (result any, err error) +} + +// A PreempterFunc implements the Preempter interface for a standalone Preempt function. +type PreempterFunc func(ctx context.Context, req *Request) (any, error) + +func (f PreempterFunc) Preempt(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Preempter = PreempterFunc(nil) + +// Handler handles messages on a connection. +type Handler interface { + // Handle is invoked sequentially for each incoming request that has not + // already been handled by a Preempter. + // + // If the Request has a nil ID, Handle must return a nil result, + // and any error may be logged but will not be reported to the caller. + // + // If the Request has a non-nil ID, Handle must return either a + // non-nil, JSON-marshalable result, or a non-nil error. + // + // The Context passed to Handle will be canceled if the + // connection is broken or the request is canceled or completed. + // (If Handle returns ErrAsyncResponse, ctx will remain uncanceled + // until either Cancel or Respond is called for the request's ID.) + Handle(ctx context.Context, req *Request) (result any, err error) +} + +type defaultHandler struct{} + +func (defaultHandler) Preempt(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +func (defaultHandler) Handle(context.Context, *Request) (any, error) { + return nil, ErrNotHandled +} + +// A HandlerFunc implements the Handler interface for a standalone Handle function. +type HandlerFunc func(ctx context.Context, req *Request) (any, error) + +func (f HandlerFunc) Handle(ctx context.Context, req *Request) (any, error) { + return f(ctx, req) +} + +var _ Handler = HandlerFunc(nil) + +// async is a small helper for operations with an asynchronous result that you +// can wait for. +type async struct { + ready chan struct{} // closed when done + firstErr chan error // 1-buffered; contains either nil or the first non-nil error +} + +func newAsync() *async { + var a async + a.ready = make(chan struct{}) + a.firstErr = make(chan error, 1) + a.firstErr <- nil + return &a +} + +func (a *async) done() { + close(a.ready) +} + +func (a *async) wait() error { + <-a.ready + err := <-a.firstErr + a.firstErr <- err + return err +} + +func (a *async) setError(err error) { + storedErr := <-a.firstErr + if storedErr == nil { + storedErr = err + } + a.firstErr <- storedErr +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go new file mode 100644 index 000000000..791e698d9 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/messages.go @@ -0,0 +1,212 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" + "errors" + "fmt" +) + +// ID is a Request identifier, which is defined by the spec to be a string, integer, or null. +// https://www.jsonrpc.org/specification#request_object +type ID struct { + value any +} + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +// +// TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero. +// Simplify this package by making ID json serializable once we can rely on +// omitzero. +func MakeID(v any) (ID, error) { + switch v := v.(type) { + case nil: + return ID{}, nil + case float64: + return Int64ID(int64(v)), nil + case string: + return StringID(v), nil + } + return ID{}, fmt.Errorf("%w: invalid ID type %T", ErrParse, v) +} + +// Message is the interface to all jsonrpc2 message types. +// They share no common functionality, but are a closed set of concrete types +// that are allowed to implement this interface. The message types are *Request +// and *Response. +type Message interface { + // marshal builds the wire form from the API form. + // It is private, which makes the set of Message implementations closed. + marshal(to *wireCombined) +} + +// Request is a Message sent to a peer to request behavior. +// If it has an ID it is a call, otherwise it is a notification. +type Request struct { + // ID of this request, used to tie the Response back to the request. + // This will be nil for notifications. + ID ID + // Method is a string containing the method name to invoke. + Method string + // Params is either a struct or an array with the parameters of the method. + Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any +} + +// Response is a Message used as a reply to a call Request. +// It will have the same ID as the call it is a response to. +type Response struct { + // result is the content of the response. + Result json.RawMessage + // err is set only if the call failed. + Error error + // id of the request this is a response to. + ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any +} + +// StringID creates a new string request identifier. +func StringID(s string) ID { return ID{value: s} } + +// Int64ID creates a new integer request identifier. +func Int64ID(i int64) ID { return ID{value: i} } + +// IsValid returns true if the ID is a valid identifier. +// The default value for ID will return false. +func (id ID) IsValid() bool { return id.value != nil } + +// Raw returns the underlying value of the ID. +func (id ID) Raw() any { return id.value } + +// NewNotification constructs a new Notification message for the supplied +// method and parameters. +func NewNotification(method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{Method: method, Params: p}, merr +} + +// NewCall constructs a new Call message for the supplied ID, method and +// parameters. +func NewCall(id ID, method string, params any) (*Request, error) { + p, merr := marshalToRaw(params) + return &Request{ID: id, Method: method, Params: p}, merr +} + +func (msg *Request) IsCall() bool { return msg.ID.IsValid() } + +func (msg *Request) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Method = msg.Method + to.Params = msg.Params +} + +// NewResponse constructs a new Response message that is a reply to the +// supplied. If err is set result may be ignored. +func NewResponse(id ID, result any, rerr error) (*Response, error) { + r, merr := marshalToRaw(result) + return &Response{ID: id, Result: r, Error: rerr}, merr +} + +func (msg *Response) marshal(to *wireCombined) { + to.ID = msg.ID.value + to.Error = toWireError(msg.Error) + to.Result = msg.Result +} + +func toWireError(err error) *WireError { + if err == nil { + // no error, the response is complete + return nil + } + if err, ok := err.(*WireError); ok { + // already a wire error, just use it + return err + } + result := &WireError{Message: err.Error()} + var wrapped *WireError + if errors.As(err, &wrapped) { + // if we wrapped a wire error, keep the code from the wrapped error + // but the message from the outer error + result.Code = wrapped.Code + } + return result +} + +func EncodeMessage(msg Message) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.Marshal(&wire) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +// EncodeIndent is like EncodeMessage, but honors indents. +// TODO(rfindley): refactor so that this concern is handled independently. +// Perhaps we should pass in a json.Encoder? +func EncodeIndent(msg Message, prefix, indent string) ([]byte, error) { + wire := wireCombined{VersionTag: wireVersion} + msg.marshal(&wire) + data, err := json.MarshalIndent(&wire, prefix, indent) + if err != nil { + return data, fmt.Errorf("marshaling jsonrpc message: %w", err) + } + return data, nil +} + +func DecodeMessage(data []byte) (Message, error) { + msg := wireCombined{} + if err := json.Unmarshal(data, &msg); err != nil { + return nil, fmt.Errorf("unmarshaling jsonrpc message: %w", err) + } + if msg.VersionTag != wireVersion { + return nil, fmt.Errorf("invalid message version tag %q; expected %q", msg.VersionTag, wireVersion) + } + id, err := MakeID(msg.ID) + if err != nil { + return nil, err + } + if msg.Method != "" { + // has a method, must be a call + return &Request{ + Method: msg.Method, + ID: id, + Params: msg.Params, + }, nil + } + // no method, should be a response + if !id.IsValid() { + return nil, ErrInvalidRequest + } + resp := &Response{ + ID: id, + Result: msg.Result, + } + // we have to check if msg.Error is nil to avoid a typed error + if msg.Error != nil { + resp.Error = msg.Error + } + return resp, nil +} + +func marshalToRaw(obj any) (json.RawMessage, error) { + if obj == nil { + return nil, nil + } + data, err := json.Marshal(obj) + if err != nil { + return nil, err + } + return json.RawMessage(data), nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go new file mode 100644 index 000000000..05db06261 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/net.go @@ -0,0 +1,138 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "io" + "net" + "os" +) + +// This file contains implementations of the transport primitives that use the standard network +// package. + +// NetListenOptions is the optional arguments to the NetListen function. +type NetListenOptions struct { + NetListenConfig net.ListenConfig + NetDialer net.Dialer +} + +// NetListener returns a new Listener that listens on a socket using the net package. +func NetListener(ctx context.Context, network, address string, options NetListenOptions) (Listener, error) { + ln, err := options.NetListenConfig.Listen(ctx, network, address) + if err != nil { + return nil, err + } + return &netListener{net: ln}, nil +} + +// netListener is the implementation of Listener for connections made using the net package. +type netListener struct { + net net.Listener +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netListener) Accept(context.Context) (io.ReadWriteCloser, error) { + return l.net.Accept() +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netListener) Close() error { + addr := l.net.Addr() + err := l.net.Close() + if addr.Network() == "unix" { + rerr := os.Remove(addr.String()) + if rerr != nil && err == nil { + err = rerr + } + } + return err +} + +// Dialer returns a dialer that can be used to connect to the listener. +func (l *netListener) Dialer() Dialer { + return NetDialer(l.net.Addr().Network(), l.net.Addr().String(), net.Dialer{}) +} + +// NetDialer returns a Dialer using the supplied standard network dialer. +func NetDialer(network, address string, nd net.Dialer) Dialer { + return &netDialer{ + network: network, + address: address, + dialer: nd, + } +} + +type netDialer struct { + network string + address string + dialer net.Dialer +} + +func (n *netDialer) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + return n.dialer.DialContext(ctx, n.network, n.address) +} + +// NetPipeListener returns a new Listener that listens using net.Pipe. +// It is only possibly to connect to it using the Dialer returned by the +// Dialer method, each call to that method will generate a new pipe the other +// side of which will be returned from the Accept call. +func NetPipeListener(ctx context.Context) (Listener, error) { + return &netPiper{ + done: make(chan struct{}), + dialed: make(chan io.ReadWriteCloser), + }, nil +} + +// netPiper is the implementation of Listener build on top of net.Pipes. +type netPiper struct { + done chan struct{} + dialed chan io.ReadWriteCloser +} + +// Accept blocks waiting for an incoming connection to the listener. +func (l *netPiper) Accept(context.Context) (io.ReadWriteCloser, error) { + // Block until the pipe is dialed or the listener is closed, + // preferring the latter if already closed at the start of Accept. + select { + case <-l.done: + return nil, net.ErrClosed + default: + } + select { + case rwc := <-l.dialed: + return rwc, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +// Close will cause the listener to stop listening. It will not close any connections that have +// already been accepted. +func (l *netPiper) Close() error { + // unblock any accept calls that are pending + close(l.done) + return nil +} + +func (l *netPiper) Dialer() Dialer { + return l +} + +func (l *netPiper) Dial(ctx context.Context) (io.ReadWriteCloser, error) { + client, server := net.Pipe() + + select { + case l.dialed <- server: + return client, nil + + case <-l.done: + client.Close() + server.Close() + return nil, net.ErrClosed + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go new file mode 100644 index 000000000..424163aaf --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/serve.go @@ -0,0 +1,330 @@ +// Copyright 2020 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "context" + "fmt" + "io" + "runtime" + "sync" + "sync/atomic" + "time" +) + +// Listener is implemented by protocols to accept new inbound connections. +type Listener interface { + // Accept accepts an inbound connection to a server. + // It blocks until either an inbound connection is made, or the listener is closed. + Accept(context.Context) (io.ReadWriteCloser, error) + + // Close closes the listener. + // Any blocked Accept or Dial operations will unblock and return errors. + Close() error + + // Dialer returns a dialer that can be used to connect to this listener + // locally. + // If a listener does not implement this it will return nil. + Dialer() Dialer +} + +// Dialer is used by clients to dial a server. +type Dialer interface { + // Dial returns a new communication byte stream to a listening server. + Dial(ctx context.Context) (io.ReadWriteCloser, error) +} + +// Server is a running server that is accepting incoming connections. +type Server struct { + listener Listener + binder Binder + async *async + + shutdownOnce sync.Once + closing int32 // atomic: set to nonzero when Shutdown is called +} + +// Dial uses the dialer to make a new connection, wraps the returned +// reader and writer using the framer to make a stream, and then builds +// a connection on top of that stream using the binder. +// +// The returned Connection will operate independently using the Preempter and/or +// Handler provided by the Binder, and will release its own resources when the +// connection is broken, but the caller may Close it earlier to stop accepting +// (or sending) new requests. +// +// If non-nil, the onDone function is called when the connection is closed. +func Dial(ctx context.Context, dialer Dialer, binder Binder, onDone func()) (*Connection, error) { + // dial a server + rwc, err := dialer.Dial(ctx) + if err != nil { + return nil, err + } + return bindConnection(ctx, rwc, binder, onDone), nil +} + +// NewServer starts a new server listening for incoming connections and returns +// it. +// This returns a fully running and connected server, it does not block on +// the listener. +// You can call Wait to block on the server, or Shutdown to get the sever to +// terminate gracefully. +// To notice incoming connections, use an intercepting Binder. +func NewServer(ctx context.Context, listener Listener, binder Binder) *Server { + server := &Server{ + listener: listener, + binder: binder, + async: newAsync(), + } + go server.run(ctx) + return server +} + +// Wait returns only when the server has shut down. +func (s *Server) Wait() error { + return s.async.wait() +} + +// Shutdown informs the server to stop accepting new connections. +func (s *Server) Shutdown() { + s.shutdownOnce.Do(func() { + atomic.StoreInt32(&s.closing, 1) + s.listener.Close() + }) +} + +// run accepts incoming connections from the listener, +// If IdleTimeout is non-zero, run exits after there are no clients for this +// duration, otherwise it exits only on error. +func (s *Server) run(ctx context.Context) { + defer s.async.done() + + var activeConns sync.WaitGroup + for { + rwc, err := s.listener.Accept(ctx) + if err != nil { + // Only Shutdown closes the listener. If we get an error after Shutdown is + // called, assume that was the cause and don't report the error; + // otherwise, report the error in case it is unexpected. + if atomic.LoadInt32(&s.closing) == 0 { + s.async.setError(err) + } + // We are done generating new connections for good. + break + } + + // A new inbound connection. + activeConns.Add(1) + _ = bindConnection(ctx, rwc, s.binder, activeConns.Done) // unregisters itself when done + } + activeConns.Wait() +} + +// NewIdleListener wraps a listener with an idle timeout. +// +// When there are no active connections for at least the timeout duration, +// calls to Accept will fail with ErrIdleTimeout. +// +// A connection is considered inactive as soon as its Close method is called. +func NewIdleListener(timeout time.Duration, wrap Listener) Listener { + l := &idleListener{ + wrapped: wrap, + timeout: timeout, + active: make(chan int, 1), + timedOut: make(chan struct{}), + idleTimer: make(chan *time.Timer, 1), + } + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + return l +} + +type idleListener struct { + wrapped Listener + timeout time.Duration + + // Only one of these channels is receivable at any given time. + active chan int // count of active connections; closed when Close is called if not timed out + timedOut chan struct{} // closed when the idle timer expires + idleTimer chan *time.Timer // holds the timer only when idle +} + +// Accept accepts an incoming connection. +// +// If an incoming connection is accepted concurrent to the listener being closed +// due to idleness, the new connection is immediately closed. +func (l *idleListener) Accept(ctx context.Context) (io.ReadWriteCloser, error) { + rwc, err := l.wrapped.Accept(ctx) + + select { + case n, ok := <-l.active: + if err != nil { + if ok { + l.active <- n + } + return nil, err + } + if ok { + l.active <- n + 1 + } else { + // l.wrapped.Close Close has been called, but Accept returned a + // connection. This race can occur with concurrent Accept and Close calls + // with any net.Listener, and it is benign: since the listener was closed + // explicitly, it can't have also timed out. + } + return l.newConn(rwc), nil + + case <-l.timedOut: + if err == nil { + // Keeping the connection open would leave the listener simultaneously + // active and closed due to idleness, which would be contradictory and + // confusing. Close the connection and pretend that it never happened. + rwc.Close() + } else { + // In theory the timeout could have raced with an unrelated error return + // from Accept. However, ErrIdleTimeout is arguably still valid (since we + // would have closed due to the timeout independent of the error), and the + // harm from returning a spurious ErrIdleTimeout is negligible anyway. + } + return nil, ErrIdleTimeout + + case timer := <-l.idleTimer: + if err != nil { + // The idle timer doesn't run until it receives itself from the idleTimer + // channel, so it can't have called l.wrapped.Close yet and thus err can't + // be ErrIdleTimeout. Leave the idle timer as it was and return whatever + // error we got. + l.idleTimer <- timer + return nil, err + } + + if !timer.Stop() { + // Failed to stop the timer — the timer goroutine is in the process of + // firing. Send the timer back to the timer goroutine so that it can + // safely close the timedOut channel, and then wait for the listener to + // actually be closed before we return ErrIdleTimeout. + l.idleTimer <- timer + rwc.Close() + <-l.timedOut + return nil, ErrIdleTimeout + } + + l.active <- 1 + return l.newConn(rwc), nil + } +} + +func (l *idleListener) Close() error { + select { + case _, ok := <-l.active: + if ok { + close(l.active) + } + + case <-l.timedOut: + // Already closed by the timer; take care not to double-close if the caller + // only explicitly invokes this Close method once, since the io.Closer + // interface explicitly leaves doubled Close calls undefined. + return ErrIdleTimeout + + case timer := <-l.idleTimer: + if !timer.Stop() { + // Couldn't stop the timer. It shouldn't take long to run, so just wait + // (so that the Listener is guaranteed to be closed before we return) + // and pretend that this call happened afterward. + // That way we won't leak any timers or goroutines when Close returns. + l.idleTimer <- timer + <-l.timedOut + return ErrIdleTimeout + } + close(l.active) + } + + return l.wrapped.Close() +} + +func (l *idleListener) Dialer() Dialer { + return l.wrapped.Dialer() +} + +func (l *idleListener) timerExpired() { + select { + case n, ok := <-l.active: + if ok { + panic(fmt.Sprintf("jsonrpc2: idleListener idle timer fired with %d connections still active", n)) + } else { + panic("jsonrpc2: Close finished with idle timer still running") + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired more than once") + + case <-l.idleTimer: + // The timer for this very call! + } + + // Close the Listener with all channels still blocked to ensure that this call + // to l.wrapped.Close doesn't race with the one in l.Close. + defer close(l.timedOut) + l.wrapped.Close() +} + +func (l *idleListener) connClosed() { + select { + case n, ok := <-l.active: + if !ok { + // l is already closed, so it can't close due to idleness, + // and we don't need to track the number of active connections any more. + return + } + n-- + if n == 0 { + l.idleTimer <- time.AfterFunc(l.timeout, l.timerExpired) + } else { + l.active <- n + } + + case <-l.timedOut: + panic("jsonrpc2: idleListener idle timer fired before last active connection was closed") + + case <-l.idleTimer: + panic("jsonrpc2: idleListener idle timer active before last active connection was closed") + } +} + +type idleListenerConn struct { + wrapped io.ReadWriteCloser + l *idleListener + closeOnce sync.Once +} + +func (l *idleListener) newConn(rwc io.ReadWriteCloser) *idleListenerConn { + c := &idleListenerConn{ + wrapped: rwc, + l: l, + } + + // A caller that forgets to call Close may disrupt the idleListener's + // accounting, even though the file descriptor for the underlying connection + // may eventually be garbage-collected anyway. + // + // Set a (best-effort) finalizer to verify that a Close call always occurs. + // (We will clear the finalizer explicitly in Close.) + runtime.SetFinalizer(c, func(c *idleListenerConn) { + panic("jsonrpc2: IdleListener connection became unreachable without a call to Close") + }) + + return c +} + +func (c *idleListenerConn) Read(p []byte) (int, error) { return c.wrapped.Read(p) } +func (c *idleListenerConn) Write(p []byte) (int, error) { return c.wrapped.Write(p) } + +func (c *idleListenerConn) Close() error { + defer c.closeOnce.Do(func() { + c.l.connClosed() + runtime.SetFinalizer(c, nil) + }) + return c.wrapped.Close() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go new file mode 100644 index 000000000..c0a41bffb --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go @@ -0,0 +1,97 @@ +// Copyright 2018 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package jsonrpc2 + +import ( + "encoding/json" +) + +// This file contains the go forms of the wire specification. +// see http://www.jsonrpc.org/specification for details + +var ( + // ErrParse is used when invalid JSON was received by the server. + ErrParse = NewError(-32700, "parse error") + // ErrInvalidRequest is used when the JSON sent is not a valid Request object. + ErrInvalidRequest = NewError(-32600, "invalid request") + // ErrMethodNotFound should be returned by the handler when the method does + // not exist / is not available. + ErrMethodNotFound = NewError(-32601, "method not found") + // ErrInvalidParams should be returned by the handler when method + // parameter(s) were invalid. + ErrInvalidParams = NewError(-32602, "invalid params") + // ErrInternal indicates a failure to process a call correctly + ErrInternal = NewError(-32603, "internal error") + + // The following errors are not part of the json specification, but + // compliant extensions specific to this implementation. + + // ErrServerOverloaded is returned when a message was refused due to a + // server being temporarily unable to accept any new messages. + ErrServerOverloaded = NewError(-32000, "overloaded") + // ErrUnknown should be used for all non coded errors. + ErrUnknown = NewError(-32001, "unknown error") + // ErrServerClosing is returned for calls that arrive while the server is closing. + ErrServerClosing = NewError(-32004, "server is closing") + // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. + ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32005, "rejected by transport") +) + +const wireVersion = "2.0" + +// wireCombined has all the fields of both Request and Response. +// We can decode this and then work out which it is. +type wireCombined struct { + VersionTag string `json:"jsonrpc"` + ID any `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *WireError `json:"error,omitempty"` +} + +// WireError represents a structured error in a Response. +type WireError struct { + // Code is an error code indicating the type of failure. + Code int64 `json:"code"` + // Message is a short description of the error. + Message string `json:"message"` + // Data is optional structured data containing additional information about the error. + Data json.RawMessage `json:"data,omitempty"` +} + +// NewError returns an error that will encode on the wire correctly. +// The standard codes are made available from this package, this function should +// only be used to build errors for application specific codes as allowed by the +// specification. +func NewError(code int64, message string) error { + return &WireError{ + Code: code, + Message: message, + } +} + +func (err *WireError) Error() string { + return err.Message +} + +func (err *WireError) Is(other error) bool { + w, ok := other.(*WireError) + if !ok { + return false + } + return err.Code == w.Code +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go new file mode 100644 index 000000000..4b5c325fa --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/util/util.go @@ -0,0 +1,44 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package util + +import ( + "cmp" + "fmt" + "iter" + "slices" +) + +// Helpers below are copied from gopls' moremaps package. + +// Sorted returns an iterator over the entries of m in key order. +func Sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] { + // TODO(adonovan): use maps.Sorted if proposal #68598 is accepted. + return func(yield func(K, V) bool) { + keys := KeySlice(m) + slices.Sort(keys) + for _, k := range keys { + if !yield(k, m[k]) { + break + } + } + } +} + +// KeySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)). +func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// Wrapf wraps *errp with the given formatted message if *errp is not nil. +func Wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go new file mode 100644 index 000000000..849060d57 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/internal/xcontext/xcontext.go @@ -0,0 +1,23 @@ +// Copyright 2019 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package xcontext is a package to offer the extra functionality we need +// from contexts that is not available from the standard context package. +package xcontext + +import ( + "context" + "time" +) + +// Detach returns a context that keeps all the values of its parent context +// but detaches from the cancellation and error handling. +func Detach(ctx context.Context) context.Context { return detachedContext{ctx} } + +type detachedContext struct{ parent context.Context } + +func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } +func (v detachedContext) Done() <-chan struct{} { return nil } +func (v detachedContext) Err() error { return nil } +func (v detachedContext) Value(key any) any { return v.parent.Value(key) } diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go new file mode 100644 index 000000000..a9ea78fa8 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go @@ -0,0 +1,56 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response + // Error is a structured error in a JSON-RPC response. + Error = jsonrpc2.WireError +) + +// MakeID coerces the given Go value to an ID. The value should be the +// default JSON marshaling of a Request identifier: nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +func MakeID(v any) (ID, error) { + return jsonrpc2.MakeID(v) +} + +// EncodeMessage serializes a JSON-RPC message to its wire format. +func EncodeMessage(msg Message) ([]byte, error) { + return jsonrpc2.EncodeMessage(msg) +} + +// DecodeMessage deserializes JSON-RPC wire format data into a Message. +// It returns either a Request or Response based on the message content. +func DecodeMessage(data []byte) (Message, error) { + return jsonrpc2.DecodeMessage(data) +} + +// Standard JSON-RPC 2.0 error codes. +// See https://www.jsonrpc.org/specification#error_object +const ( + // CodeParseError indicates invalid JSON was received by the server. + CodeParseError = -32700 + // CodeInvalidRequest indicates the JSON sent is not a valid Request object. + CodeInvalidRequest = -32600 + // CodeMethodNotFound indicates the method does not exist or is not available. + CodeMethodNotFound = -32601 + // CodeInvalidParams indicates invalid method parameter(s). + CodeInvalidParams = -32602 + // CodeInternalError indicates an internal JSON-RPC error. + CodeInternalError = -32603 +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go new file mode 100644 index 000000000..2dc1a86c0 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go @@ -0,0 +1,1075 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "iter" + "log/slog" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// A Client is an MCP client, which may be connected to an MCP server +// using the [Client.Connect] method. +type Client struct { + impl *Implementation + opts ClientOptions + logger *slog.Logger // TODO: file proposal to export this + mu sync.Mutex + roots *featureSet[*Root] + sessions []*ClientSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler +} + +// NewClient creates a new [Client]. +// +// Use [Client.Connect] to connect it to an MCP server. +// +// The first argument must not be nil. +// +// If non-nil, the provided options configure the Client. +func NewClient(impl *Implementation, opts *ClientOptions) *Client { + if impl == nil { + panic("nil Implementation") + } + c := &Client{ + impl: impl, + logger: ensureLogger(nil), // ensure we have a logger + roots: newFeatureSet(func(r *Root) string { return r.URI }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + } + if opts != nil { + c.opts = *opts + } + return c +} + +// ClientOptions configures the behavior of the client. +type ClientOptions struct { + // CreateMessageHandler handles incoming requests for sampling/createMessage. + // + // Setting CreateMessageHandler to a non-nil value automatically causes the + // client to advertise the sampling capability, with default value + // &SamplingCapabilities{}. If [ClientOptions.Capabilities] is set and has a + // non nil value for [ClientCapabilities.Sampling], that value overrides the + // inferred capability. + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // ElicitationHandler handles incoming requests for elicitation/create. + // + // Setting ElicitationHandler to a non-nil value automatically causes the + // client to advertise the elicitation capability, with default value + // &ElicitationCapabilities{}. If [ClientOptions.Capabilities] is set and has + // a non nil value for [ClientCapabilities.ELicitattion], that value + // overrides the inferred capability. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) + // Capabilities optionally configures the client's default capabilities, + // before any capabilities are inferred from other configuration. + // + // If Capabilities is nil, the default client capabilities are + // {"roots":{"listChanged":true}}, for historical reasons. Setting + // Capabilities to a non-nil value overrides this default. As a special case, + // to work around #607, Capabilities.Roots is ignored: set + // Capabilities.RootsV2 to configure the roots capability. This allows the + // "roots" capability to be disabled entirely. + // + // For example: + // - To disable the "roots" capability, use &ClientCapabilities{} + // - To configure "roots", but disable "listChanged" notifications, use + // &ClientCapabilities{RootsV2:&RootCapabilities{}}. + // + // # Interaction with capability inference + // + // Sampling and elicitation capabilities are automatically added when their + // corresponding handlers are set, with the default value described at + // [ClientOptions.CreateMessageHandler] and + // [ClientOptions.ElicitationHandler]. If the Sampling or Elicitation fields + // are set in the Capabilities field, their values override the inferred + // value. + // + // For example, to to configure elicitation modes: + // + // Capabilities: &ClientCapabilities{ + // Elicitation: &ElicitationCapabilities{ + // Form: &FormElicitationCapabilities{}, + // URL: &URLElicitationCapabilities{}, + // }, + // } + // + // Conversely, if Capabilities does not set a field (for example, if the + // Elicitation field is nil), the inferred elicitation capability will be + // used. + Capabilities *ClientCapabilities + // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. + ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) + // Handlers for notifications from the server. + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration +} + +// bind implements the binder[*ClientSession] interface, so that Clients can +// be connected using [connect]. +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} + if state != nil { + cs.state = *state + } + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = append(c.sessions, cs) + return cs +} + +// disconnect implements the binder[*Client] interface, so that +// Clients can be connected using [connect]. +func (c *Client) disconnect(cs *ClientSession) { + c.mu.Lock() + defer c.mu.Unlock() + c.sessions = slices.DeleteFunc(c.sessions, func(cs2 *ClientSession) bool { + return cs2 == cs + }) +} + +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct { + // protocolVersion overrides the protocol version sent in the initialize + // request, for testing. If empty, latestProtocolVersion is used. + protocolVersion string +} + +func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ClientCapabilities + if c.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = c.opts.Capabilities.clone() + } else { + // SDK defaults: roots with listChanged. + // (this was the default behavior at v1.0.0, and so cannot be changed) + caps = &ClientCapabilities{ + RootsV2: &RootCapabilities{ + ListChanged: true, + }, + } + } + + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + + // Augment with sampling capability if handler is set. + if c.opts.CreateMessageHandler != nil { + if caps.Sampling == nil { + caps.Sampling = &SamplingCapabilities{} + } + } + + // Augment with elicitation capability if handler is set. + if c.opts.ElicitationHandler != nil { + if caps.Elicitation == nil { + caps.Elicitation = &ElicitationCapabilities{} + // Form elicitation was added in 2025-11-25; for older versions, + // {} is treated the same as {"form":{}}. + if protocolVersion >= protocolVersion20251125 { + caps.Elicitation.Form = &FormElicitationCapabilities{} + } + } + } + return caps +} + +// Connect begins an MCP session by connecting to a server over the given +// transport. The resulting session is initialized, and ready to use. +// +// Typically, it is the responsibility of the client to close the connection +// when it is no longer needed. However, if the connection is closed by the +// server, calls or notifications will return an error wrapping +// [ErrConnectionClosed]. +func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) + if err != nil { + return nil, err + } + + protocolVersion := latestProtocolVersion + if opts != nil && opts.protocolVersion != "" { + protocolVersion = opts.protocolVersion + } + params := &InitializeParams{ + ProtocolVersion: protocolVersion, + ClientInfo: c.impl, + Capabilities: c.capabilities(protocolVersion), + } + req := &InitializeRequest{Session: cs, Params: params} + res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) + if err != nil { + _ = cs.Close() + return nil, err + } + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } + cs.state.InitializeResult = res + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + req2 := &initializedClientRequest{Session: cs, Params: &InitializedParams{}} + if err := handleNotify(ctx, notificationInitialized, req2); err != nil { + _ = cs.Close() + return nil, err + } + + if c.opts.KeepAlive > 0 { + cs.startKeepalive(c.opts.KeepAlive) + } + + return cs, nil +} + +// A ClientSession is a logical connection with an MCP server. Its +// methods can be used to send requests or notifications to the server. Create +// a session by calling [Client.Connect]. +// +// Call [ClientSession.Close] to close the connection, or await server +// termination with [ClientSession.Wait]. +type ClientSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState + + // Pending URL elicitations waiting for completion notifications. + pendingElicitationsMu sync.Mutex + pendingElicitations map[string]chan struct{} +} + +type clientSessionState struct { + InitializeResult *InitializeResult +} + +func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } + +func (cs *ClientSession) ID() string { + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Close performs a graceful close of the connection, preventing new requests +// from being handled, and waiting for ongoing requests to return. Close then +// terminates the connection. +// +// Close is idempotent and concurrency safe. +func (cs *ClientSession) Close() error { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + if cs.keepaliveCancel != nil { + cs.keepaliveCancel() + } + err := cs.conn.Close() + + if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { + cs.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the server. +// Generally, clients should be responsible for closing the connection. +func (cs *ClientSession) Wait() error { + return cs.conn.Wait() +} + +// registerElicitationWaiter registers a waiter for an elicitation complete +// notification with the given elicitation ID. It returns two functions: an await +// function that waits for the notification or context cancellation, and a cleanup +// function that must be called to unregister the waiter. This must be called before +// triggering the elicitation to avoid a race condition where the notification +// arrives before the waiter is registered. +// +// The cleanup function must be called even if the await function is never called, +// to prevent leaking the registration. +func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) { + // Create a channel for this elicitation. + ch := make(chan struct{}, 1) + + // Register the channel. + cs.pendingElicitationsMu.Lock() + if cs.pendingElicitations == nil { + cs.pendingElicitations = make(map[string]chan struct{}) + } + cs.pendingElicitations[elicitationID] = ch + cs.pendingElicitationsMu.Unlock() + + // Return await and cleanup functions. + await = func(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err()) + case <-ch: + return nil + } + } + + cleanup = func() { + cs.pendingElicitationsMu.Lock() + delete(cs.pendingElicitations, elicitationID) + cs.pendingElicitationsMu.Unlock() + } + + return await, cleanup +} + +// startKeepalive starts the keepalive mechanism for this client session. +func (cs *ClientSession) startKeepalive(interval time.Duration) { + startKeepalive(cs, interval, &cs.keepaliveCancel) +} + +// AddRoots adds the given roots to the client, +// replacing any with the same URIs, +// and notifies any connected servers. +func (c *Client) AddRoots(roots ...*Root) { + // Only notify if something could change. + if len(roots) == 0 { + return + } + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { c.roots.add(roots...); return true }) +} + +// RemoveRoots removes the roots with the given URIs, +// and notifies any connected servers if the list has changed. +// It is not an error to remove a nonexistent root. +func (c *Client) RemoveRoots(uris ...string) { + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, + func() bool { return c.roots.remove(uris...) }) +} + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it notifies a snapshot of the sessions. +func changeAndNotify[P Params](c *Client, notification string, params P, change func() bool) { + var sessions []*ClientSession + // Lock for the change, but not for the notification. + c.mu.Lock() + if change() { + // Check if listChanged is enabled for this notification type. + if c.shouldSendListChangedNotification(notification) { + sessions = slices.Clone(c.sessions) + } + } + c.mu.Unlock() + notifySessions(sessions, notification, params, c.logger) +} + +// shouldSendListChangedNotification checks if the client's capabilities allow +// sending the given list-changed notification. +func (c *Client) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := c.opts.Capabilities + + switch notification { + case notificationRootsListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil { + return true + } + // Check RootsV2 first (preferred), then fall back to Roots. + if caps.RootsV2 != nil { + return caps.RootsV2.ListChanged + } + return caps.Roots.ListChanged + default: + // Unknown notification, allow by default. + return true + } +} + +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { + c.mu.Lock() + defer c.mu.Unlock() + roots := slices.Collect(c.roots.all()) + if roots == nil { + roots = []*Root{} // avoid JSON null + } + return &ListRootsResult{ + Roots: roots, + }, nil +} + +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + if c.opts.CreateMessageHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} + } + return c.opts.CreateMessageHandler(ctx, req) +} + +// urlElicitationMiddleware returns middleware that automatically handles URL elicitation +// required errors by executing the elicitation handler, waiting for completion notifications, +// and retrying the operation. +// +// This middleware should be added to clients that want automatic URL elicitation handling: +// +// client := mcp.NewClient(impl, opts) +// client.AddSendingMiddleware(mcp.urlElicitationMiddleware()) +// +// TODO(rfindley): this isn't strictly necessary for the SEP, but may be +// useful. Propose exporting it. +func urlElicitationMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + // Call the underlying handler. + res, err := next(ctx, method, req) + if err == nil { + return res, nil + } + + // Check if this is a URL elicitation required error. + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired { + return res, err + } + + // Notifications don't support retries. + if strings.HasPrefix(method, "notifications/") { + return res, err + } + + // Extract the client session. + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, err + } + + // Check if the client has an elicitation handler. + if cs.client.opts.ElicitationHandler == nil { + return res, err + } + + // Parse the elicitations from the error data. + var errorData struct { + Elicitations []*ElicitParams `json:"elicitations"` + } + if rpcErr.Data != nil { + if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil { + return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err) + } + } + + // Validate that all elicitations are URL mode. + for _, elicit := range errorData.Elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // Default mode. + } + if mode != "url" { + return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode) + } + } + + // Register waiters for all elicitations before executing handlers + // to avoid race condition where notification arrives before waiter is registered. + type waiter struct { + await func(context.Context) error + cleanup func() + } + waiters := make([]waiter, 0, len(errorData.Elicitations)) + for _, elicitParams := range errorData.Elicitations { + await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID) + waiters = append(waiters, waiter{await: await, cleanup: cleanup}) + } + + // Ensure cleanup happens even if we return early. + defer func() { + for _, w := range waiters { + w.cleanup() + } + }() + + // Execute the elicitation handler for each elicitation. + for _, elicitParams := range errorData.Elicitations { + elicitReq := newClientRequest(cs, elicitParams) + _, elicitErr := cs.client.elicit(ctx, elicitReq) + if elicitErr != nil { + return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr) + } + } + + // Wait for all elicitations to complete. + for _, w := range waiters { + if err := w.await(ctx); err != nil { + return nil, err + } + } + + // All elicitations complete, retry the original operation. + return next(ctx, method, req) + } + } +} + +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"} + } + + // Validate the elicitation parameters based on the mode. + mode := req.Params.Mode + if mode == "" { + mode = "form" + } + + switch mode { + case "form": + if req.Params.URL != "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"} + } + schema, err := validateElicitSchema(req.Params.RequestedSchema) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()} + } + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + // Validate elicitation result content against requested schema. + if schema != nil && res.Content != nil { + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)} + } + if err := resolved.Validate(res.Content); err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)} + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)} + } + } + return res, nil + case "url": + if req.Params.RequestedSchema != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"} + } + if req.Params.URL == "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"} + } + // No schema validation for URL mode, just pass through to handler. + return c.opts.ElicitationHandler(ctx, req) + default: + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)} + } +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) { + if wireSchema == nil { + return nil, nil // nil schema is allowed + } + + var schema *jsonschema.Schema + if err := remarshal(wireSchema, &schema); err != nil { + return nil, err + } + if schema == nil { + return nil, nil + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return nil, fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return nil, err + } + } + } + + return schema, nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]any); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return validateDefaultProperty[string](propName, propSchema) +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + intDefaultError := validateDefaultProperty[int](propName, propSchema) + floatDefaultError := validateDefaultProperty[float64](propName, propSchema) + if intDefaultError != nil && floatDefaultError != nil { + return fmt.Errorf("elicit schema property %q has default value that cannot be interpreted as an int or float", propName) + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + return validateDefaultProperty[bool](propName, propSchema) +} + +func validateDefaultProperty[T any](propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid T + if propSchema.Default != nil { + var defaultValue T + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a %T: %v", propName, defaultValue, err) + } + } + return nil +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (c *Client) AddSendingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) { + c.mu.Lock() + defer c.mu.Unlock() + addMiddleware(&c.receivingMethodHandler_, middleware) +} + +// clientMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var clientMethodInfos = map[string]methodInfo{ + methodComplete: newClientMethodInfo(clientSessionMethod((*ClientSession).Complete), 0), + methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), + notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), + notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), +} + +func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { + return serverMethodInfos +} + +func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { + return clientMethodInfos +} + +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } + return handleReceive(ctx, cs, req) +} + +func (cs *ClientSession) sendingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendingMethodHandler_ +} + +func (cs *ClientSession) receivingMethodHandler() MethodHandler { + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.receivingMethodHandler_ +} + +// getConn implements [Session.getConn]. +func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } + +func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { + return &ClientRequest[P]{Session: cs, Params: params} +} + +// Ping makes an MCP "ping" request to the server. +func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListPrompts lists prompts that are currently available on the server. +func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) +} + +// GetPrompt gets a prompt from the server. +func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) +} + +// ListTools lists tools that are currently available on the server. +func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) +} + +// CallTool calls the tool with the given parameters. +// +// The params.Arguments can be any value that marshals into a JSON object. +func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { + if params == nil { + params = new(CallToolParams) + } + if params.Arguments == nil { + // Avoid sending nil over the wire. + params.Arguments = map[string]any{} + } + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { + _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) + return err +} + +// ListResources lists the resources that are currently available on the server. +func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) +} + +// ListResourceTemplates lists the resource templates that are currently available on the server. +func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { + return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) +} + +// ReadResource asks the server to read a resource and return its contents. +func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) +} + +func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { + return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) +} + +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err +} + +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { + if h := c.opts.ToolListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { + if h := c.opts.PromptListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { + if h := c.opts.ResourceListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { + if h := c.opts.ResourceUpdatedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { + if h := c.opts.LoggingMessageHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { + if h := cs.client.opts.ProgressNotificationHandler; h != nil { + h(ctx, clientRequestFor(cs, params)) + } + return nil, nil +} + +func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) { + // Check if there's a pending elicitation waiting for this notification. + if cs, ok := req.GetSession().(*ClientSession); ok { + cs.pendingElicitationsMu.Lock() + if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists { + select { + case ch <- struct{}{}: + default: + // Channel already signaled. + } + } + cs.pendingElicitationsMu.Unlock() + } + + // Call the user's handler if provided. + if h := c.opts.ElicitationCompleteHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the client to the server +// associated with this session. +// This can be used if the client is performing a long-running task that was +// initiated by the server. +func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) +} + +// Tools provides an iterator for all tools available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[*Tool, error] { + if params == nil { + params = &ListToolsParams{} + } + return paginate(ctx, params, cs.ListTools, func(res *ListToolsResult) []*Tool { + return res.Tools + }) +} + +// Resources provides an iterator for all resources available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[*Resource, error] { + if params == nil { + params = &ListResourcesParams{} + } + return paginate(ctx, params, cs.ListResources, func(res *ListResourcesResult) []*Resource { + return res.Resources + }) +} + +// ResourceTemplates provides an iterator for all resource templates available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) iter.Seq2[*ResourceTemplate, error] { + if params == nil { + params = &ListResourceTemplatesParams{} + } + return paginate(ctx, params, cs.ListResourceTemplates, func(res *ListResourceTemplatesResult) []*ResourceTemplate { + return res.ResourceTemplates + }) +} + +// Prompts provides an iterator for all prompts available on the server, +// automatically fetching pages and managing cursors. +// The params argument can set the initial cursor. +// Iteration stops at the first encountered error, which will be yielded. +func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[*Prompt, error] { + if params == nil { + params = &ListPromptsParams{} + } + return paginate(ctx, params, cs.ListPrompts, func(res *ListPromptsResult) []*Prompt { + return res.Prompts + }) +} + +// paginate is a generic helper function to provide a paginated iterator. +func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + for { + res, err := listFunc(ctx, params) + if err != nil { + yield(nil, err) + return + } + for _, r := range items(res) { + if !yield(r, nil) { + return + } + } + nextCursorVal := res.nextCursorPtr() + if nextCursorVal == nil || *nextCursorVal == "" { + return + } + *params.cursorPtr() = *nextCursorVal + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go new file mode 100644 index 000000000..b531eaf13 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/cmd.go @@ -0,0 +1,108 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "io" + "os/exec" + "syscall" + "time" +) + +var defaultTerminateDuration = 5 * time.Second // mutable for testing + +// A CommandTransport is a [Transport] that runs a command and communicates +// with it over stdin/stdout, using newline-delimited JSON. +type CommandTransport struct { + Command *exec.Cmd + // TerminateDuration controls how long Close waits after closing stdin + // for the process to exit before sending SIGTERM. + // If zero or negative, the default of 5s is used. + TerminateDuration time.Duration +} + +// Connect starts the command, and connects to it over stdin/stdout. +func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { + stdout, err := t.Command.StdoutPipe() + if err != nil { + return nil, err + } + stdout = io.NopCloser(stdout) // close the connection by closing stdin, not stdout + stdin, err := t.Command.StdinPipe() + if err != nil { + return nil, err + } + if err := t.Command.Start(); err != nil { + return nil, err + } + td := t.TerminateDuration + if td <= 0 { + td = defaultTerminateDuration + } + return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil +} + +// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over +// stdin/stdout pipes. +type pipeRWC struct { + cmd *exec.Cmd + stdout io.ReadCloser + stdin io.WriteCloser + terminateDuration time.Duration +} + +func (s *pipeRWC) Read(p []byte) (n int, err error) { + return s.stdout.Read(p) +} + +func (s *pipeRWC) Write(p []byte) (n int, err error) { + return s.stdin.Write(p) +} + +// Close closes the input stream to the child process, and awaits normal +// termination of the command. If the command does not exit, it is signalled to +// terminate, and then eventually killed. +func (s *pipeRWC) Close() error { + // Spec: + // "For the stdio transport, the client SHOULD initiate shutdown by:... + + // "...First, closing the input stream to the child process (the server)" + if err := s.stdin.Close(); err != nil { + return fmt.Errorf("closing stdin: %v", err) + } + resChan := make(chan error, 1) + go func() { + resChan <- s.cmd.Wait() + }() + // "...Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time" + wait := func() (error, bool) { + select { + case err := <-resChan: + return err, true + case <-time.After(s.terminateDuration): + } + return nil, false + } + if err, ok := wait(); ok { + return err + } + // Note the condition here: if sending SIGTERM fails, don't wait and just + // move on to SIGKILL. + if err := s.cmd.Process.Signal(syscall.SIGTERM); err == nil { + if err, ok := wait(); ok { + return err + } + } + // "...Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM" + if err := s.cmd.Process.Kill(); err != nil { + return err + } + if err, ok := wait(); ok { + return err + } + return fmt.Errorf("unresponsive subprocess") +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go new file mode 100644 index 000000000..fb1a0d1e5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go @@ -0,0 +1,289 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + +package mcp + +import ( + "encoding/json" + "errors" + "fmt" +) + +// A Content is a [TextContent], [ImageContent], [AudioContent], +// [ResourceLink], or [EmbeddedResource]. +type Content interface { + MarshalJSON() ([]byte, error) + fromWire(*wireContent) +} + +// TextContent is a textual content. +type TextContent struct { + Text string + Meta Meta + Annotations *Annotations +} + +func (c *TextContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ + Type: "text", + Text: c.Text, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *TextContent) fromWire(wire *wireContent) { + c.Text = wire.Text + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ImageContent contains base64-encoded image data. +type ImageContent struct { + Meta Meta + Annotations *Annotations + Data []byte // base64-encoded + MIMEType string +} + +func (c *ImageContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "image", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *ImageContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// AudioContent contains base64-encoded audio data. +type AudioContent struct { + Data []byte + MIMEType string + Meta Meta + Annotations *Annotations +} + +func (c AudioContent) MarshalJSON() ([]byte, error) { + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ + Type: "audio", + MIMEType: c.MIMEType, + Data: data, + Meta: c.Meta, + Annotations: c.Annotations, + } + return json.Marshal(wire) +} + +func (c *AudioContent) fromWire(wire *wireContent) { + c.MIMEType = wire.MIMEType + c.Data = wire.Data + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// Custom wire format to ensure required fields are always included, even when empty. +type imageAudioWire struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + +// ResourceLink is a link to a resource +type ResourceLink struct { + URI string + Name string + Title string + Description string + MIMEType string + Size *int64 + Meta Meta + Annotations *Annotations + // Icons for the resource link, if any. + Icons []Icon `json:"icons,omitempty"` +} + +func (c *ResourceLink) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource_link", + URI: c.URI, + Name: c.Name, + Title: c.Title, + Description: c.Description, + MIMEType: c.MIMEType, + Size: c.Size, + Meta: c.Meta, + Annotations: c.Annotations, + Icons: c.Icons, + }) +} + +func (c *ResourceLink) fromWire(wire *wireContent) { + c.URI = wire.URI + c.Name = wire.Name + c.Title = wire.Title + c.Description = wire.Description + c.MIMEType = wire.MIMEType + c.Size = wire.Size + c.Meta = wire.Meta + c.Annotations = wire.Annotations + c.Icons = wire.Icons +} + +// EmbeddedResource contains embedded resources. +type EmbeddedResource struct { + Resource *ResourceContents + Meta Meta + Annotations *Annotations +} + +func (c *EmbeddedResource) MarshalJSON() ([]byte, error) { + return json.Marshal(&wireContent{ + Type: "resource", + Resource: c.Resource, + Meta: c.Meta, + Annotations: c.Annotations, + }) +} + +func (c *EmbeddedResource) fromWire(wire *wireContent) { + c.Resource = wire.Resource + c.Meta = wire.Meta + c.Annotations = wire.Annotations +} + +// ResourceContents contains the contents of a specific resource or +// sub-resource. +type ResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` + Meta Meta `json:"_meta,omitempty"` +} + +func (r *ResourceContents) MarshalJSON() ([]byte, error) { + // If we could assume Go 1.24, we could use omitzero for Blob and avoid this method. + if r.URI == "" { + return nil, errors.New("ResourceContents missing URI") + } + if r.Blob == nil { + // Text. Marshal normally. + type wireResourceContents ResourceContents // (lacks MarshalJSON method) + return json.Marshal((wireResourceContents)(*r)) + } + // Blob. + if r.Text != "" { + return nil, errors.New("ResourceContents has non-zero Text and Blob fields") + } + // r.Blob may be the empty slice, so marshal with an alternative definition. + br := struct { + URI string `json:"uri,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` + }{ + URI: r.URI, + MIMEType: r.MIMEType, + Blob: r.Blob, + Meta: r.Meta, + } + return json.Marshal(br) +} + +// wireContent is the wire format for content. +// It represents the protocol types TextContent, ImageContent, AudioContent, +// ResourceLink, and EmbeddedResource. +// The Type field distinguishes them. In the protocol, each type has a constant +// value for the field. +// At most one of Text, Data, Resource, and URI is non-zero. +type wireContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + Data []byte `json:"data,omitempty"` + Resource *ResourceContents `json:"resource,omitempty"` + URI string `json:"uri,omitempty"` + Name string `json:"name,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Size *int64 `json:"size,omitempty"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + Icons []Icon `json:"icons,omitempty"` +} + +func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { + var blocks []Content + for _, wire := range wires { + block, err := contentFromWire(wire, allow) + if err != nil { + return nil, err + } + blocks = append(blocks, block) + } + return blocks, nil +} + +func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("nil content") + } + if allow != nil && !allow[wire.Type] { + return nil, fmt.Errorf("invalid content type %q", wire.Type) + } + switch wire.Type { + case "text": + v := new(TextContent) + v.fromWire(wire) + return v, nil + case "image": + v := new(ImageContent) + v.fromWire(wire) + return v, nil + case "audio": + v := new(AudioContent) + v.fromWire(wire) + return v, nil + case "resource_link": + v := new(ResourceLink) + v.fromWire(wire) + return v, nil + case "resource": + v := new(EmbeddedResource) + v.fromWire(wire) + return v, nil + } + return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go new file mode 100644 index 000000000..5c322c4a3 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go @@ -0,0 +1,429 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file is for SSE events. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events. + +package mcp + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "iter" + "maps" + "net/http" + "slices" + "strings" + "sync" +) + +// If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. +// Enable for debugging. +const validateMemoryEventStore = false + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field + Retry string // the "retry" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 && e.Retry == "" +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + if evt.Retry != "" { + fmt.Fprintf(&b, "retry: %s\n", evt.Retry) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + retryKey = []byte("retry") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, retryKey): + evt.Retry = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} + +// An EventStore tracks data for SSE streams. +// A single EventStore suffices for all sessions, since session IDs are +// globally unique. So one EventStore can be created per process, for +// all Servers in the process. +// Such a store is able to bound resource usage for the entire process. +// +// All of an EventStore's methods must be safe for use by multiple goroutines. +type EventStore interface { + // Open is called when a new stream is created. It may be used to ensure that + // the underlying data structure for the stream is initialized, making it + // ready to store and replay event streams. + Open(_ context.Context, sessionID, streamID string) error + + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID, streamID string, data []byte) error + + // After returns an iterator over the data for the given session and stream, beginning + // just after the given index. + // + // Once the iterator yields a non-nil error, it will stop. + // After's iterator must return an error immediately if any data after index was + // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] + + // SessionClosed informs the store that the given session is finished, along + // with all of its streams. + // + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. +} + +// A dataList is a list of []byte. +// The zero dataList is ready to use. +type dataList struct { + size int // total size of data bytes + first int // the stream index of the first element in data + data [][]byte +} + +func (dl *dataList) appendData(d []byte) { + // Empty data consumes memory but doesn't increment size. However, it should + // be rare. + dl.data = append(dl.data, d) + dl.size += len(d) +} + +// removeFirst removes the first data item in dl, returning the size of the item. +// It panics if dl is empty. +func (dl *dataList) removeFirst() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + r := len(dl.data[0]) + dl.size -= r + dl.data[0] = nil // help GC + dl.data = dl.data[1:] + dl.first++ + return r +} + +// A MemoryEventStore is an [EventStore] backed by memory. +type MemoryEventStore struct { + mu sync.Mutex + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList +} + +// MemoryEventStoreOptions are options for a [MemoryEventStore]. +type MemoryEventStoreOptions struct{} + +// MaxBytes returns the maximum number of bytes that the store will retain before +// purging data. +func (s *MemoryEventStore) MaxBytes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.maxBytes +} + +// SetMaxBytes sets the maximum number of bytes the store will retain before purging +// data. The argument must not be negative. If it is zero, a suitable default will be used. +// SetMaxBytes can be called at any time. The size of the store will be adjusted +// immediately. +func (s *MemoryEventStore) SetMaxBytes(n int) { + s.mu.Lock() + defer s.mu.Unlock() + switch { + case n < 0: + panic("negative argument") + case n == 0: + s.maxBytes = defaultMaxBytes + default: + s.maxBytes = n + } + s.purge() +} + +const defaultMaxBytes = 10 << 20 // 10 MiB + +// NewMemoryEventStore creates a [MemoryEventStore] with the default value +// for MaxBytes. +func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { + return &MemoryEventStore{ + maxBytes: defaultMaxBytes, + store: make(map[string]map[string]*dataList), + } +} + +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} + +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { + streamMap, ok := s.store[sessionID] + if !ok { + streamMap = make(map[string]*dataList) + s.store[sessionID] = streamMap + } + dl, ok := streamMap[streamID] + if !ok { + dl = &dataList{} + streamMap[streamID] = dl + } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) + // Purge before adding, so at least the current data item will be present. + // (That could result in nBytes > maxBytes, but we'll live with that.) + s.purge() + dl.appendData(data) + s.nBytes += len(data) + return nil +} + +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + +// After implements [EventStore.After]. +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { + // Return the data items to yield. + // We must copy, because dataList.removeFirst nils out slice elements. + copyData := func() ([][]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + streamMap, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + } + dl, ok := streamMap[streamID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + } + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) + } + return slices.Clone(dl.data[start-dl.first:]), nil + } + + return func(yield func([]byte, error) bool) { + ds, err := copyData() + if err != nil { + yield(nil, err) + return + } + for _, d := range ds { + if !yield(d, nil) { + return + } + } + } +} + +// SessionClosed implements [EventStore.SessionClosed]. +func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, dl := range s.store[sessionID] { + s.nBytes -= dl.size + } + delete(s.store, sessionID) + s.validate() + return nil +} + +// purge removes data until no more than s.maxBytes bytes are in use. +// It must be called with s.mu held. +func (s *MemoryEventStore) purge() { + // Remove the first element of every dataList until below the max. + for s.nBytes > s.maxBytes { + changed := false + for _, sm := range s.store { + for _, dl := range sm { + if dl.size > 0 { + r := dl.removeFirst() + if r > 0 { + changed = true + s.nBytes -= r + } + } + } + } + if !changed { + panic("no progress during purge") + } + } + s.validate() +} + +// validate checks that the store's data structures are valid. +// It must be called with s.mu held. +func (s *MemoryEventStore) validate() { + if !validateMemoryEventStore { + return + } + // Check that we're accounting for the size correctly. + n := 0 + for _, sm := range s.store { + for _, dl := range sm { + for _, d := range dl.data { + n += len(d) + } + } + } + if n != s.nBytes { + panic("sizes don't add up") + } +} + +// debugString returns a string containing the state of s. +// Used in tests. +func (s *MemoryEventStore) debugString() string { + s.mu.Lock() + defer s.mu.Unlock() + var b strings.Builder + for i, sess := range slices.Sorted(maps.Keys(s.store)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + sm := s.store[sess] + for i, sid := range slices.Sorted(maps.Keys(sm)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + dl := sm[sid] + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) + for _, d := range dl.data { + fmt.Fprintf(&b, " %s", d) + } + } + } + return b.String() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go new file mode 100644 index 000000000..438370fe5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/features.go @@ -0,0 +1,114 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "iter" + "maps" + "slices" +) + +// This file contains implementations that are common to all features. +// A feature is an item provided to a peer. In the 2025-03-26 spec, +// the features are prompt, tool, resource and root. + +// A featureSet is a collection of features of type T. +// Every feature has a unique ID, and the spec never mentions +// an ordering for the List calls, so what it calls a "list" is actually a set. +// +// An alternative implementation would use an ordered map, but that's probably +// not necessary as adds and removes are rare, and usually batched. +type featureSet[T any] struct { + uniqueID func(T) string + features map[string]T + sortedKeys []string // lazily computed; nil after add or remove +} + +// newFeatureSet creates a new featureSet for features of type T. +// The argument function should return the unique ID for a single feature. +func newFeatureSet[T any](uniqueIDFunc func(T) string) *featureSet[T] { + return &featureSet[T]{ + uniqueID: uniqueIDFunc, + features: make(map[string]T), + } +} + +// add adds each feature to the set if it is not present, +// or replaces an existing feature. +func (s *featureSet[T]) add(fs ...T) { + for _, f := range fs { + s.features[s.uniqueID(f)] = f + } + s.sortedKeys = nil +} + +// remove removes all features with the given uids from the set if present, +// and returns whether any were removed. +// It is not an error to remove a nonexistent feature. +func (s *featureSet[T]) remove(uids ...string) bool { + changed := false + for _, uid := range uids { + if _, ok := s.features[uid]; ok { + changed = true + delete(s.features, uid) + } + } + if changed { + s.sortedKeys = nil + } + return changed +} + +// get returns the feature with the given uid. +// If there is none, it returns zero, false. +func (s *featureSet[T]) get(uid string) (T, bool) { + t, ok := s.features[uid] + return t, ok +} + +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + +// all returns an iterator over of all the features in the set +// sorted by unique ID. +func (s *featureSet[T]) all() iter.Seq[T] { + s.sortKeys() + return func(yield func(T) bool) { + s.yieldFrom(0, yield) + } +} + +// above returns an iterator over features in the set whose unique IDs are +// greater than `uid`, in ascending ID order. +func (s *featureSet[T]) above(uid string) iter.Seq[T] { + s.sortKeys() + index, found := slices.BinarySearch(s.sortedKeys, uid) + if found { + index++ + } + return func(yield func(T) bool) { + s.yieldFrom(index, yield) + } +} + +// sortKeys is a helper that maintains a sorted list of feature IDs. It +// computes this list lazily upon its first call after a modification, or +// if it's nil. +func (s *featureSet[T]) sortKeys() { + if s.sortedKeys != nil { + return + } + s.sortedKeys = slices.Sorted(maps.Keys(s.features)) +} + +// yieldFrom is a helper that iterates over the features in the set, +// starting at the given index, and calls the yield function for each one. +func (s *featureSet[T]) yieldFrom(index int, yield func(T) bool) { + for i := index; i < len(s.sortedKeys); i++ { + if !yield(s.features[s.sortedKeys[i]]) { + return + } + } +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go new file mode 100644 index 000000000..208427e22 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/logging.go @@ -0,0 +1,207 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "log/slog" + "sync" + "time" +) + +// Logging levels. +const ( + LevelDebug = slog.LevelDebug + LevelInfo = slog.LevelInfo + LevelNotice = (slog.LevelInfo + slog.LevelWarn) / 2 + LevelWarning = slog.LevelWarn + LevelError = slog.LevelError + LevelCritical = slog.LevelError + 4 + LevelAlert = slog.LevelError + 8 + LevelEmergency = slog.LevelError + 12 +) + +var slogToMCP = map[slog.Level]LoggingLevel{ + LevelDebug: "debug", + LevelInfo: "info", + LevelNotice: "notice", + LevelWarning: "warning", + LevelError: "error", + LevelCritical: "critical", + LevelAlert: "alert", + LevelEmergency: "emergency", +} + +var mcpToSlog = make(map[LoggingLevel]slog.Level) + +func init() { + for sl, ml := range slogToMCP { + mcpToSlog[ml] = sl + } +} + +func slogLevelToMCP(sl slog.Level) LoggingLevel { + if ml, ok := slogToMCP[sl]; ok { + return ml + } + return "debug" // for lack of a better idea +} + +func mcpLevelToSlog(ll LoggingLevel) slog.Level { + if sl, ok := mcpToSlog[ll]; ok { + return sl + } + // TODO: is there a better default? + return LevelDebug +} + +// compareLevels behaves like [cmp.Compare] for [LoggingLevel]s. +func compareLevels(l1, l2 LoggingLevel) int { + return cmp.Compare(mcpLevelToSlog(l1), mcpLevelToSlog(l2)) +} + +// LoggingHandlerOptions are options for a LoggingHandler. +type LoggingHandlerOptions struct { + // The value for the "logger" field of logging notifications. + LoggerName string + // Limits the rate at which log messages are sent. + // Excess messages are dropped. + // If zero, there is no rate limiting. + MinInterval time.Duration +} + +// A LoggingHandler is a [slog.Handler] for MCP. +type LoggingHandler struct { + opts LoggingHandlerOptions + ss *ServerSession + // Ensures that the buffer reset is atomic with the write (see Handle). + // A pointer so that clones share the mutex. See + // https://github.com/golang/example/blob/master/slog-handler-guide/README.md#getting-the-mutex-right. + mu *sync.Mutex + lastMessageSent time.Time // for rate-limiting + buf *bytes.Buffer + handler slog.Handler +} + +// discardHandler is a slog.Handler that drops all logs. +// TODO: use slog.DiscardHandler when we require Go 1.24+. +type discardHandler struct{} + +func (discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (discardHandler) WithAttrs([]slog.Attr) slog.Handler { return discardHandler{} } +func (discardHandler) WithGroup(string) slog.Handler { return discardHandler{} } + +// ensureLogger returns l if non-nil, otherwise a discard logger. +func ensureLogger(l *slog.Logger) *slog.Logger { + if l != nil { + return l + } + return slog.New(discardHandler{}) +} + +// NewLoggingHandler creates a [LoggingHandler] that logs to the given [ServerSession] using a +// [slog.JSONHandler]. +func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingHandler { + var buf bytes.Buffer + jsonHandler := slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Remove level: it appears in LoggingMessageParams. + if a.Key == slog.LevelKey { + return slog.Attr{} + } + return a + }, + }) + lh := &LoggingHandler{ + ss: ss, + mu: new(sync.Mutex), + buf: &buf, + handler: jsonHandler, + } + if opts != nil { + lh.opts = *opts + } + return lh +} + +// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level. +func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { + // This is also checked in ServerSession.LoggingMessage, so checking it here + // is just an optimization that skips building the JSON. + h.ss.mu.Lock() + mcpLevel := h.ss.state.LogLevel + h.ss.mu.Unlock() + return level >= mcpLevelToSlog(mcpLevel) +} + +// WithAttrs implements [slog.Handler.WithAttrs]. +func (h *LoggingHandler) WithAttrs(as []slog.Attr) slog.Handler { + h2 := *h + h2.handler = h.handler.WithAttrs(as) + return &h2 +} + +// WithGroup implements [slog.Handler.WithGroup]. +func (h *LoggingHandler) WithGroup(name string) slog.Handler { + h2 := *h + h2.handler = h.handler.WithGroup(name) + return &h2 +} + +// Handle implements [slog.Handler.Handle] by writing the Record to a JSONHandler, +// then calling [ServerSession.LoggingMessage] with the result. +func (h *LoggingHandler) Handle(ctx context.Context, r slog.Record) error { + err := h.handle(ctx, r) + // TODO(jba): find a way to surface the error. + // The return value will probably be ignored. + return err +} + +func (h *LoggingHandler) handle(ctx context.Context, r slog.Record) error { + // Observe the rate limit. + // TODO(jba): use golang.org/x/time/rate. (We can't here because it would require adding + // golang.org/x/time to the go.mod file.) + h.mu.Lock() + skip := time.Since(h.lastMessageSent) < h.opts.MinInterval + h.mu.Unlock() + if skip { + return nil + } + + var err error + // Make the buffer reset atomic with the record write. + // We are careful here in the unlikely event that the handler panics. + // We don't want to hold the lock for the entire function, because Notify is + // an I/O operation. + // This can result in out-of-order delivery. + func() { + h.mu.Lock() + defer h.mu.Unlock() + h.buf.Reset() + err = h.handler.Handle(ctx, r) + }() + if err != nil { + return err + } + + h.mu.Lock() + h.lastMessageSent = time.Now() + h.mu.Unlock() + + params := &LoggingMessageParams{ + Logger: h.opts.LoggerName, + Level: slogLevelToMCP(r.Level), + Data: json.RawMessage(h.buf.Bytes()), + } + // We pass the argument context to Notify, even though slog.Handler.Handle's + // documentation says not to. + // In this case logging is a service to clients, not a means for debugging the + // server, so we want to cancel the log message. + return h.ss.Log(ctx, params) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go new file mode 100644 index 000000000..56e950b86 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/mcp.go @@ -0,0 +1,88 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The mcp package provides an SDK for writing model context protocol clients +// and servers. +// +// To get started, create either a [Client] or [Server], add features to it +// using `AddXXX` functions, and connect it to a peer using a [Transport]. +// +// For example, to run a simple server on the [StdioTransport]: +// +// server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) +// +// // Using the generic AddTool automatically populates the the input and output +// // schema of the tool. +// type args struct { +// Name string `json:"name" jsonschema:"the person to greet"` +// } +// mcp.AddTool(server, &mcp.Tool{ +// Name: "greet", +// Description: "say hi", +// }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{ +// &mcp.TextContent{Text: "Hi " + args.Name}, +// }, +// }, nil, nil +// }) +// +// // Run the server on the stdio transport. +// if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { +// log.Printf("Server failed: %v", err) +// } +// +// To connect to this server, use the [CommandTransport]: +// +// client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) +// transport := &mcp.CommandTransport{Command: exec.Command("myserver")} +// session, err := client.Connect(ctx, transport, nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Close() +// +// params := &mcp.CallToolParams{ +// Name: "greet", +// Arguments: map[string]any{"name": "you"}, +// } +// res, err := session.CallTool(ctx, params) +// if err != nil { +// log.Fatalf("CallTool failed: %v", err) +// } +// +// # Clients, servers, and sessions +// +// In this SDK, both a [Client] and [Server] may handle many concurrent +// connections. Each time a client or server is connected to a peer using a +// [Transport], it creates a new session (either a [ClientSession] or +// [ServerSession]): +// +// Client Server +// ⇅ (jsonrpc2) ⇅ +// ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession +// +// The session types expose an API to interact with its peer. For example, +// [ClientSession.CallTool] or [ServerSession.ListRoots]. +// +// # Adding features +// +// Add MCP servers to your Client or Server using AddXXX methods (for example +// [Client.AddRoot] or [Server.AddPrompt]). If any peers are connected when +// AddXXX is called, they will receive a corresponding change notification +// (for example notifications/roots/list_changed). +// +// Adding tools is special: tools may be bound to ordinary Go functions by +// using the top-level generic [AddTool] function, which allows specifying an +// input and output type. When AddTool is used, the tool's input schema and +// output schema are automatically populated, and inputs are automatically +// validated. As a special case, if the output type is 'any', no output schema +// is generated. +// +// func double(_ context.Context, _ *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) { +// return nil, Out{Answer: 2*in.Number}, nil +// } +// ... +// mcp.AddTool(server, &mcp.Tool{Name: "double"}, double) +package mcp diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go new file mode 100644 index 000000000..62f38a36a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/prompt.go @@ -0,0 +1,17 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" +) + +// A PromptHandler handles a call to prompts/get. +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) + +type serverPrompt struct { + prompt *Prompt + handler PromptHandler +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go new file mode 100644 index 000000000..26c8982f8 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go @@ -0,0 +1,1357 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// Protocol types for version 2025-06-18. +// To see the schema changes from the previous version, run: +// +// prefix=https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema +// sdiff -l <(curl $prefix/2025-03-26/schema.ts) <(curl $prefix/2025/06-18/schema.ts) + +import ( + "encoding/json" + "fmt" +) + +// Optional annotations for the client. The client can use annotations to inform +// how objects are used or displayed. +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., []Role{"user", "assistant"}). + Audience []Role `json:"audience,omitempty"` + // The moment the resource was last modified, as an ISO 8601 formatted string. + // + // Should be an ISO 8601 formatted string (e.g., "2025-01-12T15:00:58Z"). + // + // Examples: last activity timestamp in an open file, timestamp when the + // resource was attached, etc. + LastModified string `json:"lastModified,omitempty"` + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that the + // data is entirely optional. + Priority float64 `json:"priority,omitempty"` +} + +// CallToolParams is used by clients to call a tool. +type CallToolParams struct { + // Meta is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name"` + // Arguments holds the tool arguments. It can hold any value that can be + // marshaled to JSON. + Arguments any `json:"arguments,omitempty"` +} + +// CallToolParamsRaw is passed to tool handlers on the server. Its arguments +// are not yet unmarshaled (hence "raw"), so that the handlers can perform +// unmarshaling themselves. +type CallToolParamsRaw struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool being called. + Name string `json:"name"` + // Arguments is the raw arguments received over the wire from the client. It + // is the responsibility of the tool handler to unmarshal and validate the + // Arguments (see [AddTool]). + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +// A CallToolResult is the server's response to a tool call. +// +// The [ToolHandler] and [ToolHandlerFor] handler functions return this result, +// though [ToolHandlerFor] populates much of it automatically as documented at +// each field. +type CallToolResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + + // A list of content objects that represent the unstructured result of the tool + // call. + // + // When using a [ToolHandlerFor] with structured output, if Content is unset + // it will be populated with JSON text content corresponding to the + // structured output value. + Content []Content `json:"content"` + + // StructuredContent is an optional value that represents the structured + // result of the tool call. It must marshal to a JSON object. + // + // When using a [ToolHandlerFor] with structured output, you should not + // populate this field. It will be automatically populated with the typed Out + // value. + StructuredContent any `json:"structuredContent,omitempty"` + + // IsError reports whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + // + // Any errors that originate from the tool should be reported inside the + // Content field, with IsError set to true, not as an MCP protocol-level + // error response. Otherwise, the LLM would not be able to see that an error + // occurred and self-correct. + // + // However, any errors in finding the tool, an error indicating that the + // server does not support tool calls, or any other exceptional conditions, + // should be reported as an MCP error response. + // + // When using a [ToolHandlerFor], this field is automatically set when the + // tool handler returns an error, and the error string is included as text in + // the Content field. + IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error +} + +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err +} + +func (*CallToolResult) isResult() {} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion + var wire struct { + res + Content []*wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { + return err + } + *x = CallToolResult(wire.res) + return nil +} + +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } + +func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } + +type CancelledParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional string describing the reason for the cancellation. This may be + // logged or presented to the user. + Reason string `json:"reason,omitempty"` + // The ID of the request to cancel. + // + // This must correspond to the ID of a request previously issued in the same + // direction. + RequestID any `json:"requestId"` +} + +func (x *CancelledParams) isParams() {} +func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// RootCapabilities describes a client's support for roots. +type RootCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// Capabilities a client may support. Known capabilities are defined here, in +// this schema, but this is not a closed set: any client can define its own, +// additional capabilities. +type ClientCapabilities struct { + + // NOTE: any addition to ClientCapabilities must also be reflected in + // [ClientCapabilities.clone]. + + // Experimental reports non-standard capabilities that the client supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Roots describes the client's support for roots. + // + // Deprecated: use RootsV2. As described in #607, Roots should have been a + // pointer to a RootCapabilities value. Roots will be continue to be + // populated, but any new fields will only be added in the RootsV2 field. + Roots struct { + // ListChanged reports whether the client supports notifications for + // changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // RootsV2 is present if the client supports roots. When capabilities are explicitly configured via [ClientOptions.Capabilities] + RootsV2 *RootCapabilities `json:"-"` + // Sampling is present if the client supports sampling from an LLM. + Sampling *SamplingCapabilities `json:"sampling,omitempty"` + // Elicitation is present if the client supports elicitation from the server. + Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` +} + +// clone returns a deep copy of the ClientCapabilities. +func (c *ClientCapabilities) clone() *ClientCapabilities { + cp := *c + cp.RootsV2 = shallowClone(c.RootsV2) + cp.Sampling = shallowClone(c.Sampling) + if c.Elicitation != nil { + x := *c.Elicitation + x.Form = shallowClone(c.Elicitation.Form) + x.URL = shallowClone(c.Elicitation.URL) + cp.Elicitation = &x + } + return &cp +} + +// shallowClone returns a shallow clone of *p, or nil if p is nil. +func shallowClone[T any](p *T) *T { + if p == nil { + return nil + } + x := *p + return &x +} + +func (c *ClientCapabilities) toV2() *clientCapabilitiesV2 { + return &clientCapabilitiesV2{ + ClientCapabilities: *c, + Roots: c.RootsV2, + } +} + +// clientCapabilitiesV2 is a version of ClientCapabilities that fixes the bug +// described in #607: Roots should have been a pointer to value type +// RootCapabilities. +type clientCapabilitiesV2 struct { + ClientCapabilities + Roots *RootCapabilities `json:"roots,omitempty"` +} + +func (c *clientCapabilitiesV2) toV1() *ClientCapabilities { + caps := c.ClientCapabilities + caps.RootsV2 = c.Roots + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + return &caps +} + +type CompleteParamsArgument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` +} + +// CompleteContext represents additional, optional context for completions. +type CompleteContext struct { + // Previously-resolved variables in a URI template or prompt. + Arguments map[string]string `json:"arguments,omitempty"` +} + +// CompleteReference represents a completion reference type (ref/prompt ref/resource). +// The Type field determines which other fields are relevant. +type CompleteReference struct { + Type string `json:"type"` + // Name is relevant when Type is "ref/prompt". + Name string `json:"name,omitempty"` + // URI is relevant when Type is "ref/resource". + URI string `json:"uri,omitempty"` +} + +func (r *CompleteReference) UnmarshalJSON(data []byte) error { + type wireCompleteReference CompleteReference // for naive unmarshaling + var r2 wireCompleteReference + if err := json.Unmarshal(data, &r2); err != nil { + return err + } + switch r2.Type { + case "ref/prompt", "ref/resource": + if r2.Type == "ref/prompt" && r2.URI != "" { + return fmt.Errorf("reference of type %q must not have a URI set", r2.Type) + } + if r2.Type == "ref/resource" && r2.Name != "" { + return fmt.Errorf("reference of type %q must not have a Name set", r2.Type) + } + default: + return fmt.Errorf("unrecognized content type %q", r2.Type) + } + *r = CompleteReference(r2) + return nil +} + +func (r *CompleteReference) MarshalJSON() ([]byte, error) { + // Validation for marshalling: ensure consistency before converting to JSON. + switch r.Type { + case "ref/prompt": + if r.URI != "" { + return nil, fmt.Errorf("reference of type %q must not have a URI set for marshalling", r.Type) + } + case "ref/resource": + if r.Name != "" { + return nil, fmt.Errorf("reference of type %q must not have a Name set for marshalling", r.Type) + } + default: + return nil, fmt.Errorf("unrecognized reference type %q for marshalling", r.Type) + } + + type wireReference CompleteReference + return json.Marshal(wireReference(*r)) +} + +type CompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The argument's information + Argument CompleteParamsArgument `json:"argument"` + Context *CompleteContext `json:"context,omitempty"` + Ref *CompleteReference `json:"ref"` +} + +func (*CompleteParams) isParams() {} + +type CompletionResultDetails struct { + HasMore bool `json:"hasMore,omitempty"` + Total int `json:"total,omitempty"` + Values []string `json:"values"` +} + +// The server's response to a completion/complete request +type CompleteResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Completion CompletionResultDetails `json:"completion"` +} + +func (*CompleteResult) isResult() {} + +type CreateMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // A request to include context from one or more MCP servers (including the + // caller), to be attached to the prompt. The client may ignore this request. + IncludeContext string `json:"includeContext,omitempty"` + // The maximum number of tokens to sample, as requested by the server. The + // client may choose to sample fewer tokens than requested. + MaxTokens int64 `json:"maxTokens"` + Messages []*SamplingMessage `json:"messages"` + // Optional metadata to pass through to the LLM provider. The format of this + // metadata is provider-specific. + Metadata any `json:"metadata,omitempty"` + // The server's preferences for which model to select. The client may ignore + // these preferences. + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + // An optional system prompt the server wants to use for sampling. The client + // may modify or omit this prompt. + SystemPrompt string `json:"systemPrompt,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a sampling/create_message request from the server. +// The client should inform the user before returning the sampled message, to +// allow them to inspect the response (human in the loop) and decide whether to +// allow the server to see it. +type CreateMessageResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Content Content `json:"content"` + // The name of the model that generated the message. + Model string `json:"model"` + Role Role `json:"role"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +func (*CreateMessageResult) isResult() {} +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + +type GetPromptParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` + // The name of the prompt or prompt template. + Name string `json:"name"` +} + +func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } +func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a prompts/get request from the client. +type GetPromptResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []*PromptMessage `json:"messages"` +} + +func (*GetPromptResult) isResult() {} + +// InitializeParams is sent by the client to initialize the session. +type InitializeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Capabilities describes the client's capabilities. + Capabilities *ClientCapabilities `json:"capabilities"` + // ClientInfo provides information about the client. + ClientInfo *Implementation `json:"clientInfo"` + // ProtocolVersion is the latest version of the Model Context Protocol that + // the client supports. + ProtocolVersion string `json:"protocolVersion"` +} + +func (p *InitializeParams) toV2() *initializeParamsV2 { + return &initializeParamsV2{ + InitializeParams: *p, + Capabilities: p.Capabilities.toV2(), + } +} + +// initializeParamsV2 works around the mistake in #607: Capabilities.Roots +// should have been a pointer. +type initializeParamsV2 struct { + InitializeParams + Capabilities *clientCapabilitiesV2 `json:"capabilities"` +} + +func (p *initializeParamsV2) toV1() *InitializeParams { + p1 := p.InitializeParams + if p.Capabilities != nil { + p1.Capabilities = p.Capabilities.toV1() + } + return &p1 +} + +func (x *InitializeParams) isParams() {} +func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// InitializeResult is sent by the server in response to an initialize request +// from the client. +type InitializeResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Capabilities describes the server's capabilities. + Capabilities *ServerCapabilities `json:"capabilities"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of available + // tools, resources, etc. It can be thought of like a "hint" to the model. For + // example, this information may be added to the system prompt. + Instructions string `json:"instructions,omitempty"` + // The version of the Model Context Protocol that the server wants to use. This + // may not match the version that the client requested. If the client cannot + // support this version, it must disconnect. + ProtocolVersion string `json:"protocolVersion"` + ServerInfo *Implementation `json:"serverInfo"` +} + +func (*InitializeResult) isResult() {} + +type InitializedParams struct { + // Meta is reserved by the protocol to allow clients and servers to attach + // additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *InitializedParams) isParams() {} +func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ListPromptsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a prompts/list request from the client. +type ListPromptsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Prompts []*Prompt `json:"prompts"` +} + +func (x *ListPromptsResult) isResult() {} +func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourceTemplatesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` +} + +func (x *ListResourceTemplatesResult) isResult() {} +func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListResourcesParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a resources/list request from the client. +type ListResourcesResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Resources []*Resource `json:"resources"` +} + +func (x *ListResourcesResult) isResult() {} +func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } + +type ListRootsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to a roots/list request from the server. This result +// contains an array of Root objects, each representing a root directory or file +// that the server can operate on. +type ListRootsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Roots []*Root `json:"roots"` +} + +func (*ListRootsResult) isResult() {} + +type ListToolsParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the current pagination position. If provided, + // the server should return results starting after this cursor. + Cursor string `json:"cursor,omitempty"` +} + +func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } + +// The server's response to a tools/list request from the client. +type ListToolsResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // An opaque token representing the pagination position after the last returned + // result. If present, there may be more results available. + NextCursor string `json:"nextCursor,omitempty"` + Tools []*Tool `json:"tools"` +} + +func (x *ListToolsResult) isResult() {} +func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } + +// The severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +type LoggingMessageParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` +} + +func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } +func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client should treat this as a substring of a model name; for example: - + // `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` - `sonnet` + // should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. - + // `claude` should match any Claude model + // + // The client may also map the string to a different provider's model name or a + // different model family, as long as it fills a similar niche; for example: - + // `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +// The server's preferences for model selection, requested of the client during +// sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" model is +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client may ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important factor. + CostPriority float64 `json:"costPriority,omitempty"` + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client must evaluate them in order (such + // that the first match is taken). + // + // The client should prioritize these hints over the numeric priorities, but may + // still use the priorities to select from ambiguous matches. + Hints []*ModelHint `json:"hints,omitempty"` + // How much to prioritize intelligence and capabilities when selecting a model. + // A value of 0 means intelligence is not important, while a value of 1 means + // intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` +} + +type PingParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PingParams) isParams() {} +func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type ProgressNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The progress token which was given in the initial request, used to associate + // this notification with the request that is proceeding. + ProgressToken any `json:"progressToken"` + // An optional message describing the current progress. + Message string `json:"message,omitempty"` + // The progress thus far. This should increase every time progress is made, even + // if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + // Zero means unknown. + Total float64 `json:"total,omitempty"` +} + +func (*ProgressNotificationParams) isParams() {} + +// IconTheme specifies the theme an icon is designed for. +type IconTheme string + +const ( + // IconThemeLight indicates the icon is designed for a light background. + IconThemeLight IconTheme = "light" + // IconThemeDark indicates the icon is designed for a dark background. + IconThemeDark IconTheme = "dark" +) + +// Icon provides visual identifiers for their resources, tools, prompts, and implementations +// See [/specification/draft/basic/index#icons] for notes on icons +// +// TODO(iamsurajbobade): update specification url from draft. +type Icon struct { + // Source is A URI pointing to the icon resource (required). This can be: + // - An HTTP/HTTPS URL pointing to an image file + // - A data URI with base64-encoded image data + Source string `json:"src"` + // Optional MIME type if the server's type is missing or generic + MIMEType string `json:"mimeType,omitempty"` + // Optional size specification (e.g., ["48x48"], ["any"] for scalable formats like SVG, or ["48x48", "96x96"] for multiple sizes) + Sizes []string `json:"sizes,omitempty"` + // Optional theme specifier. "light" indicates the icon is designed for a light + // background, "dark" indicates the icon is designed for a dark background. + Theme IconTheme `json:"theme,omitempty"` +} + +// A prompt or prompt template that the server offers. +type Prompt struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // A list of arguments to use for templating the prompt. + Arguments []*PromptArgument `json:"arguments,omitempty"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + // Icons for the prompt, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// Describes an argument that a prompt can accept. +type PromptArgument struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + Required bool `json:"required,omitempty"` +} + +type PromptListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Describes a message returned as part of a prompt. +// +// This is similar to SamplingMessage, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *PromptMessage) UnmarshalJSON(data []byte) error { + type msg PromptMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, nil); err != nil { + return err + } + *m = PromptMessage(wire.msg) + return nil +} + +type ReadResourceParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to read. The URI can use any protocol; it is up to + // the server how to interpret it. + URI string `json:"uri"` +} + +func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The server's response to a resources/read request from the client. +type ReadResourceResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Contents []*ResourceContents `json:"contents"` +} + +func (*ReadResourceResult) isResult() {} + +// A known resource that the server is capable of reading. +type Resource struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // The size of the raw resource content, in bytes (i.e., before base64 encoding + // or any tokenization), if known. + // + // This can be used by Hosts to display file sizes and estimate context window + // usage. + Size int64 `json:"size,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // Icons for the resource, if any. + Icons []Icon `json:"icons,omitempty"` +} + +type ResourceListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// A template description for resources available on the server. +type ResourceTemplate struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional annotations for the client. + Annotations *Annotations `json:"annotations,omitempty"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of available + // resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only be + // included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // + // If not provided, the name should be used for display (except for Tool, where + // Annotations.Title should be given precedence over using name, if + // present). + Title string `json:"title,omitempty"` + // A URI template (according to RFC 6570) that can be used to construct resource + // URIs. + URITemplate string `json:"uriTemplate"` + // Icons for the resource template, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// The sender or recipient of messages and data in a conversation. +type Role string + +// Represents a root directory or file that the server can operate on. +type Root struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. This + // restriction may be relaxed in future versions of the protocol to allow other + // URI schemes. + URI string `json:"uri"` +} + +type RootsListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// TODO: to be consistent with ServerCapabilities, move the capability types +// below directly above ClientCapabilities. + +// SamplingCapabilities describes the client's support for sampling. +type SamplingCapabilities struct{} + +// ElicitationCapabilities describes the capabilities for elicitation. +// +// If neither Form nor URL is set, the 'Form' capabilitiy is assumed. +type ElicitationCapabilities struct { + Form *FormElicitationCapabilities + URL *URLElicitationCapabilities +} + +// FormElicitationCapabilities describes capabilities for form elicitation. +type FormElicitationCapabilities struct { +} + +// URLElicitationCapabilities describes capabilities for url elicitation. +type URLElicitationCapabilities struct { +} + +// Describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Content Content `json:"content"` + Role Role `json:"role"` +} + +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (m *SamplingMessage) UnmarshalJSON(data []byte) error { + type msg SamplingMessage // avoid recursion + var wire struct { + msg + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *m = SamplingMessage(wire.msg) + return nil +} + +type SetLoggingLevelParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The level of logging that the client wants to receive from the server. The + // server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/message. + Level LoggingLevel `json:"level"` +} + +func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } +func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Definition for a tool the client can call. +type Tool struct { + // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta + // usage. + Meta `json:"_meta,omitempty"` + // Optional additional tool information. + // + // Display name precedence order is: title, annotations.title, then name. + Annotations *ToolAnnotations `json:"annotations,omitempty"` + // A human-readable description of the tool. + // + // This can be used by clients to improve the LLM's understanding of available + // tools. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // InputSchema holds a JSON Schema object defining the expected parameters + // for the tool. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's input schema (a map[string]any). + InputSchema any `json:"inputSchema"` + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // OutputSchema holds an optional JSON Schema object defining the structure + // of the tool's output returned in the StructuredContent field of a + // CallToolResult. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's output schema (a map[string]any). + OutputSchema any `json:"outputSchema,omitempty"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + // If not provided, Annotations.Title should be used for display if present, + // otherwise Name. + Title string `json:"title,omitempty"` + // Icons for the tool, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// Additional properties describing a Tool to clients. +// +// NOTE: all properties in ToolAnnotations are hints. They are not +// guaranteed to provide a faithful description of tool behavior (including +// descriptive properties like title). +// +// Clients should never make tool use decisions based on ToolAnnotations +// received from untrusted servers. +type ToolAnnotations struct { + // If true, the tool may perform destructive updates to its environment. If + // false, the tool performs only additive updates. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: true + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, calling the tool repeatedly with the same arguments will have no + // additional effect on the its environment. + // + // (This property is meaningful only when ReadOnlyHint == false.) + // + // Default: false + IdempotentHint bool `json:"idempotentHint,omitempty"` + // If true, this tool may interact with an "open world" of external entities. If + // false, the tool's domain of interaction is closed. For example, the world of + // a web search tool is open, whereas that of a memory tool is not. + // + // Default: true + OpenWorldHint *bool `json:"openWorldHint,omitempty"` + // If true, the tool does not modify its environment. + // + // Default: false + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + // A human-readable title for the tool. + Title string `json:"title,omitempty"` +} + +type ToolListChangedParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` +} + +func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +func (*SubscribeParams) isParams() {} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +func (*UnsubscribeParams) isParams() {} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +func (*ResourceUpdatedNotificationParams) isParams() {} + +// TODO(jba): add CompleteRequest and related types. + +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The mode of elicitation to use. + // + // If unset, will be inferred from the other fields. + Mode string `json:"mode"` + // The message to present to the user. + Message string `json:"message"` + // A JSON schema object defining the requested elicitation schema. + // + // From the server, this field may be set to any value that can JSON-marshal + // to valid JSON schema (including json.RawMessage for raw schema values). + // Internally, the SDK uses github.com/google/jsonschema-go for validation, + // which only supports the 2020-12 draft of the JSON schema spec. + // + // From the client, this field will use the default JSON marshaling (a + // map[string]any). + // + // Only top-level properties are allowed, without nesting. + // + // This is only used for "form" elicitation. + RequestedSchema any `json:"requestedSchema,omitempty"` + // The URL to present to the user. + // + // This is only used for "url" elicitation. + URL string `json:"url,omitempty"` + // The ID of the elicitation. + // + // This is only used for "url" elicitation. + ElicitationID string `json:"elicitationId,omitempty"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} + +// ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. +type ElicitationCompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The ID of the elicitation that has completed. This must correspond to the + // elicitationId from the original elicitation/create request. + ElicitationID string `json:"elicitationId"` +} + +func (*ElicitationCompleteParams) isParams() {} + +// An Implementation describes the name and version of an MCP implementation, with an optional +// title for UI representation. +type Implementation struct { + // Intended for programmatic or logical use, but used as a display name in past + // specs or fallback (if title isn't present). + Name string `json:"name"` + // Intended for UI and end-user contexts — optimized to be human-readable and + // easily understood, even by those unfamiliar with domain-specific terminology. + Title string `json:"title,omitempty"` + Version string `json:"version"` + // WebsiteURL for the server, if any. + WebsiteURL string `json:"websiteUrl,omitempty"` + // Icons for the Server, if any. + Icons []Icon `json:"icons,omitempty"` +} + +// CompletionCapabilities describes the server's support for argument autocompletion. +type CompletionCapabilities struct{} + +// LoggingCapabilities describes the server's support for sending log messages to the client. +type LoggingCapabilities struct{} + +// PromptCapabilities describes the server's support for prompts. +type PromptCapabilities struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// ResourceCapabilities describes the server's support for resources. +type ResourceCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the resource list. + ListChanged bool `json:"listChanged,omitempty"` + // Subscribe reports whether this server supports subscribing to resource + // updates. + Subscribe bool `json:"subscribe,omitempty"` +} + +// ToolCapabilities describes the server's support for tools. +type ToolCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// ServerCapabilities describes capabilities that a server supports. +type ServerCapabilities struct { + + // NOTE: any addition to ServerCapabilities must also be reflected in + // [ServerCapabilities.clone]. + + // Experimental reports non-standard capabilities that the server supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Completions is present if the server supports argument autocompletion + // suggestions. + Completions *CompletionCapabilities `json:"completions,omitempty"` + // Logging is present if the server supports log messages. + Logging *LoggingCapabilities `json:"logging,omitempty"` + // Prompts is present if the server supports prompts. + Prompts *PromptCapabilities `json:"prompts,omitempty"` + // Resources is present if the server supports resourcs. + Resources *ResourceCapabilities `json:"resources,omitempty"` + // Tools is present if the supports tools. + Tools *ToolCapabilities `json:"tools,omitempty"` +} + +// clone returns a deep copy of the ServerCapabilities. +func (c *ServerCapabilities) clone() *ServerCapabilities { + cp := *c + cp.Completions = shallowClone(c.Completions) + cp.Logging = shallowClone(c.Logging) + cp.Prompts = shallowClone(c.Prompts) + cp.Resources = shallowClone(c.Resources) + cp.Tools = shallowClone(c.Tools) + return &cp +} + +const ( + methodCallTool = "tools/call" + notificationCancelled = "notifications/cancelled" + methodComplete = "completion/complete" + methodCreateMessage = "sampling/createMessage" + methodElicit = "elicitation/create" + notificationElicitationComplete = "notifications/elicitation/complete" + methodGetPrompt = "prompts/get" + methodInitialize = "initialize" + notificationInitialized = "notifications/initialized" + methodListPrompts = "prompts/list" + methodListResourceTemplates = "resources/templates/list" + methodListResources = "resources/list" + methodListRoots = "roots/list" + methodListTools = "tools/list" + notificationLoggingMessage = "notifications/message" + methodPing = "ping" + notificationProgress = "notifications/progress" + notificationPromptListChanged = "notifications/prompts/list_changed" + methodReadResource = "resources/read" + notificationResourceListChanged = "notifications/resources/list_changed" + notificationResourceUpdated = "notifications/resources/updated" + notificationRootsListChanged = "notifications/roots/list_changed" + methodSetLevel = "logging/setLevel" + methodSubscribe = "resources/subscribe" + notificationToolListChanged = "notifications/tools/list_changed" + methodUnsubscribe = "resources/unsubscribe" +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go new file mode 100644 index 000000000..f64d6fb62 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go @@ -0,0 +1,38 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +type ( + CallToolRequest = ServerRequest[*CallToolParamsRaw] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams] +) diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go new file mode 100644 index 000000000..dc657f5dd --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go @@ -0,0 +1,164 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" +) + +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler +} + +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler +} + +// A ResourceHandler is a function that reads a resource. +// It will be called when the client calls [ClientSession.ReadResource]. +// If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) + +// ResourceNotFoundError returns an error indicating that a resource being read could +// not be found. +func ResourceNotFoundError(uri string) error { + return &jsonrpc.Error{ + Code: CodeResourceNotFound, + Message: "Resource not found", + Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), + } +} + +// readFileResource reads from the filesystem at a URI relative to dirFilepath, respecting +// the roots. +// dirFilepath and rootFilepaths are absolute filesystem paths. +func readFileResource(rawURI, dirFilepath string, rootFilepaths []string) ([]byte, error) { + uriFilepath, err := computeURIFilepath(rawURI, dirFilepath, rootFilepaths) + if err != nil { + return nil, err + } + + var data []byte + err = withFile(dirFilepath, uriFilepath, func(f *os.File) error { + var err error + data, err = io.ReadAll(f) + return err + }) + if os.IsNotExist(err) { + err = ResourceNotFoundError(rawURI) + } + return data, err +} + +// computeURIFilepath returns a path relative to dirFilepath. +// The dirFilepath and rootFilepaths are absolute file paths. +func computeURIFilepath(rawURI, dirFilepath string, rootFilepaths []string) (string, error) { + // We use "file path" to mean a filesystem path. + uri, err := url.Parse(rawURI) + if err != nil { + return "", err + } + if uri.Scheme != "file" { + return "", fmt.Errorf("URI is not a file: %s", uri) + } + if uri.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // The URI's path is interpreted relative to dirFilepath, and in the local filesystem. + // It must not try to escape its directory. + uriFilepathRel, err := filepath.Localize(strings.TrimPrefix(uri.Path, "/")) + if err != nil { + return "", fmt.Errorf("%q cannot be localized: %w", uriFilepathRel, err) + } + + // Check roots, if there are any. + if len(rootFilepaths) > 0 { + // To check against the roots, we need an absolute file path, not relative to the directory. + // uriFilepath is local, so the joined path is under dirFilepath. + uriFilepathAbs := filepath.Join(dirFilepath, uriFilepathRel) + rootOK := false + // Check that the requested file path is under some root. + // Since both paths are absolute, that's equivalent to filepath.Rel constructing + // a local path. + for _, rootFilepathAbs := range rootFilepaths { + if rel, err := filepath.Rel(rootFilepathAbs, uriFilepathAbs); err == nil && filepath.IsLocal(rel) { + rootOK = true + break + } + } + if !rootOK { + return "", fmt.Errorf("URI path %q is not under any root", uriFilepathAbs) + } + } + return uriFilepathRel, nil +} + +// fileRoots transforms the Roots obtained from the client into absolute paths on +// the local filesystem. +// TODO(jba): expose this functionality to user ResourceHandlers, +// so they don't have to repeat it. +func fileRoots(rawRoots []*Root) ([]string, error) { + var fileRoots []string + for _, r := range rawRoots { + fr, err := fileRoot(r) + if err != nil { + return nil, err + } + fileRoots = append(fileRoots, fr) + } + return fileRoots, nil +} + +// fileRoot returns the absolute path for Root. +func fileRoot(root *Root) (_ string, err error) { + defer util.Wrapf(&err, "root %q", root.URI) + + // Convert to absolute file path. + rurl, err := url.Parse(root.URI) + if err != nil { + return "", err + } + if rurl.Scheme != "file" { + return "", errors.New("not a file URI") + } + if rurl.Path == "" { + // A more specific error than the one below, to catch the + // common mistake "file://foo". + return "", errors.New("empty path") + } + // We don't want Localize here: we want an absolute path, which is not local. + fileRoot := filepath.Clean(filepath.FromSlash(rurl.Path)) + if !filepath.IsAbs(fileRoot) { + return "", errors.New("not an absolute path") + } + return fileRoot, nil +} + +// Matches reports whether the receiver's uri template matches the uri. +func (sr *serverResourceTemplate) Matches(uri string) bool { + tmpl, err := uritemplate.New(sr.resourceTemplate.URITemplate) + if err != nil { + return false + } + return tmpl.Regexp().MatchString(uri) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go new file mode 100644 index 000000000..4a35603c6 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_go124.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package mcp + +import ( + "errors" + "os" +) + +// withFile calls f on the file at join(dir, rel), +// protecting against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + r, err := os.OpenRoot(dir) + if err != nil { + return err + } + defer r.Close() + file, err := r.Open(rel) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go new file mode 100644 index 000000000..d1f72eedc --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource_pre_go124.go @@ -0,0 +1,25 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build !go1.24 + +package mcp + +import ( + "errors" + "os" + "path/filepath" +) + +// withFile calls f on the file at join(dir, rel). +// It does not protect against path traversal attacks. +func withFile(dir, rel string, f func(*os.File) error) (err error) { + file, err := os.Open(filepath.Join(dir, rel)) + if err != nil { + return err + } + // Record error, in case f writes. + defer func() { err = errors.Join(err, file.Close()) }() + return f(file) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go new file mode 100644 index 000000000..1f7edf9c5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go @@ -0,0 +1,1497 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/gob" + "encoding/json" + "errors" + "fmt" + "iter" + "log/slog" + "maps" + "net/url" + "path/filepath" + "reflect" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" +) + +// DefaultPageSize is the default for [ServerOptions.PageSize]. +const DefaultPageSize = 1000 + +// A Server is an instance of an MCP server. +// +// Servers expose server-side MCP features, which can serve one or more MCP +// sessions by using [Server.Run]. +type Server struct { + // fixed at creation + impl *Implementation + opts ServerOptions + + mu sync.Mutex + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] + sessions []*ServerSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send +} + +// ServerOptions is used to configure behavior of the server. +type ServerOptions struct { + // Optional instructions for connected clients. + Instructions string + // If non-nil, log server activity. + Logger *slog.Logger + // If non-nil, called when "notifications/initialized" is received. + InitializedHandler func(context.Context, *InitializedRequest) + // PageSize is the maximum number of items to return in a single page for + // list methods (e.g. ListTools). + // + // If zero, defaults to [DefaultPageSize]. + PageSize int + // If non-nil, called when "notifications/roots/list_changed" is received. + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) + // If non-nil, called when "notifications/progress" is received. + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) + // If non-nil, called when "completion/complete" is received. + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) + // If non-zero, defines an interval for regular "ping" requests. + // If the peer fails to respond to pings originating from the keepalive check, + // the session is automatically closed. + KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeRequest) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error + + // Capabilities optionally configures the server's default capabilities, + // before any capabilities are inferred from other configuration or server + // features. + // + // If Capabilities is nil, the default server capabilities are {"logging":{}}, + // for historical reasons. Setting Capabilities to a non-nil value overrides + // this default. For example, setting Capabilities to `&ServerCapabilities{}` + // disables the logging capability. + // + // # Interaction with capability inference + // + // "tools", "prompts", and "resources" capabilities are automatically added when + // tools, prompts, or resources are added to the server (for example, via + // [Server.AddPrompt]), with default value `{"listChanged":true}`. Similarly, + // if the [ClientOptions.SubscribeHandler] or + // [ClientOptions.CompletionHandler] are set, the inferred capabilities are + // adjusted accordingly. + // + // Any non-nil field in Capabilities overrides the inferred value. + // For example: + // + // - To advertise the "tools" capability, even if no tools are added, set + // Capabilities.Tools to &ToolCapabilities{ListChanged:true}. + // - To disable tool list notifications, set Capabilities.Tools to + // &ToolCapabilities{}. + // + // Conversely, if Capabilities does not set a field (for example, if the + // Prompts field is nil), the inferred capability will be used. + Capabilities *ServerCapabilities + + // If true, advertises the prompts capability during initialization, + // even if no prompts have been registered. + // + // Deprecated: Use Capabilities instead. + HasPrompts bool + // If true, advertises the resources capability during initialization, + // even if no resources have been registered. + // + // Deprecated: Use Capabilities instead. + HasResources bool + // If true, advertises the tools capability during initialization, + // even if no tools have been registered. + // + // Deprecated: Use Capabilities instead. + HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string +} + +// NewServer creates a new MCP server. The resulting server has no features: +// add features using the various Server.AddXXX methods, and the [AddTool] function. +// +// The server can be connected to one or more MCP clients using [Server.Run]. +// +// The first argument must not be nil. +// +// If non-nil, the provided options are used to configure the server. +func NewServer(impl *Implementation, options *ServerOptions) *Server { + if impl == nil { + panic("nil Implementation") + } + var opts ServerOptions + if options != nil { + opts = *options + } + options = nil // prevent reuse + if opts.PageSize < 0 { + panic(fmt.Errorf("invalid page size %d", opts.PageSize)) + } + if opts.PageSize == 0 { + opts.PageSize = DefaultPageSize + } + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + + if opts.Logger == nil { // ensure we have a logger + opts.Logger = ensureLogger(nil) + } + + return &Server{ + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), + pendingNotifications: make(map[string]*time.Timer), + } +} + +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) + s.changeAndNotify( + notificationPromptListChanged, + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) +} + +// RemovePrompts removes the prompts with the given names. +// It is not an error to remove a nonexistent prompt. +func (s *Server) RemovePrompts(names ...string) { + s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) }) +} + +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The Tool argument must not be modified after this call. +// +// The tool's input schema must be non-nil and have the type "object". For a tool +// that takes no input, or one where any input is valid, set [Tool.InputSchema] to +// `{"type": "object"}`, using your preferred library or `json.RawMessage`. +// +// If present, [Tool.OutputSchema] must also have type "object". +// +// When the handler is invoked as part of a CallTool request, req.Params.Arguments +// will be a json.RawMessage. +// +// Unmarshaling the arguments and validating them against the input schema are the +// caller's responsibility. +// +// Validating the result against the output schema, if any, is the caller's responsibility. +// +// Setting the result's Content, StructuredContent and IsError fields are the caller's +// responsibility. +// +// Most users should use the top-level function [AddTool], which handles all these +// responsibilities. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + if err := validateToolName(t.Name); err != nil { + s.opts.Logger.Error(fmt.Sprintf("AddTool: invalid tool name %q: %v", t.Name, err)) + } + if t.InputSchema == nil { + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) + } + if s, ok := t.InputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.InputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal input schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object" (got %v)`, t.Name, typ)) + } + } + if t.OutputSchema != nil { + if s, ok := t.OutputSchema.(*jsonschema.Schema); ok { + if s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } + } else { + var m map[string]any + if err := remarshal(t.OutputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal output schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object" (got %v)`, t.Name, typ)) + } + } + } + st := &serverTool{tool: t, handler: h} + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. + s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true }) +} + +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { + tt := *t + + // Special handling for an "any" input: treat as an empty object. + if reflect.TypeFor[In]() == reflect.TypeFor[any]() && t.InputSchema == nil { + tt.InputSchema = &jsonschema.Schema{Type: "object"} + } + + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + if err != nil { + return nil, nil, fmt.Errorf("output schema: %v", err) + } + } + + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + + // Unmarshal and validate args. + var in In + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) + } + } + + // Call typed handler. + res, out, err := h(ctx, req, in) + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc.Error), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors + if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc.Error); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil + } + + if res == nil { + res = &CallToolResult{} + } + + // Marshal the output and put the RawMessage in the StructuredContent field. + var outval any = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the unpointered type. + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + outval = elemZero + } + } + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outJSON), + }} + } + } + return res, nil + } // end of handler + + return &tt, th, nil +} + +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { + var internalSchema *jsonschema.Schema + if *sfield == nil { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + if err == nil { + *sfield = internalSchema + } + } else if err := remarshal(*sfield, &internalSchema); err != nil { + return zero, err + } + if err != nil { + return zero, err + } + *rfield, err = internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + +// AddTool adds a tool and typed tool handler to the server. +// +// If the tool's input schema is nil, it is set to the schema inferred from the +// In type parameter. Types are inferred from Go types, and property +// descriptions are read from the 'jsonschema' struct tag. Internally, the SDK +// uses the github.com/google/jsonschema-go package for inference and +// validation. The In type argument must be a map or a struct, so that its +// inferred JSON Schema has type "object", as required by the spec. As a +// special case, if the In type is 'any', the tool's input schema is set to an +// empty object schema value. +// +// If the tool's output schema is nil, and the Out type is not 'any', the +// output schema is set to the schema inferred from the Out type argument, +// which must also be a map or struct. If the Out type is 'any', the output +// schema is omitted. +// +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces +// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed +// description of this automatic behavior. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + } + s.AddTool(tt, hh) +} + +// RemoveTools removes the tools with the given names. +// It is not an error to remove a nonexistent tool. +func (s *Server) RemoveTools(names ...string) { + s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) }) +} + +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, + func() bool { + if _, err := url.Parse(r.URI); err != nil { + panic(err) // url.Parse includes the URI in the error + } + s.resources.add(&serverResource{r, h}) + return true + }) +} + +// RemoveResources removes the resources with the given URIs. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResources(uris ...string) { + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) }) +} + +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { + s.changeAndNotify(notificationResourceListChanged, + func() bool { + // Validate the URI template syntax + _, err := uritemplate.New(t.URITemplate) + if err != nil { + panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) + } + s.resourceTemplates.add(&serverResourceTemplate{t, h}) + return true + }) +} + +// RemoveResourceTemplates removes the resource templates with the given URI templates. +// It is not an error to remove a nonexistent resource. +func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) }) +} + +func (s *Server) capabilities() *ServerCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ServerCapabilities + if s.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = s.opts.Capabilities.clone() + } else { + // SDK defaults: only logging capability. + caps = &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + } + } + + // Augment with tools capability if tools exist or legacy HasTools is set. + if s.opts.HasTools || s.tools.len() > 0 { + if caps.Tools == nil { + caps.Tools = &ToolCapabilities{ListChanged: true} + } + } + + // Augment with prompts capability if prompts exist or legacy HasPrompts is set. + if s.opts.HasPrompts || s.prompts.len() > 0 { + if caps.Prompts == nil { + caps.Prompts = &PromptCapabilities{ListChanged: true} + } + } + + // Augment with resources capability if resources/templates exist or legacy HasResources is set. + if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + if caps.Resources == nil { + caps.Resources = &ResourceCapabilities{ListChanged: true} + } + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } + } + + // Augment with completions capability if handler is set. + if s.opts.CompletionHandler != nil { + if caps.Completions == nil { + caps.Completions = &CompletionCapabilities{} + } + } + + return caps +} + +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { + if s.opts.CompletionHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + return s.opts.CompletionHandler(ctx, req) +} + +// Map from notification name to its corresponding params. The params have no fields, +// so a single struct can be reused. +var changeNotificationParams = map[string]Params{ + notificationToolListChanged: &ToolListChangedParams{}, + notificationPromptListChanged: &PromptListChangedParams{}, + notificationResourceListChanged: &ResourceListChangedParams{}, +} + +// How long to wait before sending a change notification. +const notificationDelay = 10 * time.Millisecond + +// changeAndNotify is called when a feature is added or removed. +// It calls change, which should do the work and report whether a change actually occurred. +// If there was a change, it sets a timer to send a notification. +// This debounces change notifications: a single notification is sent after +// multiple changes occur in close proximity. +func (s *Server) changeAndNotify(notification string, change func() bool) { + s.mu.Lock() + defer s.mu.Unlock() + if change() && s.shouldSendListChangedNotification(notification) { + // Reset the outstanding delayed call, if any. + if t := s.pendingNotifications[notification]; t == nil { + s.pendingNotifications[notification] = time.AfterFunc(notificationDelay, func() { s.notifySessions(notification) }) + } else { + t.Reset(notificationDelay) + } + } +} + +// notifySessions sends the notification n to all existing sessions. +// It is called asynchronously by changeAndNotify. +func (s *Server) notifySessions(n string) { + s.mu.Lock() + sessions := slices.Clone(s.sessions) + s.pendingNotifications[n] = nil + s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. + notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) +} + +// shouldSendListChangedNotification checks if the server's capabilities allow +// sending the given list-changed notification. +func (s *Server) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := s.opts.Capabilities + + switch notification { + case notificationToolListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil || caps.Tools == nil { + return true + } + return caps.Tools.ListChanged + case notificationPromptListChanged: + if caps == nil || caps.Prompts == nil { + return true + } + return caps.Prompts.ListChanged + case notificationResourceListChanged: + if caps == nil || caps.Resources == nil { + return true + } + return caps.Resources.ListChanged + default: + // Unknown notification, allow by default. + return true + } +} + +// Sessions returns an iterator that yields the current set of server sessions. +// +// There is no guarantee that the iterator observes sessions that are added or +// removed during iteration. +func (s *Server) Sessions() iter.Seq[*ServerSession] { + s.mu.Lock() + clients := slices.Clone(s.sessions) + s.mu.Unlock() + return slices.Values(clients) +} + +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListPromptsParams{} + } + return paginateList(s.prompts, s.opts.PageSize, req.Params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { + res.Prompts = []*Prompt{} // avoid JSON null + for _, p := range prompts { + res.Prompts = append(res.Prompts, p.prompt) + } + }) +} + +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + s.mu.Lock() + prompt, ok := s.prompts.get(req.Params.Name) + s.mu.Unlock() + if !ok { + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } + } + return prompt.handler(ctx, req) +} + +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListToolsParams{} + } + return paginateList(s.tools, s.opts.PageSize, req.Params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { + res.Tools = []*Tool{} // avoid JSON null + for _, t := range tools { + res.Tools = append(res.Tools, t.tool) + } + }) +} + +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + s.mu.Lock() + st, ok := s.tools.get(req.Params.Name) + s.mu.Unlock() + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } + } + res, err := st.handler(ctx, req) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } + return res, err +} + +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourcesParams{} + } + return paginateList(s.resources, s.opts.PageSize, req.Params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { + res.Resources = []*Resource{} // avoid JSON null + for _, r := range resources { + res.Resources = append(res.Resources, r.resource) + } + }) +} + +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + if req.Params == nil { + req.Params = &ListResourceTemplatesParams{} + } + return paginateList(s.resourceTemplates, s.opts.PageSize, req.Params, &ListResourceTemplatesResult{}, + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { + res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null + for _, rt := range rts { + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) + } + }) +} + +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + uri := req.Params.URI + // Look up the resource URI in the lists of resources and resource templates. + // This is a security check as well as an information lookup. + handler, mimeType, ok := s.lookupResourceHandler(uri) + if !ok { + // Don't expose the server configuration to the client. + // Treat an unregistered resource the same as a registered one that couldn't be found. + return nil, ResourceNotFoundError(uri) + } + res, err := handler(ctx, req) + if err != nil { + return nil, err + } + if res == nil || res.Contents == nil { + return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) + } + // As a convenience, populate some fields. + for _, c := range res.Contents { + if c.URI == "" { + c.URI = uri + } + if c.MIMEType == "" { + c.MIMEType = mimeType + } + } + return res, nil +} + +// lookupResourceHandler returns the resource handler and MIME type for the resource or +// resource template matching uri. If none, the last return value is false. +func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + // Try resources first. + if r, ok := s.resources.get(uri); ok { + return r.handler, r.resource.MIMEType, true + } + // Look for matching template. + for rt := range s.resourceTemplates.all() { + if rt.Matches(uri) { + return rt.handler, rt.resourceTemplate.MIMEType, true + } + } + return nil, "", false +} + +// fileResourceHandler returns a ReadResourceHandler that reads paths using dir as +// a base directory. +// It honors client roots and protects against path traversal attacks. +// +// The dir argument should be a filesystem path. It need not be absolute, but +// that is recommended to avoid a dependency on the current working directory (the +// check against client roots is done with an absolute path). If dir is not absolute +// and the current working directory is unavailable, fileResourceHandler panics. +// +// Lexical path traversal attacks, where the path has ".." elements that escape dir, +// are always caught. Go 1.24 and above also protects against symlink-based attacks, +// where symlinks under dir lead out of the tree. +func fileResourceHandler(dir string) ResourceHandler { + // Convert dir to an absolute path. + dirFilepath, err := filepath.Abs(dir) + if err != nil { + panic(err) + } + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { + defer util.Wrapf(&err, "reading resource %s", req.Params.URI) + + // TODO(#25): use a memoizing API here. + rootRes, err := req.Session.ListRoots(ctx, nil) + if err != nil { + return nil, fmt.Errorf("listing roots: %w", err) + } + roots, err := fileRoots(rootRes.Roots) + if err != nil { + return nil, err + } + data, err := readFileResource(req.Params.URI, dirFilepath, roots) + if err != nil { + return nil, err + } + // TODO(jba): figure out mime type. Omit for now: Server.readResource will fill it in. + return &ReadResourceResult{Contents: []*ResourceContents{ + {URI: req.Params.URI, Blob: data}, + }}, nil + } +} + +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) + s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) + return nil +} + +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[req.Params.URI] == nil { + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[req.Params.URI][req.Session] = true + s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[req.Params.URI]; ok { + delete(subscribedSessions, req.Session) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, req.Params.URI) + } + } + s.opts.Logger.Info("resource unsubscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + + return &emptyResult{}, nil +} + +// Run runs the server over the given transport, which must be persistent. +// +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. +// +// Run is a convenience for servers that handle a single session (or one session at a time). +// It need not be called on servers that are used for multiple concurrent connections, +// as with [StreamableHTTPHandler]. +func (s *Server) Run(ctx context.Context, t Transport) error { + s.opts.Logger.Info("server run start") + ss, err := s.Connect(ctx, t, nil) + if err != nil { + s.opts.Logger.Error("server connect failed", "error", err) + return err + } + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + <-ssClosed // wait until waiting go routine above actually completes + s.opts.Logger.Error("server run cancelled", "error", ctx.Err()) + return ctx.Err() + case err := <-ssClosed: + if err != nil { + s.opts.Logger.Error("server session ended with error", "error", err) + } else { + s.opts.Logger.Info("server session ended") + } + return err + } +} + +// bind implements the binder[*ServerSession] interface, so that Servers can +// be connected using [connect]. +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} + if state != nil { + ss.state = *state + } + s.mu.Lock() + s.sessions = append(s.sessions, ss) + s.mu.Unlock() + s.opts.Logger.Info("server session connected", "session_id", ss.ID()) + return ss +} + +// disconnect implements the binder[*ServerSession] interface, so that +// Servers can be connected using [connect]. +func (s *Server) disconnect(cc *ServerSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { + return cc2 == cc + }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) +} + +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState + + onClose func() // used to clean up associated resources +} + +// Connect connects the MCP server over the given transport and starts handling +// messages. +// +// It returns a connection object that may be used to terminate the connection +// (with [Connection.Close]), or await client termination (with +// [Connection.Wait]). +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + var onClose func() + if opts != nil { + state = opts.State + onClose = opts.onClose + } + + s.opts.Logger.Info("server connecting") + ss, err := connect(ctx, t, s, state, onClose) + if err != nil { + s.opts.Logger.Error("server connect error", "error", err) + return nil, err + } + return ss, nil +} + +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) + + if !wasInit { + ss.server.opts.Logger.Error("initialized before initialize") + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitd { + ss.server.opts.Logger.Error("duplicate initialized notification") + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) + } + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + if h := ss.server.opts.InitializedHandler; h != nil { + h(ctx, serverRequestFor(ss, params)) + } + ss.server.opts.Logger.Info("session initialized") + return nil, nil +} + +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { + if h := s.opts.RootsListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { + if h := ss.server.opts.ProgressNotificationHandler; h != nil { + h(ctx, serverRequestFor(ss, p)) + } + return nil, nil +} + +// NotifyProgress sends a progress notification from the server to the client +// associated with this session. +// This is typically used to report on the status of a long-running request +// that was initiated by the client. +func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { + return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) +} + +func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { + return &ServerRequest[P]{Session: ss, Params: params} +} + +// A ServerSession is a logical connection from a single MCP client. Its +// methods can be used to send requests or notifications to the client. Create +// a session by calling [Server.Connect]. +// +// Call [ServerSession.Close] to close the connection, or await client +// termination with [ServerSession.Wait]. +type ServerSession struct { + // Ensure that onClose is called at most once. + // We defensively use an atomic CompareAndSwap rather than a sync.Once, in case the + // onClose callback triggers a re-entrant call to Close. + calledOnClose atomic.Bool + onClose func() + + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState +} + +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } +} + +// hasInitialized reports whether the server has received the initialized +// notification. +// +// TODO(findleyr): use this to prevent change notifications. +func (ss *ServerSession) hasInitialized() bool { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializedParams != nil +} + +// checkInitialized returns a formatted error if the server has not yet +// received the initialized notification. +func (ss *ServerSession) checkInitialized(method string) error { + if !ss.hasInitialized() { + // TODO(rfindley): enable this check. + // Right now is is flaky, because server tests don't await the initialized notification. + // Perhaps requests should simply block until they have received the initialized notification + + // if strings.HasPrefix(method, "notifications/") { + // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized) + // } else { + // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized) + // } + } + return nil +} + +func (ss *ServerSession) ID() string { + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() + } + return "" +} + +// Ping pings the client. +func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { + _, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params))) + return err +} + +// ListRoots lists the client roots. +func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + if err := ss.checkInitialized(methodListRoots); err != nil { + return nil, err + } + return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) +} + +// CreateMessage sends a sampling request to the client. +func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageParams{Messages: []*SamplingMessage{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessage{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) +} + +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + if err := ss.checkInitialized(methodElicit); err != nil { + return nil, err + } + if params == nil { + return nil, fmt.Errorf("%w: params cannot be nil", jsonrpc2.ErrInvalidParams) + } + + if params.Mode == "" { + params2 := *params + if params.URL != "" || params.ElicitationID != "" { + params2.Mode = "url" + } else { + params2.Mode = "form" + } + params = ¶ms2 + } + + if iparams := ss.InitializeParams(); iparams == nil || iparams.Capabilities == nil || iparams.Capabilities.Elicitation == nil { + return nil, fmt.Errorf("client does not support elicitation") + } + caps := ss.InitializeParams().Capabilities.Elicitation + switch params.Mode { + case "form": + if caps.Form == nil && caps.URL != nil { + // Note: if both 'Form' and 'URL' are nil, we assume the client supports + // form elicitation for backward compatibility. + return nil, errors.New(`client does not support "form" elicitation`) + } + case "url": + if caps.URL == nil { + return nil, errors.New(`client does not support "url" elicitation`) + } + } + + res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) + if err != nil { + return nil, err + } + + if params.RequestedSchema == nil { + return res, nil + } + schema, err := validateElicitSchema(params.RequestedSchema) + if err != nil { + return nil, err + } + if schema == nil { + return res, nil + } + + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, err + } + if err := resolved.Validate(res.Content); err != nil { + return nil, fmt.Errorf("elicitation result content does not match requested schema: %v", err) + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, fmt.Errorf("failed to apply schema defalts to elicitation result: %v", err) + } + + return res, nil +} + +// Log sends a log message to the client. +// The message is not sent if the client has not called SetLevel, or if its level +// is below that of the last SetLevel. +func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { + ss.mu.Lock() + logLevel := ss.state.LogLevel + ss.mu.Unlock() + if logLevel == "" { + // The spec is unclear, but seems to imply that no log messages are sent until the client + // sets the level. + // TODO(jba): read other SDKs, possibly file an issue. + return nil + } + if compareLevels(params.Level, logLevel) < 0 { + return nil + } + return handleNotify(ctx, notificationLoggingMessage, newServerRequest(ss, orZero[Params](params))) +} + +// AddSendingMiddleware wraps the current sending method handler using the provided +// middleware. Middleware is applied from right to left, so that the first one is +// executed first. +// +// For example, AddSendingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Sending middleware is called when a request is sent. It is useful for tasks +// such as tracing, metrics, and adding progress tokens. +func (s *Server) AddSendingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.sendingMethodHandler_, middleware) +} + +// AddReceivingMiddleware wraps the current receiving method handler using +// the provided middleware. Middleware is applied from right to left, so that the +// first one is executed first. +// +// For example, AddReceivingMiddleware(m1, m2, m3) augments the method handler as +// m1(m2(m3(handler))). +// +// Receiving middleware is called when a request is received. It is useful for tasks +// such as authentication, request logging and metrics. +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { + s.mu.Lock() + defer s.mu.Unlock() + addMiddleware(&s.receivingMethodHandler_, middleware) +} + +// serverMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. +var serverMethodInfos = map[string]methodInfo{ + methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: initializeMethodInfo(), + methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), + notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), + notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), +} + +// initializeMethodInfo handles the workaround for #607: we must set +// params.Capabilities.RootsV2. +func initializeMethodInfo() methodInfo { + info := newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0) + info.unmarshalParams = func(m json.RawMessage) (Params, error) { + var params *initializeParamsV2 + if m != nil { + if err := json.Unmarshal(m, ¶ms); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, params, err) + } + } + if params == nil { + return nil, fmt.Errorf(`missing required "params"`) + } + return params.toV1(), nil + } + return info +} + +func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } + +func (ss *ServerSession) sendingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendingMethodHandler_ +} + +func (ss *ServerSession) receivingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.receivingMethodHandler_ +} + +// getConn implements [session.getConn]. +func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } + +// handle invokes the method described by the given JSON RPC request. +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + ss.mu.Lock() + initialized := ss.state.InitializeParams != nil + ss.mu.Unlock() + + // From the spec: + // "The client SHOULD NOT send requests other than pings before the server + // has responded to the initialize request." + switch req.Method { + case methodInitialize, methodPing, notificationInitialized: + default: + if !initialized { + ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) + return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) + } + } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + + // For the streamable transport, we need the request ID to correlate + // server->client calls and notifications to the incoming request from which + // they originated. See [idContextKey] for details. + ctx = context.WithValue(ctx, idContextKey{}, req.ID) + return handleReceive(ctx, ss, req) +} + +// InitializeParams returns the InitializeParams provided during the client's +// initial connection. +func (ss *ServerSession) InitializeParams() *InitializeParams { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializeParams +} + +func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { + if params == nil { + return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) + } + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) + + s := ss.server + return &InitializeResult{ + // TODO(rfindley): alter behavior when falling back to an older version: + // reject unsupported features. + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), + Capabilities: s.capabilities(), + Instructions: s.opts.Instructions, + ServerInfo: s.impl, + }, nil +} + +func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error) { + return &emptyResult{}, nil +} + +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) { + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) + ss.server.opts.Logger.Info("client log level set", "level", params.Level) + return &emptyResult{}, nil +} + +// Close performs a graceful shutdown of the connection, preventing new +// requests from being handled, and waiting for ongoing requests to return. +// Close then terminates the connection. +// +// Close is idempotent and concurrency safe. +func (ss *ServerSession) Close() error { + if ss.keepaliveCancel != nil { + // Note: keepaliveCancel access is safe without a mutex because: + // 1. keepaliveCancel is only written once during startKeepalive (happens-before all Close calls) + // 2. context.CancelFunc is safe to call multiple times and from multiple goroutines + // 3. The keepalive goroutine calls Close on ping failure, but this is safe since + // Close is idempotent and conn.Close() handles concurrent calls correctly + ss.keepaliveCancel() + } + err := ss.conn.Close() + + if ss.onClose != nil && ss.calledOnClose.CompareAndSwap(false, true) { + ss.onClose() + } + + return err +} + +// Wait waits for the connection to be closed by the client. +func (ss *ServerSession) Wait() error { + return ss.conn.Wait() +} + +// startKeepalive starts the keepalive mechanism for this server session. +func (ss *ServerSession) startKeepalive(interval time.Duration) { + startKeepalive(ss, interval, &ss.keepaliveCancel) +} + +// pageToken is the internal structure for the opaque pagination cursor. +// It will be Gob-encoded and then Base64-encoded for use as a string token. +type pageToken struct { + LastUID string // The unique ID of the last resource seen. +} + +// encodeCursor encodes a unique identifier (UID) into a opaque pagination cursor +// by serializing a pageToken struct. +func encodeCursor(uid string) (string, error) { + var buf bytes.Buffer + token := pageToken{LastUID: uid} + encoder := gob.NewEncoder(&buf) + if err := encoder.Encode(token); err != nil { + return "", fmt.Errorf("failed to encode page token: %w", err) + } + return base64.URLEncoding.EncodeToString(buf.Bytes()), nil +} + +// decodeCursor decodes an opaque pagination cursor into the original pageToken struct. +func decodeCursor(cursor string) (*pageToken, error) { + decodedBytes, err := base64.URLEncoding.DecodeString(cursor) + if err != nil { + return nil, fmt.Errorf("failed to decode cursor: %w", err) + } + + var token pageToken + buf := bytes.NewBuffer(decodedBytes) + decoder := gob.NewDecoder(buf) + if err := decoder.Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor) + } + return &token, nil +} + +// paginateList is a generic helper that returns a paginated slice of items +// from a featureSet. It populates the provided result res with the items +// and sets its next cursor for subsequent pages. +// If there are no more pages, the next cursor within the result will be an empty string. +func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageSize int, params P, res R, setFunc func(R, []T)) (R, error) { + var seq iter.Seq[T] + if params.cursorPtr() == nil || *params.cursorPtr() == "" { + seq = fs.all() + } else { + pageToken, err := decodeCursor(*params.cursorPtr()) + // According to the spec, invalid cursors should return Invalid params. + if err != nil { + var zero R + return zero, jsonrpc2.ErrInvalidParams + } + seq = fs.above(pageToken.LastUID) + } + var count int + var features []T + for f := range seq { + count++ + // If we've seen pageSize + 1 elements, we've gathered enough info to determine + // if there's a next page. Stop processing the sequence. + if count == pageSize+1 { + break + } + features = append(features, f) + } + setFunc(res, features) + // No remaining pages. + if count < pageSize+1 { + return res, nil + } + nextCursor, err := encodeCursor(fs.uniqueID(features[len(features)-1])) + if err != nil { + var zero R + return zero, err + } + *res.nextCursorPtr() = nextCursor + return res, nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go new file mode 100644 index 000000000..dcf9888cc --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go new file mode 100644 index 000000000..d83eae7da --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go @@ -0,0 +1,610 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains code shared between client and server, including +// method handler and middleware definitions. +// +// Much of this is here so that we can factor out commonalities using +// generics. If this becomes unwieldy, it can perhaps be simplified with +// reflection. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "reflect" + "slices" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20251125 = "2025-11-25" // not yet released + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20241105 = "2024-11-05" +) + +var supportedProtocolVersions = []string{ + protocolVersion20251125, + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20241105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion +} + +// A MethodHandler handles MCP messages. +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) + +// A Session is either a [ClientSession] or a [ServerSession]. +type Session interface { + // ID returns the session ID, or the empty string if there is none. + ID() string + + sendingMethodInfos() map[string]methodInfo + receivingMethodInfos() map[string]methodInfo + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler + getConn() *jsonrpc2.Connection +} + +// Middleware is a function from [MethodHandler] to [MethodHandler]. +type Middleware func(MethodHandler) MethodHandler + +// addMiddleware wraps the handler in the middleware functions. +func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { + for _, m := range slices.Backward(middleware) { + *handlerp = m(*handlerp) + } +} + +func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().sendingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + params := req.GetParams() + if initParams, ok := params.(*InitializeParams); ok { + // Fix the marshaling of initialize params, to work around #607. + // + // The initialize params we produce should never be nil, nor have nil + // capabilities, so any panic here is a bug. + params = initParams.toV2() + } + // Notifications don't have results. + if strings.HasPrefix(method, "notifications/") { + return nil, req.GetSession().getConn().Notify(ctx, method, params) + } + // Create the result to unmarshal into. + // The concrete type of the result is the return type of the receiving function. + res := info.newResult() + if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { + return nil, err + } + return res, nil +} + +// Helper method to avoid typed nil. +func orZero[T any, P *U, U any](p P) T { + if p == nil { + var zero T + return zero + } + return any(p).(T) +} + +func handleNotify(ctx context.Context, method string, req Request) error { + mh := req.GetSession().sendingMethodHandler() + _, err := mh(ctx, method, req) + return err +} + +func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { + mh := req.GetSession().sendingMethodHandler() + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, method, req) + if err != nil { + var z R + return z, err + } + return res.(R), nil +} + +// defaultReceivingMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. +func defaultReceivingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().receivingMethodInfos()[method] + if !ok { + // This can be called from user code, with an arbitrary value for method. + return nil, jsonrpc2.ErrNotHandled + } + return info.handleMethod(ctx, method, req) +} + +func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { + info, err := checkRequest(jreq, session.receivingMethodInfos()) + if err != nil { + return nil, err + } + params, err := info.unmarshalParams(jreq.Params) + if err != nil { + return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) + } + + mh := session.receivingMethodHandler() + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) + // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. + res, err := mh(ctx, jreq.Method, req) + if err != nil { + return nil, err + } + return res, nil +} + +// checkRequest checks the given request against the provided method info, to +// ensure it is a valid MCP request. +// +// If valid, the relevant method info is returned. Otherwise, a non-nil error +// is returned describing why the request is invalid. +// +// This is extracted from request handling so that it can be called in the +// transport layer to preemptively reject bad requests. +func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) { + info, ok := infos[req.Method] + if !ok { + return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) + } + if info.flags¬ification != 0 && req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + if info.flags¬ification == 0 && !req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + // missingParamsOK is checked here to catch the common case where "params" is + // missing entirely. + // + // However, it's checked again after unmarshalling to catch the rare but + // possible case where "params" is JSON null (see https://go.dev/issue/33835). + if info.flags&missingParamsOK == 0 && len(req.Params) == 0 { + return methodInfo{}, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return info, nil +} + +// methodInfo is information about sending and receiving a method. +type methodInfo struct { + // flags is a collection of flags controlling how the JSONRPC method is + // handled. See individual flag values for documentation. + flags methodFlags + // Unmarshal params from the wire into a Params struct. + // Used on the receive side. + unmarshalParams func(json.RawMessage) (Params, error) + newRequest func(Session, Params, *RequestExtra) Request + // Run the code when a call to the method is received. + // Used on the receive side. + handleMethod MethodHandler + // Create a pointer to a Result struct. + // Used on the send side. + newResult func() Result +} + +// The following definitions support converting from typed to untyped method handlers. +// Type parameter meanings: +// - S: sessions +// - P: params +// - R: results + +// A typedMethodHandler is like a MethodHandler, but with type information. +type ( + typedClientMethodHandler[P Params, R Result] func(context.Context, *ClientRequest[P]) (R, error) + typedServerMethodHandler[P Params, R Result] func(context.Context, *ServerRequest[P]) (R, error) +) + +type paramsPtr[T any] interface { + *T + Params +} + +type methodFlags int + +const ( + notification methodFlags = 1 << iota // method is a notification, not request + missingParamsOK // params may be missing or null +) + +func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { + r := &ClientRequest[P]{Session: s.(*ClientSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ClientRequest[P])) + }) + return mi +} + +func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ServerRequest[P])) + }) + return mi +} + +// newMethodInfo creates a methodInfo from a typedMethodHandler. +// +// If isRequest is set, the method is treated as a request rather than a +// notification. +func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { + return methodInfo{ + flags: flags, + unmarshalParams: func(m json.RawMessage) (Params, error) { + var p P + if m != nil { + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } + } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return orZero[Params](p), nil + }, + // newResult is used on the send side, to construct the value to unmarshal the result into. + // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. + // TODO(jba): explore generic approaches to this, perhaps by treating R in + // the signature as the unpointered type. + newResult: func() Result { return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) }, + } +} + +// serverMethod is glue for creating a typedMethodHandler from a method on Server. +func serverMethod[P Params, R Result]( + f func(*Server, context.Context, *ServerRequest[P]) (R, error), +) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.Session.server, ctx, req) + } +} + +// clientMethod is glue for creating a typedMethodHandler from a method on Client. +func clientMethod[P Params, R Result]( + f func(*Client, context.Context, *ClientRequest[P]) (R, error), +) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.Session.client, ctx, req) + } +} + +// serverSessionMethod is glue for creating a typedServerMethodHandler from a method on ServerSession. +func serverSessionMethod[P Params, R Result](f func(*ServerSession, context.Context, P) (R, error)) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.GetSession().(*ServerSession), ctx, req.Params) + } +} + +// clientSessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Context, P) (R, error)) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.GetSession().(*ClientSession), ctx, req.Params) + } +} + +// MCP-specific error codes. +const ( + // CodeResourceNotFound indicates that a requested resource could not be found. + CodeResourceNotFound = -32002 + // CodeURLElicitationRequired indicates that the server requires URL elicitation + // before processing the request. The client should execute the elicitation handler + // with the elicitations provided in the error data. + CodeURLElicitationRequired = -32042 +) + +// URLElicitationRequiredError returns an error indicating that URL elicitation is required +// before the request can be processed. The elicitations parameter should contain the +// elicitation requests that must be completed. +func URLElicitationRequiredError(elicitations []*ElicitParams) error { + // Validate that all elicitations are URL mode + for _, elicit := range elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // default mode + } + if mode != "url" { + panic(fmt.Sprintf("URLElicitationRequiredError requires all elicitations to be URL mode, got %q", mode)) + } + } + + data, err := json.Marshal(map[string]any{ + "elicitations": elicitations, + }) + if err != nil { + // This should never happen with valid ElicitParams + panic(fmt.Sprintf("failed to marshal elicitations: %v", err)) + } + return &jsonrpc.Error{ + Code: CodeURLElicitationRequired, + Message: "URL elicitation required", + Data: json.RawMessage(data), + } +} + +// Internal error codes +const ( + // The error code if the method exists and was called properly, but the peer does not support it. + // + // TODO(rfindley): this code is wrong, and we should fix it to be + // consistent with other SDKs. + codeUnsupportedMethod = -31001 +) + +// notifySessions calls Notify on all the sessions. +// Should be called on a copy of the peer sessions. +// The logger must be non-nil. +func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) { + if sessions == nil { + return + } + // Notify with the background context, so the messages are sent on the + // standalone stream. + // TODO: make this timeout configurable, or call handleNotify asynchronously. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // TODO: there's a potential spec violation here, when the feature list + // changes before the session (client or server) is initialized. + for _, s := range sessions { + req := newRequest(s, params) + if err := handleNotify(ctx, method, req); err != nil { + logger.Warn(fmt.Sprintf("calling %s: %v", method, err)) + } + } +} + +func newRequest[S Session, P Params](s S, p P) Request { + switch s := any(s).(type) { + case *ClientSession: + return &ClientRequest[P]{Session: s, Params: p} + case *ServerSession: + return &ServerRequest[P]{Session: s, Params: p} + default: + panic("bad session") + } +} + +// Meta is additional metadata for requests, responses and other types. +type Meta map[string]any + +// GetMeta returns metadata from a value. +func (m Meta) GetMeta() map[string]any { return m } + +// SetMeta sets the metadata on a value. +func (m *Meta) SetMeta(x map[string]any) { *m = x } + +const progressTokenKey = "progressToken" + +func getProgressToken(p Params) any { + return p.GetMeta()[progressTokenKey] +} + +func setProgressToken(p Params, pt any) { + switch pt.(type) { + // Support int32 and int64 for atomic.IntNN. + case int, int32, int64, string: + default: + panic(fmt.Sprintf("progress token %v is of type %[1]T, not int or string", pt)) + } + m := p.GetMeta() + if m == nil { + m = map[string]any{} + } + m[progressTokenKey] = pt +} + +// A Request is a method request with parameters and additional information, such as the session. +// Request is implemented by [*ClientRequest] and [*ServerRequest]. +type Request interface { + isRequest() + GetSession() Session + GetParams() Params + // GetExtra returns the Extra field for ServerRequests, and nil for ClientRequests. + GetExtra() *RequestExtra +} + +// A ClientRequest is a request to a client. +type ClientRequest[P Params] struct { + Session *ClientSession + Params P +} + +// A ServerRequest is a request to a server. +type ServerRequest[P Params] struct { + Session *ServerSession + Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any + + // If set, CloseSSEStream explicitly closes the current SSE request stream. + // + // [SEP-1699] introduced server-side SSE stream disconnection: for + // long-running requests, servers may opt to close the SSE stream and + // ask the client to retry at a later time. CloseSSEStream implements this + // feature; if RetryAfter is set, an event is sent with a `retry:` field + // to configure the reconnection delay. + // + // [SEP-1699]: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699 + CloseSSEStream func(CloseSSEStreamArgs) +} + +// CloseSSEStreamArgs are arguments for [RequestExtra.CloseSSEStream]. +type CloseSSEStreamArgs struct { + // RetryAfter configures the reconnection delay sent to the client via the + // SSE retry field. If zero, no retry field is sent. + RetryAfter time.Duration +} + +func (*ClientRequest[P]) isRequest() {} +func (*ServerRequest[P]) isRequest() {} + +func (r *ClientRequest[P]) GetSession() Session { return r.Session } +func (r *ServerRequest[P]) GetSession() Session { return r.Session } + +func (r *ClientRequest[P]) GetParams() Params { return r.Params } +func (r *ServerRequest[P]) GetParams() Params { return r.Params } + +func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } +func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } + +func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { + return &ServerRequest[P]{Session: s, Params: p} +} + +func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { + return &ClientRequest[P]{Session: s, Params: p} +} + +// Params is a parameter (input) type for an MCP call or notification. +type Params interface { + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) + + // isParams discourages implementation of Params outside of this package. + isParams() +} + +// RequestParams is a parameter (input) type for an MCP request. +type RequestParams interface { + Params + + // GetProgressToken returns the progress token from the params' Meta field, or nil + // if there is none. + GetProgressToken() any + + // SetProgressToken sets the given progress token into the params' Meta field. + // It panics if its argument is not an int or a string. + SetProgressToken(any) +} + +// Result is a result of an MCP call. +type Result interface { + // isResult discourages implementation of Result outside of this package. + isResult() + + // GetMeta returns metadata from a value. + GetMeta() map[string]any + // SetMeta sets the metadata on a value. + SetMeta(map[string]any) +} + +// emptyResult is returned by methods that have no result, like ping. +// Those methods cannot return nil, because jsonrpc2 cannot handle nils. +type emptyResult struct{} + +func (*emptyResult) isResult() {} +func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } +func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } + +type listParams interface { + // Returns a pointer to the param's Cursor field. + cursorPtr() *string +} + +type listResult[T any] interface { + // Returns a pointer to the param's NextCursor field. + nextCursorPtr() *string +} + +// keepaliveSession represents a session that supports keepalive functionality. +type keepaliveSession interface { + Ping(ctx context.Context, params *PingParams) error + Close() error +} + +// startKeepalive starts the keepalive mechanism for a session. +// It assigns the cancel function to the provided cancelPtr and starts a goroutine +// that sends ping messages at the specified interval. +func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + // Assign cancel function before starting goroutine to avoid race condition. + // We cannot return it because the caller may need to cancel during the + // window between goroutine scheduling and function return. + *cancelPtr = cancel + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) + err := session.Ping(pingCtx, nil) + pingCancel() + if err != nil { + // Ping failed, close the session + _ = session.Close() + return + } + } + } + }() +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go new file mode 100644 index 000000000..7f644918b --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/sse.go @@ -0,0 +1,479 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// This file implements support for SSE (HTTP with server-sent events) +// transport server and client. +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +// +// The transport is simple, at least relative to the new streamable transport +// introduced in the 2025-03-26 version of the spec. In short: +// +// 1. Sessions are initiated via a hanging GET request, which streams +// server->client messages as SSE 'message' events. +// 2. The first event in the SSE stream must be an 'endpoint' event that +// informs the client of the session endpoint. +// 3. The client POSTs client->server messages to the session endpoint. +// +// Therefore, the each new GET request hands off its responsewriter to an +// [SSEServerTransport] type that abstracts the transport as follows: +// - Write writes a new event to the responseWriter, or fails if the GET has +// exited. +// - Read reads off a message queue that is pushed to via POST requests. +// - Close causes the hanging GET to exit. + +// SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by +// the [2024-11-05 version] of the MCP spec. +// +// [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEHandler struct { + getServer func(request *http.Request) *Server + opts SSEOptions + onConnection func(*ServerSession) // for testing; must not block + + mu sync.Mutex + sessions map[string]*SSEServerTransport +} + +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + +// NewSSEHandler returns a new [SSEHandler] that creates and manages MCP +// sessions created via incoming HTTP requests. +// +// Sessions are created when the client issues a GET request to the server, +// which must accept text/event-stream responses (server-sent events). +// For each such request, a new [SSEServerTransport] is created with a distinct +// messages endpoint, and connected to the server returned by getServer. +// The SSEHandler also handles requests to the message endpoints, by +// delegating them to the relevant server transport. +// +// The getServer function may return a distinct [Server] for each new +// request, or reuse an existing server. If it returns nil, the handler +// will return a 400 Bad Request. +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ + getServer: getServer, + sessions: make(map[string]*SSEServerTransport), + } + + if opts != nil { + s.opts = *opts + } + + return s +} + +// A SSEServerTransport is a logical SSE session created through a hanging GET +// request. +// +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// +// When connected, it returns the following [Connection] implementation: +// - Writes are SSE 'message' events to the GET response. +// - Reads are received from POSTs to the session endpoint, via +// [SSEServerTransport.ServeHTTP]. +// - Close terminates the hanging GET. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +// +// Each SSEServerTransport may be connected (via [Server.Connect]) at most +// once, since [SSEServerTransport.ServeHTTP] serves messages to the connected +// session. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. +type SSEServerTransport struct { + // Endpoint is the endpoint for this session, where the client can POST + // messages. + Endpoint string + + // Response is the hanging response body to the incoming GET request. + Response http.ResponseWriter + + // incoming is the queue of incoming messages. + // It is never closed, and by convention, incoming is non-nil if and only if + // the transport is connected. + incoming chan jsonrpc.Message + + // We must guard both pushes to the incoming queue and writes to the response + // writer, because incoming POST requests are arbitrarily concurrent and we + // need to ensure we don't write push to the queue, or write to the + // ResponseWriter, after the session GET request exits. + mu sync.Mutex // also guards writes to Response + closed bool // set when the stream is closed + done chan struct{} // closed when the connection is closed +} + +// ServeHTTP handles POST requests to the transport endpoint. +func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.incoming == nil { + http.Error(w, "session not connected", http.StatusInternalServerError) + return + } + + // Read and parse the message. + data, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + // Optionally, we could just push the data onto a channel, and let the + // message fail to parse when it is read. This failure seems a bit more + // useful + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + http.Error(w, "failed to parse body", http.StatusBadRequest) + return + } + if req, ok := msg.(*jsonrpc.Request); ok { + if _, err := checkRequest(req, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + select { + case t.incoming <- msg: + w.WriteHeader(http.StatusAccepted) + case <-t.done: + http.Error(w, "session closed", http.StatusBadRequest) + } +} + +// Connect sends the 'endpoint' event to the client. +// See [SSEServerTransport] for more details on the [Connection] implementation. +func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { + if t.incoming != nil { + return nil, fmt.Errorf("already connected") + } + t.incoming = make(chan jsonrpc.Message, 100) + t.done = make(chan struct{}) + _, err := writeEvent(t.Response, Event{ + Name: "endpoint", + Data: []byte(t.Endpoint), + }) + if err != nil { + return nil, err + } + return &sseServerConn{t: t}, nil +} + +func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + sessionID := req.URL.Query().Get("sessionid") + + // TODO: consider checking Content-Type here. For now, we are lax. + + // For POST requests, the message body is a message to send to a session. + if req.Method == http.MethodPost { + // Look up the session. + if sessionID == "" { + http.Error(w, "sessionid must be provided", http.StatusBadRequest) + return + } + h.mu.Lock() + session := h.sessions[sessionID] + h.mu.Unlock() + if session == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + session.ServeHTTP(w, req) + return + } + + if req.Method != http.MethodGet { + http.Error(w, "invalid method", http.StatusMethodNotAllowed) + return + } + + // GET requests create a new session, and serve messages over SSE. + + // TODO: it's not entirely documented whether we should check Accept here. + // Let's again be lax and assume the client will accept SSE. + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + sessionID = randText() + endpoint, err := req.URL.Parse("?sessionid=" + sessionID) + if err != nil { + http.Error(w, "internal error: failed to create endpoint", http.StatusInternalServerError) + return + } + + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} + + // The session is terminated when the request exits. + h.mu.Lock() + h.sessions[sessionID] = transport + h.mu.Unlock() + defer func() { + h.mu.Lock() + delete(h.sessions, sessionID) + h.mu.Unlock() + }() + + server := h.getServer(req) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + ss, err := server.Connect(req.Context(), transport, nil) + if err != nil { + http.Error(w, "connection failed", http.StatusInternalServerError) + return + } + if h.onConnection != nil { + h.onConnection(ss) + } + defer ss.Close() // close the transport when the GET exits + + select { + case <-req.Context().Done(): + case <-transport.done: + } +} + +// sseServerConn implements the [Connection] interface for a single [SSEServerTransport]. +// It hides the Connection interface from the SSEServerTransport API. +type sseServerConn struct { + t *SSEServerTransport +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (s *sseServerConn) SessionID() string { return "" } + +// Read implements jsonrpc2.Reader. +func (s *sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg := <-s.t.incoming: + return msg, nil + case <-s.t.done: + return nil, io.EOF + } +} + +// Write implements jsonrpc2.Writer. +func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if ctx.Err() != nil { + return ctx.Err() + } + + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + s.t.mu.Lock() + defer s.t.mu.Unlock() + + // Note that it is invalid to write to a ResponseWriter after ServeHTTP has + // exited, and so we must lock around this write and check isDone, which is + // set before the hanging GET exits. + if s.t.closed { + return io.EOF + } + + _, err = writeEvent(s.t.Response, Event{Name: "message", Data: data}) + return err +} + +// Close implements io.Closer, and closes the session. +// +// It must be safe to call Close more than once, as the close may +// asynchronously be initiated by either the server closing its connection, or +// by the hanging GET exiting. +func (s *sseServerConn) Close() error { + s.t.mu.Lock() + defer s.t.mu.Unlock() + if !s.t.closed { + s.t.closed = true + close(s.t.done) + } + return nil +} + +// An SSEClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the SSE transport defined by the 2024-11-05 version of the +// spec. +// +// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports +type SSEClientTransport struct { + // Endpoint is the SSE endpoint to connect to. + Endpoint string + + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client +} + +// Connect connects through the client endpoint. +func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { + parsedURL, err := url.Parse(c.Endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint: %v", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", c.Endpoint, nil) + if err != nil { + return nil, err + } + httpClient := c.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + req.Header.Set("Accept", "text/event-stream") + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + msgEndpoint, err := func() (*url.URL, error) { + var evt Event + for evt, err = range scanEvents(resp.Body) { + break + } + if err != nil { + return nil, err + } + if evt.Name != "endpoint" { + return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") + } + raw := string(evt.Data) + return parsedURL.Parse(raw) + }() + if err != nil { + resp.Body.Close() + return nil, fmt.Errorf("missing endpoint: %v", err) + } + + // From here on, the stream takes ownership of resp.Body. + s := &sseClientConn{ + client: httpClient, + msgEndpoint: msgEndpoint, + incoming: make(chan []byte, 100), + body: resp.Body, + done: make(chan struct{}), + } + + go func() { + defer s.Close() // close the transport when the GET exits + + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return + } + select { + case s.incoming <- evt.Data: + case <-s.done: + return + } + } + }() + + return s, nil +} + +// An sseClientConn is a logical jsonrpc2 connection that implements the client +// half of the SSE protocol: +// - Writes are POSTS to the session endpoint. +// - Reads are SSE 'message' events, and pushes them onto a buffered channel. +// - Close terminates the GET request. +type sseClientConn struct { + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages + + mu sync.Mutex + body io.ReadCloser // body of the hanging GET + closed bool // set when the stream is closed + done chan struct{} // closed when the stream is closed +} + +// TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) +func (c *sseClientConn) SessionID() string { return "" } + +func (c *sseClientConn) isDone() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case <-c.done: + return nil, io.EOF + + case data := <-c.incoming: + // TODO(rfindley): do we really need to check this? We receive from c.done above. + if c.isDone() { + return nil, io.EOF + } + msg, err := jsonrpc2.DecodeMessage(data) + if err != nil { + return nil, err + } + return msg, nil + } +} + +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + if c.isDone() { + return io.EOF + } + req, err := http.NewRequestWithContext(ctx, "POST", c.msgEndpoint.String(), bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("failed to write: %s", resp.Status) + } + return nil +} + +func (c *sseClientConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + c.closed = true + _ = c.body.Close() + close(c.done) + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go new file mode 100644 index 000000000..b4b2fa310 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go @@ -0,0 +1,2040 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// NOTE: see streamable_server.go and streamable_client.go for detailed +// documentation of the streamable server design. +// TODO: move the client and server logic into those files. + +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "math" + "math/rand/v2" + "net/http" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" + lastEventIDHeader = "Last-Event-ID" +) + +// A StreamableHTTPHandler is an http.Handler that serves streamable MCP +// sessions, as defined by the [MCP spec]. +// +// [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html +type StreamableHTTPHandler struct { + getServer func(*http.Request) *Server + opts StreamableHTTPOptions + + onTransportDeletion func(sessionID string) // for testing + + mu sync.Mutex + sessions map[string]*sessionInfo // keyed by session ID +} + +type sessionInfo struct { + session *ServerSession + transport *StreamableServerTransport + // userID is the user ID from the TokenInfo when the session was created. + // If non-empty, subsequent requests must have the same user ID to prevent + // session hijacking. + userID string + + // If timeout is set, automatically close the session after an idle period. + timeout time.Duration + timerMu sync.Mutex + refs int // reference count + timer *time.Timer +} + +// startPOST signals that a POST request for this session is starting (which +// carries a client->server message), pausing the session timeout if it was +// running. +// +// TODO: we may want to also pause the timer when resuming non-standalone SSE +// streams, but that is tricy to implement. Clients should generally make +// keepalive pings if they want to keep the session live. +func (i *sessionInfo) startPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + if i.refs == 0 { + i.timer.Stop() + } + i.refs++ +} + +// endPOST sigals that a request for this session is ending, starting the +// timeout if there are no other requests running. +func (i *sessionInfo) endPOST() { + if i.timeout <= 0 { + return + } + + i.timerMu.Lock() + defer i.timerMu.Unlock() + + if i.timer == nil { + return // timer stopped permanently + } + + i.refs-- + assert(i.refs >= 0, "negative ref count") + if i.refs == 0 { + i.timer.Reset(i.timeout) + } +} + +// stopTimer stops the inactivity timer permanently. +func (i *sessionInfo) stopTimer() { + i.timerMu.Lock() + defer i.timerMu.Unlock() + if i.timer != nil { + i.timer.Stop() + i.timer = nil + } +} + +// StreamableHTTPOptions configures the StreamableHTTPHandler. +type StreamableHTTPOptions struct { + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. + Stateless bool + + // TODO(#148): support session retention (?) + + // JSONResponse causes streamable responses to return application/json rather + // than text/event-stream ([§2.1.5] of the spec). + // + // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + JSONResponse bool + + // Logger specifies the logger to use. + // If nil, do not log. + Logger *slog.Logger + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // SessionTimeout configures a timeout for idle sessions. + // + // When sessions receive no new HTTP requests from the client for this + // duration, they are automatically closed. + // + // If SessionTimeout is the zero value, idle sessions are never closed. + SessionTimeout time.Duration +} + +// NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. +// +// The getServer function is used to create or look up servers for new +// sessions. It is OK for getServer to return the same server multiple times. +// If getServer returns nil, a 400 Bad Request will be served. +func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { + h := &StreamableHTTPHandler{ + getServer: getServer, + sessions: make(map[string]*sessionInfo), + } + if opts != nil { + h.opts = *opts + } + + if h.opts.Logger == nil { // ensure we have a logger + h.opts.Logger = ensureLogger(nil) + } + + return h +} + +// closeAll closes all ongoing sessions, for tests. +// +// TODO(rfindley): investigate the best API for callers to configure their +// session lifecycle. (?) +// +// Should we allow passing in a session store? That would allow the handler to +// be stateless. +func (h *StreamableHTTPHandler) closeAll() { + // TODO: if we ever expose this outside of tests, we'll need to do better + // than simply collecting sessions while holding the lock: we need to prevent + // new sessions from being added. + // + // Currently, sessions remove themselves from h.sessions when closed, so we + // can't call Close while holding the lock. + h.mu.Lock() + sessionInfos := slices.Collect(maps.Values(h.sessions)) + h.sessions = nil + h.mu.Unlock() + for _, s := range sessionInfos { + s.session.Close() + } +} + +func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Allow multiple 'Accept' headers. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Accept#syntax + accept := strings.Split(strings.Join(req.Header.Values("Accept"), ","), ",") + var jsonOK, streamOK bool + for _, c := range accept { + switch strings.TrimSpace(c) { + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": + jsonOK = true + streamOK = true + } + } + + if req.Method == http.MethodGet { + if !streamOK { + http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) + return + } + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. + http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) + return + } + + sessionID := req.Header.Get(sessionIDHeader) + var sessInfo *sessionInfo + if sessionID != "" { + h.mu.Lock() + sessInfo = h.sessions[sessionID] + h.mu.Unlock() + if sessInfo == nil && !h.opts.Stateless { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. + http.Error(w, "session not found", http.StatusNotFound) + return + } + // Prevent session hijacking: if the session was created with a user ID, + // verify that subsequent requests come from the same user. + if sessInfo != nil && sessInfo.userID != "" { + tokenInfo := auth.TokenInfoFromContext(req.Context()) + if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID { + http.Error(w, "session user mismatch", http.StatusForbidden) + return + } + } + } + + if req.Method == http.MethodDelete { + if sessionID == "" { + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + return + } + if sessInfo != nil { // sessInfo may be nil in stateless mode + // Closing the session also removes it from h.sessions, due to the + // onClose callback. + sessInfo.session.Close() + } + w.WriteHeader(http.StatusNoContent) + return + } + + switch req.Method { + case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } + default: + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) + return + } + + // [§2.7] of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + return + } + + if sessInfo == nil { + server := h.getServer(req) + if server == nil { + // The getServer argument to NewStreamableHTTPHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = server.opts.GetSessionID() + } + transport := &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + } + + // Sessions without a session ID are also stateless: there's no way to + // address them. + stateless := h.opts.Stateless || sessionID == "" + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } + state.LogLevel = "info" + connectOpts = &ServerSessionOptions{ + State: state, + } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + defer h.mu.Unlock() + if info, ok := h.sessions[transport.SessionID]; ok { + info.stopTimer() + delete(h.sessions, transport.SessionID) + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + } + }, + } + } + + // Pass req.Context() here, to allow middleware to add context values. + // The context is detached in the jsonrpc2 library when handling the + // long-running stream. + session, err := server.Connect(req.Context(), transport, connectOpts) + if err != nil { + http.Error(w, "failed connection", http.StatusInternalServerError) + return + } + // Capture the user ID from the token info to enable session hijacking + // prevention on subsequent requests. + var userID string + if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { + userID = tokenInfo.UserID + } + sessInfo = &sessionInfo{ + session: session, + transport: transport, + userID: userID, + } + + if stateless { + // Stateless mode: close the session when the request exits. + defer session.Close() // close the fake session after handling the request + } else { + // Otherwise, save the transport so that it can be reused + + // Clean up the session when it times out. + // + // Note that the timer here may fire multiple times, but + // sessInfo.session.Close is idempotent. + if h.opts.SessionTimeout > 0 { + sessInfo.timeout = h.opts.SessionTimeout + sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() { + sessInfo.session.Close() + }) + } + h.mu.Lock() + h.sessions[transport.SessionID] = sessInfo + h.mu.Unlock() + defer func() { + // If initialization failed, clean up the session (#578). + if session.InitializeParams() == nil { + // Initialization failed. + session.Close() + } + }() + } + } + + if req.Method == http.MethodPost { + sessInfo.startPOST() + defer sessInfo.endPOST() + } + + sessInfo.transport.ServeHTTP(w, req) +} + +// A StreamableServerTransport implements the server side of the MCP streamable +// transport. +// +// Each StreamableServerTransport must be connected (via [Server.Connect]) at +// most once, since [StreamableServerTransport.ServeHTTP] serves messages to +// the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the related stream, or to the standalone SSE stream, +// according to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the standalone SSE stream. +// - Requests or notifications made with a detached context.Context value are +// routed to the standalone SSE stream. +type StreamableServerTransport struct { + // SessionID is the ID of this session. + // + // If SessionID is the empty string, this is a 'stateless' session, which has + // limited ability to communicate with the client. Otherwise, the session ID + // must be globally unique, that is, different from any other session ID + // anywhere, past and future. (We recommend using a crypto random number + // generator to produce one, as with [crypto/rand.Text].) + SessionID string + + // Stateless controls whether the eventstore is 'Stateless'. Server sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + + // EventStore enables stream resumption. + // + // If set, EventStore will be used to persist stream events and replay them + // upon stream resumption. + EventStore EventStore + + // jsonResponse, if set, tells the server to prefer to respond to requests + // using application/json responses rather than text/event-stream. + // + // Specifically, responses will be application/json whenever incoming POST + // request contain only a single message. In this case, notifications or + // requests made within the context of a server request will be sent to the + // standalone SSE stream, if any. + // + // TODO(rfindley): jsonResponse should be exported, since + // StreamableHTTPOptions.JSONResponse is exported, and we want to allow users + // to write their own streamable HTTP handler. + jsonResponse bool + + // optional logger provided through the [StreamableHTTPOptions.Logger]. + // + // TODO(rfindley): logger should be exported, since we want to allow users + // to write their own streamable HTTP handler. + logger *slog.Logger + + // connection is non-nil if and only if the transport has been connected. + connection *streamableServerConn +} + +// Connect implements the [Transport] interface. +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { + if t.connection != nil { + return nil, fmt.Errorf("transport already connected") + } + t.connection = &streamableServerConn{ + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + logger: ensureLogger(t.logger), // see #556: must be non-nil + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), + } + // Stream 0 corresponds to the standalone SSE stream. + // + // It is always text/event-stream, since it must carry arbitrarily many + // messages. + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, nil, "") + if err != nil { + return nil, err + } + return t.connection, nil +} + +type streamableServerConn struct { + sessionID string + stateless bool + jsonResponse bool + eventStore EventStore + + logger *slog.Logger + + incoming chan jsonrpc.Message // messages from the client to the server + + mu sync.Mutex // guards all fields below + + // Sessions are closed exactly once. + isDone bool + done chan struct{} + + // Sessions can have multiple logical connections (which we call streams), + // corresponding to HTTP requests. Additionally, streams may be resumed by + // subsequent HTTP requests, when the HTTP connection is terminated + // unexpectedly. + // + // Therefore, we use a logical stream ID to key the stream state, and + // perform the accounting described below when incoming HTTP requests are + // handled. + + // streams holds the logical streams for this session, keyed by their ID. + // + // Lifecycle: streams persist until all of their responses are received from + // the server. + streams map[string]*stream + + // requestStreams maps incoming requests to their logical stream ID. + // + // Lifecycle: requestStreams persist until their response is received. + requestStreams map[jsonrpc.ID]string +} + +func (c *streamableServerConn) SessionID() string { + return c.sessionID +} + +// A stream is a single logical stream of SSE events within a server session. +// A stream begins with a client request, or with a client GET that has +// no Last-Event-ID header. +// +// A stream ends only when its session ends; we cannot determine its end otherwise, +// since a client may send a GET with a Last-Event-ID that references the stream +// at any time. +type stream struct { + // id is the logical ID for the stream, unique within a session. + // + // The standalone SSE stream has id "". + id string + + // logger is used for logging errors during stream operations. + logger *slog.Logger + + // mu guards the fields below, as well as storage of new messages in the + // connection's event store (if any). + mu sync.Mutex + + // If pendingJSONMessages is non-nil, this is a JSON stream and messages are + // collected here until the stream is complete, at which point they are + // flushed as a single JSON response. Note that the non-nilness of this field + // is significant, as it signals the expected content type. + // + // Note: if we remove support for batching, this could just be a bool. + pendingJSONMessages []json.RawMessage + + // w is the HTTP response writer for this stream. A non-nil w indicates + // that the stream is claimed by an HTTP request (the hanging POST or GET); + // it is set to nil when the request completes. + w http.ResponseWriter + + // done is closed to release the hanging HTTP request. + // + // Invariant: a non-nil done implies w is also non-nil, though the converse + // is not necessarily true: done is set to nil when it is closed, to avoid + // duplicate closure. + done chan struct{} + + // lastIdx is the index of the last written SSE event, for event ID generation. + // It starts at -1 since indices start at 0. + lastIdx int + + // protocolVersion is the protocol version for this stream. + protocolVersion string + + // requests is the set of unanswered incoming requests for the stream. + // + // Requests are removed when their response has been received. + // In practice, there is only one request, but in the 2025-03-26 version of + // the spec and earlier there was a concept of batching, in which POST + // payloads could hold multiple requests or responses. + requests map[jsonrpc.ID]struct{} +} + +// close sends a 'close' event to the client (if protocolVersion >= 2025-11-25 +// and reconnectAfter > 0) and closes the done channel. +// +// The done channel is set to nil after closing, so that done != nil implies +// the stream is active and done is open. This simplifies checks elsewhere. +func (s *stream) close(reconnectAfter time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done == nil { + return // stream not connected or already closed + } + if s.protocolVersion >= protocolVersion20251125 && reconnectAfter > 0 { + reconnectStr := strconv.FormatInt(reconnectAfter.Milliseconds(), 10) + if _, err := writeEvent(s.w, Event{ + Name: "close", + Retry: reconnectStr, + }); err != nil { + s.logger.Warn(fmt.Sprintf("Writing close event: %v", err)) + } + } + close(s.done) + s.done = nil +} + +// release releases the stream from its HTTP request, allowing it to be +// claimed by another request (e.g., for resumption). +func (s *stream) release() { + s.mu.Lock() + defer s.mu.Unlock() + s.w = nil + s.done = nil // may already be nil, if the stream is done or closed +} + +// deliverLocked writes data to the stream (for SSE) or stores it in +// pendingJSONMessages (for JSON mode). The eventID is used for SSE event ID; +// pass "" to omit. +// +// If responseTo is valid, it is removed from the requests map. When all +// requests have been responded to, the done channel is closed and set to nil. +// +// Returns true if the stream is now done (all requests have been responded to). +// The done value is always accurate, even if an error is returned. +// +// s.mu must be held when calling this method. +func (s *stream) deliverLocked(data []byte, eventID string, responseTo jsonrpc.ID) (done bool, err error) { + // First, record the response. We must do this *before* returning an error + // below, as even if the stream is disconnected we want to update our + // accounting. + if responseTo.IsValid() { + delete(s.requests, responseTo) + } + // Now, try to deliver the message to the client. + done = len(s.requests) == 0 && s.id != "" + if s.done == nil { + return done, fmt.Errorf("stream not connected or already closed") + } + if done { + defer func() { close(s.done); s.done = nil }() + } + // Try to write to the response. + // + // If we get here, the request is still hanging (because s.done != nil + // implies s.w != nil), but may have been cancelled by the client/http layer: + // there's a brief race between request cancellation and releasing the + // stream. + if s.pendingJSONMessages != nil { + s.pendingJSONMessages = append(s.pendingJSONMessages, data) + if done { + // Flush all pending messages as JSON response. + var toWrite []byte + if len(s.pendingJSONMessages) == 1 { + toWrite = s.pendingJSONMessages[0] + } else { + toWrite, err = json.Marshal(s.pendingJSONMessages) + if err != nil { + return done, err + } + } + if _, err := s.w.Write(toWrite); err != nil { + return done, err + } + } + } else { + // SSE mode: write event to response writer. + s.lastIdx++ + if _, err := writeEvent(s.w, Event{Name: "message", Data: data, ID: eventID}); err != nil { + return done, err + } + } + return done, nil +} + +// doneLocked reports whether the stream is logically complete. +// +// s.requests was populated when reading the POST body, requests are deleted as +// they are responded to. Once all requests have been responded to, the stream +// is done. +// +// s.mu must be held while calling this function. +func (s *stream) doneLocked() bool { + return len(s.requests) == 0 && s.id != "" +} + +func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonrpc.ID]struct{}, id string) (*stream, error) { + if c.eventStore != nil { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } + } + return &stream{ + id: id, + requests: requests, + lastIdx: -1, // indices start at 0, incremented before each write + logger: c.logger, + }, nil +} + +// We track the incoming request ID inside the handler context using +// idContextValue, so that notifications and server->client calls that occur in +// the course of handling incoming requests are correlated with the incoming +// request that caused them, and can be dispatched as server-sent events to the +// correct HTTP request. +// +// Currently, this is implemented in [ServerSession.handle]. This is not ideal, +// because it means that a user of the MCP package couldn't implement the +// streamable transport, as they'd lack this privileged access. +// +// If we ever wanted to expose this mechanism, we have a few options: +// 1. Make ServerSession an interface, and provide an implementation of +// ServerSession to handlers that closes over the incoming request ID. +// 2. Expose a 'HandlerTransport' interface that allows transports to provide +// a handler middleware, so that we don't hard-code this behavior in +// ServerSession.handle. +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets +// any transport access the incoming request ID. +// +// For now, by giving only the StreamableServerTransport access to the request +// ID, we avoid having to make this API decision. +type idContextKey struct{} + +// ServeHTTP handles a single HTTP request for the session. +func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.connection == nil { + http.Error(w, "transport not connected", http.StatusInternalServerError) + return + } + switch req.Method { + case http.MethodGet: + t.connection.serveGET(w, req) + case http.MethodPost: + t.connection.servePOST(w, req) + default: + // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. + w.Header().Set("Allow", "GET, POST") + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return + } +} + +// serveGET streams messages to a hanging http GET, with stream ID and last +// message parsed from the Last-Event-ID header. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { + // streamID "" corresponds to the default GET request. + streamID := "" + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written. + lastIdx := -1 + if len(req.Header.Values(lastEventIDHeader)) > 0 { + eid := req.Header.Get(lastEventIDHeader) + var ok bool + streamID, lastIdx, ok = parseEventID(eid) + if !ok { + http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) + return + } + if c.eventStore == nil { + http.Error(w, "stream replay unsupported", http.StatusBadRequest) + return + } + } + + ctx := req.Context() + + // Read the protocol version from the header. For GET requests, this should + // always be present since GET only happens after initialization. + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + stream, done := c.acquireStream(ctx, w, streamID, lastIdx, protocolVersion) + if stream == nil { + return + } + defer stream.release() + c.hangResponse(ctx, done) +} + +// hangResponse blocks the HTTP response until one of three conditions is met: +// - ctx is cancelled (the client disconnected or the request timed out) +// - done is closed (all responses have been sent, or the stream was explicitly closed) +// - the session is closed +// +// This keeps the HTTP connection open so that server-sent events can be +// written to the response. +func (c *streamableServerConn) hangResponse(ctx context.Context, done <-chan struct{}) { + select { + case <-ctx.Done(): + case <-done: + case <-c.done: + } +} + +// acquireStream replays all events since lastIdx, and acquires the ongoing +// stream, if any. If non-nil, the resulting stream will be registered for +// receiving new messages, and the stream's done channel will be closed when +// all related messages have been delivered. +// +// If any errors occur, they will be written to w and the resulting stream will +// be nil. The resulting stream may also be nil if the stream is complete. +// +// Importantly, this function must hold the stream mutex until done replaying +// all messages, so that no delivery or storage of new messages occurs while +// the stream is still replaying. +// +// protocolVersion is the protocol version for this stream, used to determine +// feature support (e.g. prime and close events were added in 2025-11-25). +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx int, protocolVersion string) (*stream, chan struct{}) { + // if tempStream is set, the stream is done and we're just replaying messages. + // + // We record a temporary stream to claim exclusive replay rights. The spec + // (https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#resumability-and-redelivery) + // does not explicitly require exclusive replay, but we enforce it defensively. + tempStream := false + c.mu.Lock() + s, ok := c.streams[streamID] + if !ok { + // The stream is logically done, but claim exclusive rights to replay it by + // adding a temporary entry in the streams map. + // + // We create this entry with a non-nil w, to ensure it isn't claimed by + // another request before we lock it below. + tempStream = true + s = &stream{ + id: streamID, + w: w, + } + c.streams[streamID] = s + + // Since this stream is transient, we must clean up after replaying. + defer func() { + c.mu.Lock() + delete(c.streams, streamID) + c.mu.Unlock() + }() + } + c.mu.Unlock() + + s.mu.Lock() + defer s.mu.Unlock() + + // Check that this stream wasn't claimed by another request. + if !tempStream && s.w != nil { + http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) + return nil, nil + } + + // Collect events to replay. Collect them all before writing, so that we + // have an opportunity to set the HTTP status code on an error. + // + // As indicated above, we must do that while holding stream.mu, so that no + // new messages are added to the eventstore until we've replayed all previous + // messages, and registered our delivery function. + var toReplay [][]byte + if c.eventStore != nil { + for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, lastIdx) { + if err != nil { + // We can't replay events, perhaps because the underlying event store + // has garbage collected its storage. + // + // We must be careful here: any 404 will signal to the client that the + // *session* is not found, rather than the stream. + // + // 400 is not really accurate, but should at least have no side effects. + // Other SDKs (typescript) do not have a mechanism for events to be purged. + http.Error(w, "failed to replay events", http.StatusBadRequest) + return nil, nil + } + if len(data) > 0 { + toReplay = append(toReplay, data) + } + } + } + + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + + if s.id == "" { + // Issue #410: the standalone SSE stream is likely not to receive messages + // for a long time. Ensure that headers are flushed. + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + for _, data := range toReplay { + lastIdx++ + e := Event{Name: "message", Data: data} + if c.eventStore != nil { + e.ID = formatEventID(s.id, lastIdx) + } + if _, err := writeEvent(w, e); err != nil { + return nil, nil + } + } + + if tempStream || s.doneLocked() { + // Nothing more to do. + return nil, nil + } + + // The stream is not done: set up delivery state before the stream is + // unlocked, allowing the connection to write new events. + s.w = w + s.done = make(chan struct{}) + s.lastIdx = lastIdx + s.protocolVersion = protocolVersion + return s, s.done +} + +// servePOST handles an incoming message, and replies with either an outgoing +// message stream or single response object, depending on whether the +// jsonResponse option is set. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { + if len(req.Header.Values(lastEventIDHeader)) > 0 { + http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) + return + } + + // Read incoming messages. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + if len(body) == 0 { + http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) + return + } + // TODO(#674): once we've documented the support matrix for 2025-03-26 and + // earlier, drop support for matching entirely; that will simplify this + // logic. + incoming, isBatch, err := readBatch(body) + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } + + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + if isBatch && protocolVersion >= protocolVersion20250618 { + http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest) + return + } + + // TODO(rfindley): no tests fail if we reject batch JSON requests entirely. + // We need to test this with older protocol versions. + // if isBatch && c.jsonResponse { + // http.Error(w, "server does not support batch requests", http.StatusBadRequest) + // return + // } + + calls := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false + var initializeProtocolVersion string + for _, msg := range incoming { + if jreq, ok := msg.(*jsonrpc.Request); ok { + // Preemptively check that this is a valid request, so that we can fail + // the HTTP request. If we didn't do this, a request with a bad method or + // missing ID could be silently swallowed. + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if jreq.Method == methodInitialize { + isInitialize = true + // Extract the protocol version from InitializeParams. + var params InitializeParams + if err := json.Unmarshal(jreq.Params, ¶ms); err == nil { + initializeProtocolVersion = params.ProtocolVersion + } + } + // Include metadata for all requests (including notifications). + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + calls[jreq.ID] = struct{}{} + // See the doc for CloseSSEStream: allow the request handler to + // explicitly close the ongoing stream. + jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) { + c.mu.Lock() + streamID, ok := c.requestStreams[jreq.ID] + var stream *stream + if ok { + stream = c.streams[streamID] + } + c.mu.Unlock() + + if stream != nil { + stream.close(args.RetryAfter) + } + } + } + } + } + + // The prime and close events were added in protocol version 2025-11-25 (SEP-1699). + // Use the version from InitializeParams if this is an initialize request, + // otherwise use the protocol version header. + effectiveVersion := protocolVersion + if isInitialize && initializeProtocolVersion != "" { + effectiveVersion = initializeProtocolVersion + } + + // If we don't have any calls, we can just publish the incoming messages and return. + // No need to track a logical stream. + // + // See section [§2.1.4] of the spec: "If the server accepts the input, the + // server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + if len(calls) == 0 { + for _, msg := range incoming { + select { + case c.incoming <- msg: + case <-c.done: + // The session is closing. Since we haven't yet written any data to the + // response, we can signal to the client that the session is gone. + http.Error(w, "session is closing", http.StatusNotFound) + return + } + } + w.WriteHeader(http.StatusAccepted) + return + } + + // Invariant: we have at least one call. + // + // Create a logical stream to track its responses. + // Important: don't publish the incoming messages until the stream is + // registered, as the server may attempt to respond to imcoming messages as + // soon as they're published. + stream, err := c.newStream(req.Context(), calls, randText()) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } + + // Set response headers. Accept was checked in [StreamableHTTPHandler]. + w.Header().Set("Cache-Control", "no-cache, no-transform") + if c.jsonResponse { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + } + if c.sessionID != "" && isInitialize { + w.Header().Set(sessionIDHeader, c.sessionID) + } + + // Set up stream delivery state. + stream.w = w + done := make(chan struct{}) + stream.done = done + stream.protocolVersion = effectiveVersion + if c.jsonResponse { + // JSON mode: collect messages in pendingJSONMessages until done. + // Set pendingJSONMessages to a non-nil value to signal that this is an + // application/json stream. + stream.pendingJSONMessages = []json.RawMessage{} + } else { + // SSE mode: write a priming event if supported. + if c.eventStore != nil && effectiveVersion >= protocolVersion20251125 { + // Write a priming event, as defined by [§2.1.6] of the spec. + // + // [§2.1.6]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + // + // We must also write it to the event store in order for indexes to + // align. + if err := c.eventStore.Append(req.Context(), c.sessionID, stream.id, nil); err != nil { + c.logger.Warn(fmt.Sprintf("Storing priming event: %v", err)) + } + stream.lastIdx++ + e := Event{Name: "prime", ID: formatEventID(stream.id, stream.lastIdx)} + if _, err := writeEvent(w, e); err != nil { + c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) + } + } + } + + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() + + // The stream is now set up to deliver messages. + // + // Register it before publishing incoming messages. + c.mu.Lock() + c.streams[stream.id] = stream + for reqID := range calls { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + + // Publish incoming messages. + for _, msg := range incoming { + select { + case c.incoming <- msg: + // Note: don't select on req.Context().Done() here, since we've already + // received the requests and may have already published a response message + // or notification. The client could resume the stream. + // + // In fact, this send could be in a separate goroutine. + case <-c.done: + // Session closed: we don't know if any data has been written, so it's + // too late to write a status code here. + return + } + } + + c.hangResponse(req.Context(), done) +} + +// Event IDs: encode both the logical connection ID and the index, as +// _, to be consistent with the typescript implementation. + +// formatEventID returns the event ID to use for the logical connection ID +// streamID and message index idx. +// +// See also [parseEventID]. +func formatEventID(sid string, idx int) string { + return fmt.Sprintf("%s_%d", sid, idx) +} + +// parseEventID parses a Last-Event-ID value into a logical stream id and +// index. +// +// See also [formatEventID]. +func parseEventID(eventID string) (streamID string, idx int, ok bool) { + parts := strings.Split(eventID, "_") + if len(parts) != 2 { + return "", 0, false + } + streamID = parts[0] + idx, err := strconv.Atoi(parts[1]) + if err != nil || idx < 0 { + return "", 0, false + } + return streamID, idx, true +} + +// Read implements the [Connection] interface. +func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-c.incoming: + if !ok { + return nil, io.EOF + } + return msg, nil + case <-c.done: + return nil, io.EOF + } +} + +// Write implements the [Connection] interface. +func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // Throughout this function, note that any error that wraps ErrRejected + // indicates a does not cause the connection to break. + // + // Most errors don't break the connection: unlike a true bidirectional + // stream, a failure to deliver to a stream is not an indication that the + // logical session is broken. + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return err + } + + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } + + // Find the incoming request that this write relates to, if any. + var ( + relatedRequest jsonrpc.ID + responseTo jsonrpc.ID // if valid, the message is a response to this request + ) + if resp, ok := msg.(*jsonrpc.Response); ok { + // If the message is a response, it relates to its request (of course). + relatedRequest = resp.ID + responseTo = resp.ID + } else { + // Otherwise, we check to see if it request was made in the context of an + // ongoing request. This may not be the case if the request was made with + // an unrelated context. + if v := ctx.Value(idContextKey{}); v != nil { + relatedRequest = v.(jsonrpc.ID) + } + } + + // If the stream is application/json, but the message is not a response, we + // must send it out of band to the standalone SSE stream. + if c.jsonResponse && !responseTo.IsValid() { + relatedRequest = jsonrpc.ID{} + } + + // Write the message to the stream. + var s *stream + c.mu.Lock() + if relatedRequest.IsValid() { + if streamID, ok := c.requestStreams[relatedRequest]; ok { + s = c.streams[streamID] + } + } else { + s = c.streams[""] // standalone SSE stream + } + if responseTo.IsValid() { + // Once we've responded to a request, disallow related messages by removing + // the stream association. This also releases memory. + delete(c.requestStreams, responseTo) + } + sessionClosed := c.isDone + c.mu.Unlock() + + if s == nil { + // The request was made in the context of an ongoing request, but that + // request is complete. + // + // In the future, we could be less strict and allow the request to land on + // the standalone SSE stream. + return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) + } + if sessionClosed { + return errors.New("session is closed") + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Store in eventStore before delivering. + // TODO(rfindley): we should only append if the response is SSE, not JSON, by + // pushing down into the delivery layer. + delivered := false + var errs []error + if c.eventStore != nil { + if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { + errs = append(errs, err) + } else { + delivered = true + } + } + + // Compute eventID for SSE streams with event store. + // Use s.lastIdx + 1 because deliverLocked increments before writing. + var eventID string + if c.eventStore != nil { + eventID = formatEventID(s.id, s.lastIdx+1) + } + + done, err := s.deliverLocked(data, eventID, responseTo) + if err != nil { + errs = append(errs, err) + } else { + delivered = true + } + + if done { + c.mu.Lock() + delete(c.streams, s.id) + c.mu.Unlock() + } + + if !delivered { + return fmt.Errorf("%w: undelivered message: %v", jsonrpc2.ErrRejected, errors.Join(errs...)) + } + return nil +} + +// Close implements the [Connection] interface. +func (c *streamableServerConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.isDone { + c.isDone = true + close(c.done) + if c.eventStore != nil { + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) + } + } + return nil +} + +// A StreamableClientTransport is a [Transport] that can communicate with an MCP +// endpoint serving the streamable HTTP transport defined by the 2025-03-26 +// version of the spec. +type StreamableClientTransport struct { + Endpoint string + HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int + + // TODO(rfindley): propose exporting these. + // If strict is set, the transport is in 'strict mode', where any violation + // of the MCP spec causes a failure. + strict bool + // If logger is set, it is used to log aspects of the transport, such as spec + // violations that were ignored. + logger *slog.Logger +} + +// These settings are not (yet) exposed to the user in +// StreamableClientTransport. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) + +var ( + // reconnectInitialDelay is the base delay for the first reconnect attempt. + // + // Mutable for testing. + reconnectInitialDelay = 1 * time.Second +) + +// Connect implements the [Transport] interface. +// +// The resulting [Connection] writes messages via POST requests to the +// transport URL with the Mcp-Session-Id header set, and reads messages from +// hanging requests. +// +// When closed, the connection issues a DELETE request to terminate the logical +// session. +func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { + client := t.HTTPClient + if client == nil { + client = http.DefaultClient + } + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + // + // This context should be detached from the incoming context: the standalone + // SSE request should not break when the connection context is done. + // + // For example, consider that the user may want to wait at most 5s to connect + // to the server, and therefore uses a context with a 5s timeout when calling + // client.Connect. Let's suppose that Connect returns after 1s, and the user + // starts using the resulting session. If we didn't detach here, the session + // would break after 4s, when the background SSE stream is terminated. + // + // Instead, creating a cancellable context detached from the incoming context + // allows us to preserve context values (which may be necessary for auth + // middleware), yet only cancel the standalone stream when the connection is closed. + connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) + conn := &streamableClientConn{ + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: ensureLogger(t.logger), // must be non-nil for safe logging + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + } + return conn, nil +} + +type streamableClientConn struct { + url string + client *http.Client + ctx context.Context // connection context, detached from Connect + cancel context.CancelFunc // cancels ctx + incoming chan jsonrpc.Message + maxRetries int + strict bool // from [StreamableClientTransport.strict] + logger *slog.Logger // from [StreamableClientTransport.logger] + + // Guard calls to Close, as it may be called multiple times. + closeOnce sync.Once + closeErr error + done chan struct{} // signal graceful termination + + // Logical reads are distributed across multiple http requests. Whenever any + // of them fails to process their response, we must break the connection, by + // failing the pending Read. + // + // Achieve this by storing the failure message, and signalling when reads are + // broken. See also [streamableClientConn.fail] and + // [streamableClientConn.failure]. + failOnce sync.Once + _failure error + failed chan struct{} // signal failure + + // Guard the initialization state. + mu sync.Mutex + initializedResult *InitializeResult + sessionID string +} + +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { + c.mu.Lock() + c.initializedResult = state.InitializeResult + c.mu.Unlock() + + // Start the standalone SSE stream as soon as we have the initialized + // result. + // + // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be + // used to open an SSE stream, allowing the server to communicate to the + // client, without the client first sending data via HTTP POST. + // + // We have to wait for initialized, because until we've received + // initialized, we don't know whether the server requires a sessionID. + // + // § 2.5: A server using the Streamable HTTP transport MAY assign a session + // ID at initialization time, by including it in an Mcp-Session-Id header + // on the HTTP response containing the InitializeResult. + c.connectStandaloneSSE() +} + +func (c *streamableClientConn) connectStandaloneSSE() { + resp, err := c.connectSSE(c.ctx, "", 0, true) + if err != nil { + // If the client didn't cancel the request, and failure breaks the logical + // session. + if c.ctx.Err() == nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + } + return + } + + // [§2.2.3]: "The server MUST either return Content-Type: + // text/event-stream in response to this HTTP GET, or else return HTTP + // 405 Method Not Allowed, indicating that the server does not offer an + // SSE stream at this endpoint." + // + // [§2.2.3]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode == http.StatusMethodNotAllowed { + // The server doesn't support the standalone SSE stream. + resp.Body.Close() + return + } + if resp.StatusCode >= 400 && resp.StatusCode < 500 && !c.strict { + // modelcontextprotocol/go-sdk#393,#610: some servers return NotFound or + // other status codes instead of MethodNotAllowed for the standalone SSE + // stream. + // + // Treat this like MethodNotAllowed in non-strict mode. + c.logger.Warn(fmt.Sprintf("got %d instead of 405 for standalone SSE stream", resp.StatusCode)) + resp.Body.Close() + return + } + summary := "standalone SSE stream" + if err := c.checkResponse(summary, resp); err != nil { + c.fail(err) + return + } + go c.handleSSE(c.ctx, summary, resp, nil) +} + +// fail handles an asynchronous error while reading. +// +// If err is non-nil, it is terminal, and subsequent (or pending) Reads will +// fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. +func (c *streamableClientConn) fail(err error) { + if err != nil { + c.failOnce.Do(func() { + c._failure = err + close(c.failed) + }) + } +} + +func (c *streamableClientConn) failure() error { + select { + case <-c.failed: + return c._failure + default: + return nil + } +} + +func (c *streamableClientConn) SessionID() string { + c.mu.Lock() + defer c.mu.Unlock() + return c.sessionID +} + +// Read implements the [Connection] interface. +func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + if err := c.failure(); err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-c.failed: + return nil, c.failure() + case <-c.done: + return nil, io.EOF + case msg := <-c.incoming: + return msg, nil + } +} + +// Write implements the [Connection] interface. +func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if err := c.failure(); err != nil { + return err + } + + var requestSummary string + var forCall *jsonrpc.Request + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + if msg.IsCall() { + forCall = msg + } + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + + data, err := jsonrpc.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + c.setMCPHeaders(req) + + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + } + + if err := c.checkResponse(requestSummary, resp); err != nil { + // Only fail the connection for non-transient errors. + // Transient errors (wrapped with ErrRejected) should not break the connection. + if !errors.Is(err, jsonrpc2.ErrRejected) { + c.fail(err) + } + return err + } + + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + c.mu.Lock() + hadSessionID := c.sessionID + if hadSessionID == "" { + c.sessionID = sessionID + } + c.mu.Unlock() + if hadSessionID != "" && hadSessionID != sessionID { + resp.Body.Close() + return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) + } + } + + if forCall == nil { + resp.Body.Close() + + // [§2.1.4]: "If the input is a JSON-RPC response or notification: + // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." + // + // [§2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusAccepted { + errMsg := fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode) + // Some servers return 200, even with an empty json body. + // + // In strict mode, return an error to the caller. + c.logger.Warn(errMsg) + if c.strict { + return errors.New(errMsg) + } + } + return nil + } + + contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0]) + switch contentType { + case "application/json": + go c.handleJSON(requestSummary, resp) + + case "text/event-stream": + var forCall *jsonrpc.Request + if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() { + forCall = jsonReq + } + // Handle the resulting stream. Note that ctx comes from the call, and + // therefore is already cancelled when the JSON-RPC request is cancelled + // (or rather, context cancellation is what *triggers* JSON-RPC + // cancellation) + go c.handleSSE(ctx, requestSummary, resp, forCall) + + default: + resp.Body.Close() + return fmt.Errorf("%s: unsupported content type %q", requestSummary, contentType) + } + return nil +} + +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth atomic.Bool + +func (c *streamableClientConn) setMCPHeaders(req *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initializedResult != nil { + req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } + if c.sessionID != "" { + req.Header.Set(sessionIDHeader, c.sessionID) + } + if testAuth.Load() { + req.Header.Set("Authorization", "Bearer foo") + } +} + +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) + return + } + select { + case c.incoming <- msg: + case <-c.done: + // The connection was closed by the client; exit gracefully. + } +} + +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +// +// If forCall is set, it is the call that initiated the stream, and the +// stream is complete when we receive its response. Otherwise, this is the +// standalone stream. +func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc2.Request) { + for { + // Connection was successful. Continue the loop with the new response. + // + // TODO(#679): we should set a reasonable limit on the number of times + // we'll try getting a response for a given request, or enforce that we + // actually make progress. + // + // Eventually, if we don't get the response, we should stop trying and + // fail the request. + lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If we don't have a last event ID, we can never get the call response, so + // there's nothing to resume. For the standalone stream, we can reconnect, + // but we may just miss messages. + if lastEventID == "" && forCall != nil { + return + } + + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) + if err != nil { + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } + return + } + + resp = newResp + if err := c.checkResponse(requestSummary, resp); err != nil { + c.fail(err) + return + } + } +} + +// checkResponse checks the status code of the provided response, and +// translates it into an error if the request was unsuccessful. +// +// The response body is close if a non-nil error is returned. +func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) { + defer func() { + if err != nil { + resp.Body.Close() + } + }() + // §2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Return an errSessionMissing to avoid sending a redundant DELETE when the + // session is already gone. + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + } + // Transient server errors (502, 503, 504, 429) should not break the connection. + // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. + if isTransientHTTPStatus(resp.StatusCode) { + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) + } + return nil +} + +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event and a flag +// indicating if the connection was closed by the client. If resp is nil, it +// returns "", false. +func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { + defer func() { + // Drain any remaining unprocessed body. This allows the connection to be re-used after closing. + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + for evt, err := range scanEvents(resp.Body) { + if err != nil { + if ctx.Err() != nil { + return "", 0, true // don't reconnect: client cancelled + } + break + } + + if evt.ID != "" { + lastEventID = evt.ID + } + + if evt.Retry != "" { + if n, err := strconv.ParseInt(evt.Retry, 10, 64); err == nil { + reconnectDelay = time.Duration(n) * time.Millisecond + } + } + // According to SSE spec, events with no name default to "message" + if evt.Name != "" && evt.Name != "message" { + continue + } + + msg, err := jsonrpc.DecodeMessage(evt.Data) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) + return "", 0, true + } + + select { + case c.incoming <- msg: + // Check if this is the response to our call, which terminates the request. + // (it could also be a server->client request or notification). + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { + // TODO: we should never get a response when forReq is nil (the standalone SSE request). + // We should detect this case. + if jsonResp.ID == forCall.ID { + return "", 0, true + } + } + + case <-c.done: + // The connection was closed by the client; exit gracefully. + return "", 0, true + } + } + // The loop finished without an error, indicating the server closed the stream. + // + // If the lastEventID is "", the stream is not retryable and we should + // report a synthetic error for the call. + // + // Note that this is different from the cancellation case above, since the + // caller is still waiting for a response that will never come. + if lastEventID == "" && forCall != nil { + errmsg := &jsonrpc2.Response{ + ID: forCall.ID, + Error: fmt.Errorf("request terminated without response"), + } + select { + case c.incoming <- errmsg: + case <-c.done: + } + } + return lastEventID, reconnectDelay, false +} + +// connectSSE handles the logic of connecting a text/event-stream connection. +// +// If lastEventID is set, it is the last-event ID of a stream being resumed. +// +// If connection fails, connectSSE retries with an exponential backoff +// strategy. It returns a new, valid HTTP response if successful, or an error +// if all retries are exhausted. +// +// reconnectDelay is the delay set by the server using the SSE retry field, or +// 0. +// +// If initial is set, this is the initial attempt. +// +// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()). +func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) { + var finalErr error + attempt := 0 + if !initial { + // We've already connected successfully once, so delay subsequent + // reconnections. Otherwise, if the server returns 200 but terminates the + // connection, we'll reconnect as fast as we can, ad infinitum. + // + // TODO: we should consider also setting a limit on total attempts for one + // logical request. + attempt = 1 + } + delay := calculateReconnectDelay(attempt) + if reconnectDelay > 0 { + delay = reconnectDelay // honor the server's requested initial delay + } + for ; attempt <= c.maxRetries; attempt++ { + select { + case <-c.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + + case <-ctx.Done(): + // If the connection context is canceled, the request below will not + // succeed anyway. + return nil, ctx.Err() + + case <-time.After(delay): + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, err + } + c.setMCPHeaders(req) + if lastEventID != "" { + req.Header.Set(lastEventIDHeader, lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + resp, err := c.client.Do(req) + if err != nil { + finalErr = err // Store the error and try again. + delay = calculateReconnectDelay(attempt + 1) + continue + } + return resp, nil + } + } + // If the loop completes, all retries have failed, or the client is closing. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) + } + return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) +} + +// Close implements the [Connection] interface. +func (c *streamableClientConn) Close() error { + c.closeOnce.Do(func() { + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. + } else { + req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) + if err != nil { + c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } + } + } + + // Cancel any hanging network requests after cleanup. + c.cancel() + close(c.done) + }) + return c.closeErr +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, reconnectMaxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server error that should not permanently break the connection. +func isTransientHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusTooManyRequests: // 429 + return true + } + return false +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go new file mode 100644 index 000000000..41a100461 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go @@ -0,0 +1,226 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move client-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Client Design + +This document describes the client-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The client-side streamable transport allows an MCP client to communicate with a +server over HTTP, sending messages via POST and receiving responses via either +JSON or server-sent events (SSE). The implementation consists of two main +components: + + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableClientTransport] │ + │ Transport configuration; creates connections via Connect() │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [streamableClientConn] │ + │ Connection implementation; handles HTTP request/response │ + └─────────────────────────────────────────────────────────────────┘ + │ + ├──────────────────────────────────────┐ + ▼ ▼ + ┌─────────────────────────────────────────┐ ┌────────────────────────────────────┐ + │ POST request handlers │ │ Standalone SSE stream │ + │ (one per outgoing message/call) │ │ (server-initiated messages) │ + └─────────────────────────────────────────┘ └────────────────────────────────────┘ + +# Sessions + +The client maintains a session with the server, identified by a session ID +(Mcp-Session-Id header): + + - Session ID is received from the server after initialization + - Client includes the session ID in all subsequent requests + - Session ends when the client calls Close() (sends DELETE) or server returns 404 + +[streamableClientConn] stores the session state: + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Protocol version and server capabilities + +# Connection Lifecycle + +1. Connect: [StreamableClientTransport.Connect] creates a [streamableClientConn] + with a detached context for the connection's lifetime. The context is detached + to prevent the standalone SSE stream from being cancelled when the original + Connect context times out. + +2. Initialize: The MCP client sends initialize/initialized messages. Upon + receiving [InitializeResult], the connection: + - Stores the negotiated protocol version for the Mcp-Protocol-Version header + - Captures the session ID from the Mcp-Session-Id response header + - Starts the standalone SSE stream via [streamableClientConn.connectStandaloneSSE] + +3. Operation: Messages are sent via POST, responses received via JSON or SSE. + +4. Close: [streamableClientConn.Close] sends a DELETE request to terminate + the session (unless the session is already gone), then cancels the connection + context to clean up the standalone SSE stream. + +# Sending Messages (Write) + +[streamableClientConn.Write] sends all outgoing messages via HTTP POST: + + POST /endpoint + Content-Type: application/json + Accept: application/json, text/event-stream + Mcp-Protocol-Version: + Mcp-Session-Id: + + + +The server may respond with: + - 202 Accepted: Message received, no response body (notifications/responses) + - 200 OK with application/json: Single JSON-RPC response + - 200 OK with text/event-stream: SSE stream of responses + +# Receiving Messages (Read) + +[streamableClientConn.Read] returns messages from the [streamableClientConn.incoming] +channel, which is populated by multiple concurrent goroutines: + +1. POST response handlers ([streamableClientConn.handleJSON] and + [streamableClientConn.handleSSE]): Process responses from POST requests + +2. Standalone SSE stream: Receives server-initiated requests and notifications + +The client handles both response formats: + - JSON: [streamableClientConn.handleJSON] reads body, decodes message + - SSE: [streamableClientConn.handleSSE] scans events, decodes each message + +# Standalone SSE Stream + +After initialization, [streamableClientConn.sessionUpdated] triggers +[streamableClientConn.connectStandaloneSSE] to open a GET request for +server-initiated messages: + + GET /endpoint + Accept: text/event-stream + Mcp-Session-Id: + +Stream behavior: + - Optional: Server may return 405 Method Not Allowed (spec-compliant) or + other 4xx errors (tolerated in non-strict mode for compatibility) + - Persistent: Runs for the connection lifetime in a background goroutine + - Resumable: Uses Last-Event-ID header on reconnection if server provides event IDs + - Reconnects: Automatic reconnection with exponential backoff on interruption + +# Stream Resumption + +When an SSE stream (standalone or POST response) is interrupted, the client +attempts to reconnect using [streamableClientConn.connectSSE]: + +Event ID tracking: + - [streamableClientConn.processStream] tracks the last received event ID + - On reconnection, the Last-Event-ID header is set to resume from that point + - Server replays missed events if it has an [EventStore] configured + +See [calculateReconnectDelay] for the reconnect delay details. + +Server-initiated reconnection (SEP-1699) + - SSE retry field: Sets the delay for the next reconnect attempt + - If server doesn't provide event IDs, non-standalone streams don't reconnect + +# Response Formats + +The client must handle two response formats from POST requests: + +1. application/json: Single JSON-RPC response + - Body contains one JSON-RPC message + - Handled by [streamableClientConn.handleJSON] + - Simpler but doesn't support streaming or server-initiated messages + +2. text/event-stream: SSE stream of messages + - Body contains SSE events with JSON-RPC messages + - Handled by [streamableClientConn.handleSSE] + - Supports multiple messages and server-initiated communication + - Stream completes when the response to the originating call is received + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - Used by [streamableClientConn.Write] + - Response may be JSON or SSE + + - GET: Open or resume SSE stream for server-initiated messages + - Used by [streamableClientConn.connectSSE] + - Always expects text/event-stream response (or 405) + + - DELETE: Terminate the session + - Used by [streamableClientConn.Close] + - Skipped if session is already known to be gone ([errSessionMissing]) + +# Error Handling + +Errors are categorized and handled differently: + +1. Transient (recoverable via reconnection): + - Network interruption during SSE streaming + - Connection reset or timeout + - Triggers reconnection in [streamableClientConn.handleSSE] + +2. Terminal (breaks the connection): + - 404 Not Found: Session terminated by server ([errSessionMissing]) + - Message decode errors: Protocol violation + - Context cancellation: Client closed connection + - Mismatched session IDs: Protocol error + - See issue #683: our terminal errors are too strict. + +Terminal errors are stored via [streamableClientConn.fail] and returned by +subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed] +channel signals that the connection is broken. + +Special case: [errSessionMissing] indicates the server has terminated the session, +so [streamableClientConn.Close] skips the DELETE request. + +# Protocol Version Header + +After initialization, all requests include: + + Mcp-Protocol-Version: + +This header (set by [streamableClientConn.setMCPHeaders]): + - Allows the server to handle requests per the negotiated protocol + - Is omitted before initialization completes + - Uses the version from [streamableClientConn.initializedResult] + +# Key Implementation Details + +[StreamableClientTransport] configuration: + - [StreamableClientTransport.Endpoint]: URL of the MCP server + - [StreamableClientTransport.HTTPClient]: Custom HTTP client (optional) + - [StreamableClientTransport.MaxRetries]: Reconnection attempts (default 5) + +[streamableClientConn] handles the [Connection] interface: + - [streamableClientConn.Read]: Returns messages from incoming channel + - [streamableClientConn.Write]: Sends messages via POST, starts response handlers + - [streamableClientConn.Close]: Sends DELETE, cancels context, closes done channel + +State management: + - [streamableClientConn.incoming]: Buffered channel for received messages + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Cached for protocol version header + - [streamableClientConn.failed]: Channel closed on terminal error + - [streamableClientConn.done]: Channel closed on graceful shutdown + - [streamableClientConn.ctx]: Detached context for connection lifetime + - [streamableClientConn.cancel]: Cancels ctx to terminate SSE streams + +Context handling: + - Connection context is detached from [StreamableClientTransport.Connect] context + using [xcontext.Detach] to preserve context values (for auth middleware) while + preventing premature cancellation of the standalone SSE stream + - Individual POST requests use caller-provided contexts for cancellation +*/ diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go new file mode 100644 index 000000000..8a573e56a --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go @@ -0,0 +1,160 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move server-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Server Design + +This document describes the server-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The streamable HTTP transport enables MCP communication over HTTP, with +server-sent events (SSE) for server-to-client messages. The implementation +consists of several layered components: + + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableHTTPHandler] │ + │ http.Handler that manages sessions and routes HTTP requests │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [StreamableServerTransport] │ + │ transport implementation, one per session; exposes ServeHTTP │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [streamableServerConn] │ + │ Connection implementation, handles message routing │ + └─────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────────────────────────────┐ + │ [stream] │ + │ Logical message channel within a session, may be resumed │ + └─────────────────────────────────────────────────────────────────┘ + +# Sessions + +As with other transports, a session represents a logical MCP connection between +a client and server. In the streamable transport, sessions are identified by a +unique session ID (Mcp-Session-Id header) and persist across multiple HTTP +requests. + +[StreamableHTTPHandler] maintains a map of active sessions ([sessionInfo]), +each containing: + - The [ServerSession] (MCP-level session state) + - The [StreamableServerTransport] (for message I/O) + - Optional timeout management for idle session cleanup + +Sessions are created on the first POST request (typically containing the +initialize request) and destroyed either by: + - Client sending a DELETE request + - Session timeout due to inactivity + - Server explicitly closing the session + +# Streams + +Within a session, there can be multiple concurrent "streams" - logical channels +for message delivery. This is distinct from HTTP streams; a single [stream] may +span multiple HTTP request/response cycles (via resumption). + +There are two types of streams: + +1. Optional standalone SSE stream (id = ""): + - Created when client sends a GET request to the endpoint + - Used for server-initiated messages (requests/notifications to client) + - Persists for the lifetime of the session + - Only one standalone stream per session + +2. Request streams (id = random string): + - Created for each POST request containing JSON-RPC calls + - Used to route responses back to the originating HTTP request + - Completed when all responses have been sent + - Can be resumed via GET with Last-Event-ID if interrupted + +# Message Routing + +When the server writes a message, it must be routed to the correct [stream]: + + - Responses: Routed to the stream that originated the request + - Requests/Notifications made during request handling: Routed to the same + stream as the triggering request (via context) + - Requests/Notifications made outside request handling: Routed to the + standalone SSE stream + +This routing is implemented using: + - [streamableServerConn.requestStreams] maps request IDs to stream IDs + - [idContextKey] is used to store the originating request ID in Context + - [streamableServerConn.streams] maps stream IDs to [stream] objects + +# Stream Resumption + +If an HTTP connection is interrupted (network issues, etc.), clients can +resume a stream by sending a GET request with the Last-Event-ID header. +This requires an [EventStore] to be configured on the server. + + - [EventStore.Open] is called when a new stream is created + - [EventStore.Append] is called for each message written to the stream + - [EventStore.After] is called to replay messages after a given index + - [EventStore.SessionClosed] is called when the session ends + +Event IDs are formatted as "_" to identify both the +stream and position within that stream (see [formatEventID] and [parseEventID]). + +# Stateless Mode + +For simpler deployments, the handler supports "stateless" mode +([StreamableHTTPOptions.Stateless]) where: + - No session ID validation is performed + - Each request creates a temporary session that's closed after the request + - Server-to-client requests are not supported (no way to receive response) + +This mode is useful for simple tool servers that don't need bidirectional +communication. + +# Response Formats + +The server can respond to POST requests in two formats: + +1. text/event-stream (default): Messages sent as SSE events, supports + streaming multiple messages and server-initiated communication during + request handling. + +2. application/json ([StreamableHTTPOptions.JSONResponse]): Single JSON + response, simpler but doesn't support streaming. Server-initiated messages + during request handling go to the standalone SSE stream instead. + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - GET: Open standalone SSE stream or resume an interrupted stream + - DELETE: Terminate the session + +# Key Implementation Details + +The [stream] struct manages delivery of messages to HTTP responses. + +Fields: + - [stream.w] is the ResponseWriter for the current HTTP response (non-nil indicates claimed) + - [stream.done] is closed to release the hanging HTTP request + - [stream.requests] tracks pending request IDs (stream completes when empty) + +Methods: + - [stream.deliverLocked] delivers a message to the stream + - [stream.close] sends a close event and releases the stream + - [stream.release] releases the stream from the HTTP request, allowing resumption + +[streamableServerConn] handles the [Connection] interface: + - [streamableServerConn.Read] receives messages from the incoming channel (fed by POST handlers) + - [streamableServerConn.Write] routes messages to appropriate streams + - [streamableServerConn.Close] terminates the session and notifies the [EventStore] +*/ diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go new file mode 100644 index 000000000..8aa7c3c0d --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go @@ -0,0 +1,139 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/google/jsonschema-go/jsonschema" +) + +// A ToolHandler handles a call to tools/call. +// +// This is a low-level API, for use with [Server.AddTool]. It does not do any +// pre- or post-processing of the request or result: the params contain raw +// arguments, no input validation is performed, and the result is returned to +// the user as-is, without any validation of the output. +// +// Most users will write a [ToolHandlerFor] and install it with the generic +// [AddTool] function. +// +// If ToolHandler returns an error, it is treated as a protocol error. By +// contrast, [ToolHandlerFor] automatically populates [CallToolResult.IsError] +// and [CallToolResult.Content] accordingly. +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) + +// A ToolHandlerFor handles a call to tools/call with typed arguments and results. +// +// Use [AddTool] to add a ToolHandlerFor to a server. +// +// Unlike [ToolHandler], [ToolHandlerFor] provides significant functionality +// out of the box, and enforces that the tool conforms to the MCP spec: +// - The In type provides a default input schema for the tool, though it may +// be overridden in [AddTool]. +// - The input value is automatically unmarshaled from req.Params.Arguments. +// - The input value is automatically validated against its input schema. +// Invalid input is rejected before getting to the handler. +// - If the Out type is not the empty interface [any], it provides the +// default output schema for the tool (which again may be overridden in +// [AddTool]). +// - The Out value is used to populate result.StructuredOutput. +// - If [CallToolResult.Content] is unset, it is populated with the JSON +// content of the output. +// - An error result is treated as a tool error, rather than a protocol +// error, and is therefore packed into CallToolResult.Content, with +// [IsError] set. +// +// For these reasons, most users can ignore the [CallToolRequest] argument and +// [CallToolResult] return values entirely. In fact, it is permissible to +// return a nil CallToolResult, if you only care about returning a output value +// or error. The effective result will be populated as described above. +type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) + +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler ToolHandler +} + +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { + // TODO: use reflection to create the struct type to unmarshal into. + // Separate validation from assignment. + + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? + if resolved != nil { + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } + } + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) + } + if err := resolved.Validate(&v); err != nil { + return nil, err + } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) + } + } + return data, nil +} + +// validateToolName checks whether name is a valid tool name, reporting a +// non-nil error if not. +func validateToolName(name string) error { + if name == "" { + return fmt.Errorf("tool name cannot be empty") + } + if len(name) > 128 { + return fmt.Errorf("tool name exceeds maximum length of 128 characters (current: %d)", len(name)) + } + // For consistency with other SDKs, report characters in the order the appear + // in the name. + var invalidChars []string + seen := make(map[rune]bool) + for _, r := range name { + if !validToolNameRune(r) { + if !seen[r] { + invalidChars = append(invalidChars, fmt.Sprintf("%q", string(r))) + seen[r] = true + } + } + } + if len(invalidChars) > 0 { + return fmt.Errorf("tool name contains invalid characters: %s", strings.Join(invalidChars, ", ")) + } + return nil +} + +// validToolNameRune reports whether r is valid within tool names. +func validToolNameRune(r rune) bool { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' || r == '.' +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go new file mode 100644 index 000000000..25f1d5d05 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go @@ -0,0 +1,655 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +// ErrConnectionClosed is returned when sending a message to a connection that +// is closed or in the process of closing. +var ErrConnectionClosed = errors.New("connection closed") + +// A Transport is used to create a bidirectional connection between MCP client +// and server. +// +// Transports should be used for at most one call to [Server.Connect] or +// [Client.Connect]. +type Transport interface { + // Connect returns the logical JSON-RPC connection.. + // + // It is called exactly once by [Server.Connect] or [Client.Connect]. + Connect(ctx context.Context) (Connection, error) +} + +// A Connection is a logical bidirectional JSON-RPC connection. +type Connection interface { + // Read reads the next message to process off the connection. + // + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. + Read(context.Context) (jsonrpc.Message, error) + + // Write writes a new message to the connection. + // + // Write may be called concurrently, as calls or responses may occur + // concurrently in user code. + Write(context.Context, jsonrpc.Message) error + + // Close closes the connection. It is implicitly called whenever a Read or + // Write fails. + // + // Close may be called multiple times, potentially concurrently. + Close() error + + // TODO(#148): remove SessionID from this interface. + SessionID() string +} + +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? +type clientConnection interface { + Connection + + // sessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) +} + +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) +} + +// A StdioTransport is a [Transport] that communicates over stdin/stdout using +// newline-delimited JSON. +type StdioTransport struct{} + +// Connect implements the [Transport] interface. +func (*StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil +} + +// nopCloserWriter is an io.WriteCloser with a trivial Close method. +type nopCloserWriter struct { + io.Writer +} + +func (nopCloserWriter) Close() error { return nil } + +// An IOTransport is a [Transport] that communicates over separate +// io.ReadCloser and io.WriteCloser using newline-delimited JSON. +type IOTransport struct { + Reader io.ReadCloser + Writer io.WriteCloser +} + +// Connect implements the [Transport] interface. +func (t *IOTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{t.Reader, t.Writer}), nil +} + +// An InMemoryTransport is a [Transport] that communicates over an in-memory +// network connection, using newline-delimited JSON. +// +// InMemoryTransports should be constructed using [NewInMemoryTransports], +// which returns two transports connected to each other. +type InMemoryTransport struct { + rwc io.ReadWriteCloser +} + +// Connect implements the [Transport] interface. +func (t *InMemoryTransport) Connect(context.Context) (Connection, error) { + return newIOConn(t.rwc), nil +} + +// NewInMemoryTransports returns two [InMemoryTransport] objects that connect +// to each other. +// +// The resulting transports are symmetrical: use either to connect to a server, +// and then the other to connect to a client. Servers must be connected before +// clients, as the client initializes the MCP session during connection. +func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { + c1, c2 := net.Pipe() + return &InMemoryTransport{c1}, &InMemoryTransport{c2} +} + +type binder[T handler, State any] interface { + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T + disconnect(T) +} + +type handler interface { + handle(ctx context.Context, req *jsonrpc.Request) (any, error) +} + +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { + var zero H + mcpConn, err := t.Connect(ctx) + if err != nil { + return zero, err + } + // If logging is configured, write message logs. + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) + var ( + h H + preempter canceller + ) + bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { + h = b.bind(mcpConn, conn, s, onClose) + preempter.conn = conn + return jsonrpc2.HandlerFunc(h.handle) + } + _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ + Reader: reader, + Writer: writer, + Closer: mcpConn, + Bind: bind, + Preempter: &preempter, + OnDone: func() { + b.disconnect(h) + }, + OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, + }) + assert(preempter.conn != nil, "unbound preempter") + return h, nil +} + +// A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP +// cancelled notifications. +type canceller struct { + conn *jsonrpc2.Connection +} + +// Preempt implements [jsonrpc2.Preempter]. +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { + if req.Method == notificationCancelled { + var params CancelledParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return nil, err + } + id, err := jsonrpc2.MakeID(params.RequestID) + if err != nil { + return nil, err + } + go c.conn.Cancel(id) + } + return nil, jsonrpc2.ErrNotHandled +} + +// call executes and awaits a jsonrpc2 call on the given connection, +// translating errors into the mcp domain. +func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { + // The "%w"s in this function expose jsonrpc.Error as part of the API. + call := conn.Call(ctx, method, params) + err := call.Await(ctx, result) + switch { + case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) + case ctx.Err() != nil: + // Notify the peer of cancellation. + err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) + return errors.Join(ctx.Err(), err) + case err != nil: + return fmt.Errorf("calling %q: %w", method, err) + } + return nil +} + +// A LoggingTransport is a [Transport] that delegates to another transport, +// writing RPC logs to an io.Writer. +type LoggingTransport struct { + Transport Transport + Writer io.Writer +} + +// Connect connects the underlying transport, returning a [Connection] that writes +// logs to the configured destination. +func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { + delegate, err := t.Transport.Connect(ctx) + if err != nil { + return nil, err + } + return &loggingConn{delegate: delegate, w: t.Writer}, nil +} + +type loggingConn struct { + delegate Connection + + mu sync.Mutex + w io.Writer +} + +func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } + +// Read is a stream middleware that logs incoming messages. +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { + msg, err := s.delegate.Read(ctx) + + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "read error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "read: %s\n", string(data)) + s.mu.Unlock() + } + + return msg, err +} + +// Write is a stream middleware that logs outgoing messages. +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { + err := s.delegate.Write(ctx, msg) + if err != nil { + s.mu.Lock() + fmt.Fprintf(s.w, "write error: %v\n", err) + s.mu.Unlock() + } else { + data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() + if err != nil { + fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) + } + fmt.Fprintf(s.w, "write: %s\n", string(data)) + s.mu.Unlock() + } + return err +} + +func (s *loggingConn) Close() error { + return s.delegate.Close() +} + +// A rwc binds an io.ReadCloser and io.WriteCloser together to create an +// io.ReadWriteCloser. +type rwc struct { + rc io.ReadCloser + wc io.WriteCloser +} + +func (r rwc) Read(p []byte) (n int, err error) { + return r.rc.Read(p) +} + +func (r rwc) Write(p []byte) (n int, err error) { + return r.wc.Write(p) +} + +func (r rwc) Close() error { + rcErr := r.rc.Close() + + var wcErr error + if r.wc != nil { // we only allow a nil writer in unit tests + wcErr = r.wc.Close() + } + + return errors.Join(rcErr, wcErr) +} + +// An ioConn is a transport that delimits messages with newlines across +// a bidirectional stream, and supports jsonrpc.2 message batching. +// +// See https://github.com/ndjson/ndjson-spec for discussion of newline +// delimited JSON. +// +// See [msgBatch] for more discussion of message batching. +type ioConn struct { + protocolVersion string // negotiated version, set during session initialization. + + writeMu sync.Mutex // guards Write, which must be concurrency safe. + rwc io.ReadWriteCloser // the underlying stream + + // incoming receives messages from the read loop started in [newIOConn]. + incoming <-chan msgOrErr + + // If outgoiBatch has a positive capacity, it will be used to batch requests + // and notifications before sending. + outgoingBatch []jsonrpc.Message + + // Unread messages in the last batch. Since reads are serialized, there is no + // need to guard here. + queue []jsonrpc.Message + + // batches correlate incoming requests to the batch in which they arrived. + // Since writes may be concurrent to reads, we need to guard this with a mutex. + batchMu sync.Mutex + batches map[jsonrpc2.ID]*msgBatch // lazily allocated + + closeOnce sync.Once + closed chan struct{} + closeErr error +} + +type msgOrErr struct { + msg json.RawMessage + err error +} + +func newIOConn(rwc io.ReadWriteCloser) *ioConn { + var ( + incoming = make(chan msgOrErr) + closed = make(chan struct{}) + ) + // Start a goroutine for reads, so that we can select on the incoming channel + // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). + // + // This leaks a goroutine if rwc.Read does not unblock after it is closed, + // but that is unavoidable since AFAIK there is no (easy and portable) way to + // guarantee that reads of stdin are unblocked when closed. + go func() { + dec := json.NewDecoder(rwc) + for { + var raw json.RawMessage + err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + // Support both Unix (\n) and Windows (\r\n) line endings. + if tr[0] != '\n' && tr[0] != '\r' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } + select { + case incoming <- msgOrErr{msg: raw, err: err}: + case <-closed: + return + } + if err != nil { + return + } + } + }() + return &ioConn{ + rwc: rwc, + incoming: incoming, + closed: closed, + } +} + +func (c *ioConn) SessionID() string { return "" } + +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + +// addBatch records a msgBatch for an incoming batch payload. +// It returns an error if batch is malformed, containing previously seen IDs. +// +// See [msgBatch] for more. +func (t *ioConn) addBatch(batch *msgBatch) error { + t.batchMu.Lock() + defer t.batchMu.Unlock() + for id := range batch.unresolved { + if _, ok := t.batches[id]; ok { + return fmt.Errorf("%w: batch contains previously seen request %v", jsonrpc2.ErrInvalidRequest, id.Raw()) + } + } + for id := range batch.unresolved { + if t.batches == nil { + t.batches = make(map[jsonrpc2.ID]*msgBatch) + } + t.batches[id] = batch + } + return nil +} + +// updateBatch records a response in the message batch tracking the +// corresponding incoming call, if any. +// +// The second result reports whether resp was part of a batch. If this is true, +// the first result is nil if the batch is still incomplete, or the full set of +// batch responses if resp completed the batch. +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { + t.batchMu.Lock() + defer t.batchMu.Unlock() + + if batch, ok := t.batches[resp.ID]; ok { + idx, ok := batch.unresolved[resp.ID] + if !ok { + panic("internal error: inconsistent batches") + } + batch.responses[idx] = resp + delete(batch.unresolved, resp.ID) + delete(t.batches, resp.ID) + if len(batch.unresolved) == 0 { + return batch.responses, true + } + return nil, true + } + return nil, false +} + +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. +// +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: +// +// "The Server should respond with an Array containing the corresponding +// Response objects, after all of the batch Request objects have been +// processed. A Response object SHOULD exist for each Request object, except +// that there SHOULD NOT be any Response objects for notifications. The Server +// MAY process a batch rpc call as a set of concurrent tasks, processing them +// in any order and with any width of parallelism." +// +// Therefore, a msgBatch keeps track of outstanding calls and their responses. +// When there are no unresolved calls, the response payload is sent. +type msgBatch struct { + unresolved map[jsonrpc2.ID]int + responses []*jsonrpc.Response +} + +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { + // As a matter of principle, enforce that reads on a closed context return an + // error. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if len(t.queue) > 0 { + next := t.queue[0] + t.queue = t.queue[1:] + return next, nil + } + + var raw json.RawMessage + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case v := <-t.incoming: + if v.err != nil { + return nil, v.err + } + raw = v.msg + + case <-t.closed: + return nil, io.EOF + } + + msgs, batch, err := readBatch(raw) + if err != nil { + return nil, err + } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + + t.queue = msgs[1:] + + if batch { + var respBatch *msgBatch // track incoming requests in the batch + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + if respBatch == nil { + respBatch = &msgBatch{ + unresolved: make(map[jsonrpc2.ID]int), + } + } + if _, ok := respBatch.unresolved[req.ID]; ok { + return nil, fmt.Errorf("duplicate message ID %q", req.ID) + } + respBatch.unresolved[req.ID] = len(respBatch.responses) + respBatch.responses = append(respBatch.responses, nil) + } + } + if respBatch != nil { + // The batch contains one or more incoming requests to track. + if err := t.addBatch(respBatch); err != nil { + return nil, err + } + } + } + return msgs[0], err +} + +// readBatch reads batch data, which may be either a single JSON-RPC message, +// or an array of JSON-RPC messages. +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { + // Try to read an array of messages first. + var rawBatch []json.RawMessage + if err := json.Unmarshal(data, &rawBatch); err == nil { + if len(rawBatch) == 0 { + return nil, true, fmt.Errorf("empty batch") + } + for _, raw := range rawBatch { + msg, err := jsonrpc2.DecodeMessage(raw) + if err != nil { + return nil, true, err + } + msgs = append(msgs, msg) + } + return msgs, true, nil + } + // Try again with a single message. + msg, err := jsonrpc2.DecodeMessage(data) + return []jsonrpc.Message{msg}, false, err +} + +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // As in [ioConn.Read], enforce that Writes on a closed context are an error. + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + t.writeMu.Lock() + defer t.writeMu.Unlock() + + // Batching support: if msg is a Response, it may have completed a batch, so + // check that first. Otherwise, it is a request or notification, and we may + // want to collect it into a batch before sending, if we're configured to use + // outgoing batches. + if resp, ok := msg.(*jsonrpc.Response); ok { + if batch, ok := t.updateBatch(resp); ok { + if len(batch) > 0 { + data, err := marshalMessages(batch) + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + } else if len(t.outgoingBatch) < cap(t.outgoingBatch) { + t.outgoingBatch = append(t.outgoingBatch, msg) + if len(t.outgoingBatch) == cap(t.outgoingBatch) { + data, err := marshalMessages(t.outgoingBatch) + t.outgoingBatch = t.outgoingBatch[:0] + if err != nil { + return err + } + data = append(data, '\n') + _, err = t.rwc.Write(data) + return err + } + return nil + } + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return fmt.Errorf("marshaling message: %v", err) + } + data = append(data, '\n') // newline delimited + _, err = t.rwc.Write(data) + return err +} + +func (t *ioConn) Close() error { + t.closeOnce.Do(func() { + t.closeErr = t.rwc.Close() + close(t.closed) + }) + return t.closeErr +} + +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { + var rawMsgs []json.RawMessage + for _, msg := range msgs { + raw, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + return nil, fmt.Errorf("encoding batch message: %w", err) + } + rawMsgs = append(rawMsgs, raw) + } + return json.Marshal(rawMsgs) +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go new file mode 100644 index 000000000..5ada466e5 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/mcp/util.go @@ -0,0 +1,43 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "crypto/rand" + "encoding/json" +) + +func assert(cond bool, msg string) { + if !cond { + panic(msg) + } +} + +// Copied from crypto/rand. +// TODO: once 1.24 is assured, just use crypto/rand. +const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + +func randText() string { + // ⌈log₃₂ 2¹²⁸⌉ = 26 chars + src := make([]byte, 26) + rand.Read(src) + for i := range src { + src[i] = base32alphabet[src[i]%32] + } + return string(src) +} + +// remarshal marshals from to JSON, and then unmarshals into to, which must be +// a pointer type. +func remarshal(from, to any) error { + data, err := json.Marshal(from) + if err != nil { + return err + } + if err := json.Unmarshal(data, to); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go new file mode 100644 index 000000000..9aa0c8d7d --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go @@ -0,0 +1,187 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, +// as defined in [RFC 8414]. +// +// Not supported: +// - signed metadata +// +// Note: URL fields in this struct are validated by validateAuthServerMetaURLs to +// prevent XSS attacks. If you add a new URL field, you must also add it to that +// function. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414) +type AuthServerMeta struct { + // GENERATED BY GEMINI 2.5. + + // Issuer is the REQUIRED URL identifying the authorization server. + Issuer string `json:"issuer"` + + // AuthorizationEndpoint is the REQUIRED URL of the server's OAuth 2.0 authorization endpoint. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the REQUIRED URL of the server's OAuth 2.0 token endpoint. + TokenEndpoint string `json:"token_endpoint"` + + // JWKSURI is the REQUIRED URL of the server's JSON Web Key Set [JWK] document. + JWKSURI string `json:"jwks_uri"` + + // RegistrationEndpoint is the RECOMMENDED URL of the server's OAuth 2.0 Dynamic Client Registration endpoint. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // ScopesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "scope" values that this server supports. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ResponseTypesSupported is a REQUIRED JSON array of strings containing a list of the OAuth 2.0 + // "response_type" values that this server supports. + ResponseTypesSupported []string `json:"response_types_supported"` + + // ResponseModesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "response_mode" values that this server supports. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // GrantTypesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // grant type values that this server supports. + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // TokenEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // client authentication methods supported by this token endpoint. + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // TokenEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings containing + // a list of the JWS signing algorithms ("alg" values) supported by the token endpoint for + // the signature on the JWT used to authenticate the client. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // ServiceDocumentation is a RECOMMENDED URL of a page containing human-readable documentation + // for the service. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + // UILocalesSupported is a RECOMMENDED JSON array of strings representing supported + // BCP47 [RFC5646] language tag values for display in the user interface. + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + + // OpPolicyURI is a RECOMMENDED URL that the server provides to the person registering + // the client to read about the server's operator policies. + OpPolicyURI string `json:"op_policy_uri,omitempty"` + + // OpTOSURI is a RECOMMENDED URL that the server provides to the person registering the + // client to read about the server's terms of service. + OpTOSURI string `json:"op_tos_uri,omitempty"` + + // RevocationEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 revocation endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // RevocationEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this revocation endpoint. + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // RevocationEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the revocation + // endpoint for the signature on the JWT used to authenticate the client. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + // IntrospectionEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 introspection endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // IntrospectionEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this introspection endpoint. + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // IntrospectionEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the introspection + // endpoint for the signature on the JWT used to authenticate the client. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // PKCE code challenge methods supported by this authorization server. + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` +} + +var wellKnownPaths = []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", +} + +// GetAuthServerMeta issues a GET request to retrieve authorization server metadata +// from an OAuth authorization server with the given issuerURL. +// +// It follows [RFC 8414]: +// - The well-known paths specified there are inserted into the URL's path, one at time. +// The first to succeed is used. +// - The Issuer field is checked against issuerURL. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414 +func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { + var errs []error + for _, p := range wellKnownPaths { + u, err := prependToPath(issuerURL, p) + if err != nil { + // issuerURL is bad; no point in continuing. + return nil, err + } + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) + if err == nil { + if asm.Issuer != issuerURL { // section 3.3 + // Security violation; don't keep trying. + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } + + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) + } + + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err + } + + return asm, nil + } + errs = append(errs, err) + } + return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) +} + +// validateAuthServerMetaURLs validates all URL fields in AuthServerMeta +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateAuthServerMetaURLs(asm *AuthServerMeta) error { + urls := []struct { + name string + value string + }{ + {"authorization_endpoint", asm.AuthorizationEndpoint}, + {"token_endpoint", asm.TokenEndpoint}, + {"jwks_uri", asm.JWKSURI}, + {"registration_endpoint", asm.RegistrationEndpoint}, + {"service_documentation", asm.ServiceDocumentation}, + {"op_policy_uri", asm.OpPolicyURI}, + {"op_tos_uri", asm.OpTOSURI}, + {"revocation_endpoint", asm.RevocationEndpoint}, + {"introspection_endpoint", asm.IntrospectionEndpoint}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go new file mode 100644 index 000000000..c64cb8cd4 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go @@ -0,0 +1,261 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// ClientRegistrationMetadata represents the client metadata fields for the DCR POST request (RFC 7591). +// +// Note: URL fields in this struct are validated by validateClientRegistrationURLs +// to prevent XSS attacks. If you add a new URL field, you must also add it to +// that function. +type ClientRegistrationMetadata struct { + // RedirectURIs is a REQUIRED JSON array of redirection URI strings for use in + // redirect-based flows (such as the authorization code grant). + RedirectURIs []string `json:"redirect_uris"` + + // TokenEndpointAuthMethod is an OPTIONAL string indicator of the requested + // authentication method for the token endpoint. + // If omitted, the default is "client_secret_basic". + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // GrantTypes is an OPTIONAL JSON array of OAuth 2.0 grant type strings + // that the client will restrict itself to using. + // If omitted, the default is ["authorization_code"]. + GrantTypes []string `json:"grant_types,omitempty"` + + // ResponseTypes is an OPTIONAL JSON array of OAuth 2.0 response type strings + // that the client will restrict itself to using. + // If omitted, the default is ["code"]. + ResponseTypes []string `json:"response_types,omitempty"` + + // ClientName is a RECOMMENDED human-readable name of the client to be presented + // to the end-user. + ClientName string `json:"client_name,omitempty"` + + // ClientURI is a RECOMMENDED URL of a web page providing information about the client. + ClientURI string `json:"client_uri,omitempty"` + + // LogoURI is an OPTIONAL URL of a logo for the client, which may be displayed + // to the end-user. + LogoURI string `json:"logo_uri,omitempty"` + + // Scope is an OPTIONAL string containing a space-separated list of scope values + // that the client will restrict itself to using. + Scope string `json:"scope,omitempty"` + + // Contacts is an OPTIONAL JSON array of strings representing ways to contact + // people responsible for this client (e.g., email addresses). + Contacts []string `json:"contacts,omitempty"` + + // TOSURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's terms of service. + TOSURI string `json:"tos_uri,omitempty"` + + // PolicyURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's privacy policy. + PolicyURI string `json:"policy_uri,omitempty"` + + // JWKSURI is an OPTIONAL URL for the client's JSON Web Key Set [JWK] document. + // This is preferred over the 'jwks' parameter. + JWKSURI string `json:"jwks_uri,omitempty"` + + // JWKS is an OPTIONAL client's JSON Web Key Set [JWK] document, passed by value. + // This is an alternative to providing a JWKSURI. + JWKS string `json:"jwks,omitempty"` + + // SoftwareID is an OPTIONAL unique identifier string for the client software, + // constant across all instances and versions. + SoftwareID string `json:"software_id,omitempty"` + + // SoftwareVersion is an OPTIONAL version identifier string for the client software. + SoftwareVersion string `json:"software_version,omitempty"` + + // SoftwareStatement is an OPTIONAL JWT that asserts client metadata values. + // Values in the software statement take precedence over other metadata values. + SoftwareStatement string `json:"software_statement,omitempty"` +} + +// ClientRegistrationResponse represents the fields returned by the Authorization Server +// (RFC 7591, Section 3.2.1 and 3.2.2). +type ClientRegistrationResponse struct { + // ClientRegistrationMetadata contains all registered client metadata, returned by the + // server on success, potentially with modified or defaulted values. + ClientRegistrationMetadata + + // ClientID is the REQUIRED newly issued OAuth 2.0 client identifier. + ClientID string `json:"client_id"` + + // ClientSecret is an OPTIONAL client secret string. + ClientSecret string `json:"client_secret,omitempty"` + + // ClientIDIssuedAt is an OPTIONAL Unix timestamp when the ClientID was issued. + ClientIDIssuedAt time.Time `json:"client_id_issued_at,omitempty"` + + // ClientSecretExpiresAt is the REQUIRED (if client_secret is issued) Unix + // timestamp when the secret expires, or 0 if it never expires. + ClientSecretExpiresAt time.Time `json:"client_secret_expires_at,omitempty"` +} + +func (r *ClientRegistrationResponse) MarshalJSON() ([]byte, error) { + type alias ClientRegistrationResponse + var clientIDIssuedAt int64 + var clientSecretExpiresAt int64 + + if !r.ClientIDIssuedAt.IsZero() { + clientIDIssuedAt = r.ClientIDIssuedAt.Unix() + } + if !r.ClientSecretExpiresAt.IsZero() { + clientSecretExpiresAt = r.ClientSecretExpiresAt.Unix() + } + + return json.Marshal(&struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + ClientIDIssuedAt: clientIDIssuedAt, + ClientSecretExpiresAt: clientSecretExpiresAt, + alias: (*alias)(r), + }) +} + +func (r *ClientRegistrationResponse) UnmarshalJSON(data []byte) error { + type alias ClientRegistrationResponse + aux := &struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + alias: (*alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if aux.ClientIDIssuedAt != 0 { + r.ClientIDIssuedAt = time.Unix(aux.ClientIDIssuedAt, 0) + } + if aux.ClientSecretExpiresAt != 0 { + r.ClientSecretExpiresAt = time.Unix(aux.ClientSecretExpiresAt, 0) + } + return nil +} + +// ClientRegistrationError is the error response from the Authorization Server +// for a failed registration attempt (RFC 7591, Section 3.2.2). +type ClientRegistrationError struct { + // ErrorCode is the REQUIRED error code if registration failed (RFC 7591, 3.2.2). + ErrorCode string `json:"error"` + + // ErrorDescription is an OPTIONAL human-readable error message. + ErrorDescription string `json:"error_description,omitempty"` +} + +func (e *ClientRegistrationError) Error() string { + return fmt.Sprintf("registration failed: %s (%s)", e.ErrorCode, e.ErrorDescription) +} + +// RegisterClient performs Dynamic Client Registration according to RFC 7591. +func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta *ClientRegistrationMetadata, c *http.Client) (*ClientRegistrationResponse, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("registration_endpoint is required") + } + + if c == nil { + c = http.DefaultClient + } + + payload, err := json.Marshal(clientMeta) + if err != nil { + return nil, fmt.Errorf("failed to marshal client metadata: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read registration response body: %w", err) + } + + if resp.StatusCode == http.StatusCreated { + var regResponse ClientRegistrationResponse + if err := json.Unmarshal(body, ®Response); err != nil { + return nil, fmt.Errorf("failed to decode successful registration response: %w (%s)", err, string(body)) + } + if regResponse.ClientID == "" { + return nil, fmt.Errorf("registration response is missing required 'client_id' field") + } + // Validate URL fields to prevent XSS attacks (see #526). + if err := validateClientRegistrationURLs(®Response.ClientRegistrationMetadata); err != nil { + return nil, err + } + return ®Response, nil + } + + if resp.StatusCode == http.StatusBadRequest { + var regError ClientRegistrationError + if err := json.Unmarshal(body, ®Error); err != nil { + return nil, fmt.Errorf("failed to decode registration error response: %w (%s)", err, string(body)) + } + return nil, ®Error + } + + return nil, fmt.Errorf("registration failed with status %s: %s", resp.Status, string(body)) +} + +// validateClientRegistrationURLs validates all URL fields in ClientRegistrationMetadata +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { + // Validate redirect URIs + for i, uri := range meta.RedirectURIs { + if err := checkURLScheme(uri); err != nil { + return fmt.Errorf("redirect_uris[%d]: %w", i, err) + } + } + + // Validate other URL fields + urls := []struct { + name string + value string + }{ + {"client_uri", meta.ClientURI}, + {"logo_uri", meta.LogoURI}, + {"tos_uri", meta.TOSURI}, + {"policy_uri", meta.PolicyURI}, + {"jwks_uri", meta.JWKSURI}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go new file mode 100644 index 000000000..cdda695b7 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go @@ -0,0 +1,91 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" +) + +// prependToPath prepends pre to the path of urlStr. +// When pre is the well-known path, this is the algorithm specified in both RFC 9728 +// section 3.1 and RFC 8414 section 3.1. +func prependToPath(urlStr, pre string) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", err + } + p := "/" + strings.Trim(pre, "/") + if u.Path != "" { + p += "/" + } + + u.Path = p + strings.TrimLeft(u.Path, "/") + return u.String(), nil +} + +// getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both +// RFC 9728 and RFC 8414. +// It will not read more than limit bytes from the body. +func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Specs require a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Specs require application/json. + ct := res.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(ct) + if err != nil || mediaType != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var t T + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + if err := dec.Decode(&t); err != nil { + return nil, err + } + return &t, nil +} + +// checkURLScheme ensures that its argument is a valid URL with a scheme +// that prevents XSS attacks. +// See #526. +func checkURLScheme(u string) error { + if u == "" { + return nil + } + uu, err := url.Parse(u) + if err != nil { + return err + } + scheme := strings.ToLower(uu.Scheme) + if scheme == "javascript" || scheme == "data" || scheme == "vbscript" { + return fmt.Errorf("URL has disallowed scheme %q", scheme) + } + return nil +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go new file mode 100644 index 000000000..34ed55b59 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go @@ -0,0 +1,92 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} diff --git a/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go new file mode 100644 index 000000000..bb61f7974 --- /dev/null +++ b/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go @@ -0,0 +1,281 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "unicode" + + "github.com/modelcontextprotocol/go-sdk/internal/util" +) + +const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource" + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server by its ID. +// The resource ID is an HTTPS URL, typically with a host:port and possibly a path. +// For example: +// +// https://example.com/server +// +// This function, following the spec (§3), inserts the default well-known path into the +// URL. In our example, the result would be +// +// https://example.com/.well-known/oauth-protected-resource/server +// +// It then retrieves the metadata at that location using the given client (or the +// default client if nil) and validates its resource field against resourceID. +func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) + + u, err := url.Parse(resourceID) + if err != nil { + return nil, err + } + // Insert well-known URI into URL. + u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) + return getPRM(ctx, u.String(), c, resourceID) +} + +// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata +// using information in the given header, using the given client (or the default +// client if nil). +// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request. +// Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata +// matches the serverURL (the URL that the client used to make the original request to the resource server). +// If there is no metadata URL in the header, it returns nil, nil. +func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") + headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] + if len(headers) == 0 { + return nil, nil + } + cs, err := ParseWWWAuthenticate(headers) + if err != nil { + return nil, err + } + metadataURL := ResourceMetadataURL(cs) + if metadataURL == "" { + return nil, nil + } + return getPRM(ctx, metadataURL, c, serverURL) +} + +// getPRM makes a GET request to the given URL, and validates the response. +// As part of the validation, it compares the returned resource field to wantResource. +func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { + if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { + return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) + } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) + if err != nil { + return nil, err + } + // Validate the Resource field (see RFC 9728, section 3.3). + if prm.Resource != wantResource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + } + // Validate the authorization server URLs to prevent XSS attacks (see #526). + for _, u := range prm.AuthorizationServers { + if err := checkURLScheme(u); err != nil { + return nil, err + } + } + return prm, nil +} + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +// ResourceMetadataURL returns a resource metadata URL from the given challenges, +// or the empty string if there is none. +func ResourceMetadataURL(cs []challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.gitignore b/vendor/github.com/wk8/go-ordered-map/v2/.gitignore deleted file mode 100644 index 57872d0f1..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/vendor/ diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml b/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml deleted file mode 100644 index 2417df10d..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml +++ /dev/null @@ -1,80 +0,0 @@ -run: - tests: false - -linters: - disable-all: true - enable: - - asciicheck - - bidichk - - bodyclose - - containedctx - - contextcheck - - decorder - - depguard - - dogsled - - dupl - - durationcheck - - errcheck - - errchkjson - # FIXME: commented out as it crashes with 1.18 for now - # - errname - - errorlint - - exportloopref - - forbidigo - - funlen - - gci - - gochecknoglobals - - gochecknoinits - - gocognit - - goconst - - gocritic - - gocyclo - - godox - - gofmt - - gofumpt - - goheader - - goimports - - gomnd - - gomoddirectives - - gomodguard - - goprintffuncname - - gosec - - gosimple - - govet - - grouper - - ifshort - - importas - - ineffassign - - lll - - maintidx - - makezero - - misspell - - nakedret - - nilerr - - nilnil - - noctx - - nolintlint - - paralleltest - - prealloc - - predeclared - - promlinter - # FIXME: doesn't support 1.18 yet - # - revive - - rowserrcheck - - sqlclosecheck - - staticcheck - - structcheck - - stylecheck - - tagliatelle - - tenv - - testpackage - - thelper - - tparallel - - typecheck - - unconvert - - unparam - - unused - - varcheck - - varnamelen - - wastedassign - - whitespace diff --git a/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md b/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md deleted file mode 100644 index f27126f84..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md +++ /dev/null @@ -1,38 +0,0 @@ -# Changelog - -[comment]: # (Changes since last release go here) - -## 2.1.8 - Jun 27th 2023 - -* Added support for YAML serialization/deserialization - -## 2.1.7 - Apr 13th 2023 - -* Renamed test_utils.go to utils_test.go - -## 2.1.6 - Feb 15th 2023 - -* Added `GetAndMoveToBack()` and `GetAndMoveToFront()` methods - -## 2.1.5 - Dec 13th 2022 - -* Added `Value()` method - -## 2.1.4 - Dec 12th 2022 - -* Fixed a bug with UTF-8 special characters in JSON keys - -## 2.1.3 - Dec 11th 2022 - -* Added support for JSON marshalling/unmarshalling of wrapper of primitive types - -## 2.1.2 - Dec 10th 2022 -* Allowing to pass options to `New`, to give a capacity hint, or initial data -* Allowing to deserialize nested ordered maps from JSON without having to explicitly instantiate them -* Added the `AddPairs` method - -## 2.1.1 - Dec 9th 2022 -* Fixing a bug with JSON marshalling - -## 2.1.0 - Dec 7th 2022 -* Added support for JSON serialization/deserialization diff --git a/vendor/github.com/wk8/go-ordered-map/v2/LICENSE b/vendor/github.com/wk8/go-ordered-map/v2/LICENSE deleted file mode 100644 index 8dada3eda..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/vendor/github.com/wk8/go-ordered-map/v2/Makefile b/vendor/github.com/wk8/go-ordered-map/v2/Makefile deleted file mode 100644 index 6e0e18a1b..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/Makefile +++ /dev/null @@ -1,32 +0,0 @@ -.DEFAULT_GOAL := all - -.PHONY: all -all: test_with_fuzz lint - -# the TEST_FLAGS env var can be set to eg run only specific tests -TEST_COMMAND = go test -v -count=1 -race -cover $(TEST_FLAGS) - -.PHONY: test -test: - $(TEST_COMMAND) - -.PHONY: bench -bench: - go test -bench=. - -FUZZ_TIME ?= 10s - -# see https://github.com/golang/go/issues/46312 -# and https://stackoverflow.com/a/72673487/4867444 -# if we end up having more fuzz tests -.PHONY: test_with_fuzz -test_with_fuzz: - $(TEST_COMMAND) -fuzz=FuzzRoundTripJSON -fuzztime=$(FUZZ_TIME) - $(TEST_COMMAND) -fuzz=FuzzRoundTripYAML -fuzztime=$(FUZZ_TIME) - -.PHONY: fuzz -fuzz: test_with_fuzz - -.PHONY: lint -lint: - golangci-lint run diff --git a/vendor/github.com/wk8/go-ordered-map/v2/README.md b/vendor/github.com/wk8/go-ordered-map/v2/README.md deleted file mode 100644 index b02894443..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/README.md +++ /dev/null @@ -1,154 +0,0 @@ -[![Go Reference](https://pkg.go.dev/badge/github.com/wk8/go-ordered-map/v2.svg)](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2) -[![Build Status](https://circleci.com/gh/wk8/go-ordered-map.svg?style=svg)](https://app.circleci.com/pipelines/github/wk8/go-ordered-map) - -# Golang Ordered Maps - -Same as regular maps, but also remembers the order in which keys were inserted, akin to [Python's `collections.OrderedDict`s](https://docs.python.org/3.7/library/collections.html#ordereddict-objects). - -It offers the following features: -* optimal runtime performance (all operations are constant time) -* optimal memory usage (only one copy of values, no unnecessary memory allocation) -* allows iterating from newest or oldest keys indifferently, without memory copy, allowing to `break` the iteration, and in time linear to the number of keys iterated over rather than the total length of the ordered map -* supports any generic types for both keys and values. If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) that takes and returns generic `interface{}`s instead of using generics -* idiomatic API, akin to that of [`container/list`](https://golang.org/pkg/container/list) -* support for JSON and YAML marshalling - -## Documentation - -[The full documentation is available on pkg.go.dev](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2). - -## Installation -```bash -go get -u github.com/wk8/go-ordered-map/v2 -``` - -Or use your favorite golang vendoring tool! - -## Supported go versions - -Go >= 1.18 is required to use version >= 2 of this library, as it uses generics. - -If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) instead. - -## Example / usage - -```go -package main - -import ( - "fmt" - - "github.com/wk8/go-ordered-map/v2" -) - -func main() { - om := orderedmap.New[string, string]() - - om.Set("foo", "bar") - om.Set("bar", "baz") - om.Set("coucou", "toi") - - fmt.Println(om.Get("foo")) // => "bar", true - fmt.Println(om.Get("i dont exist")) // => "", false - - // iterating pairs from oldest to newest: - for pair := om.Oldest(); pair != nil; pair = pair.Next() { - fmt.Printf("%s => %s\n", pair.Key, pair.Value) - } // prints: - // foo => bar - // bar => baz - // coucou => toi - - // iterating over the 2 newest pairs: - i := 0 - for pair := om.Newest(); pair != nil; pair = pair.Prev() { - fmt.Printf("%s => %s\n", pair.Key, pair.Value) - i++ - if i >= 2 { - break - } - } // prints: - // coucou => toi - // bar => baz -} -``` - -An `OrderedMap`'s keys must implement `comparable`, and its values can be anything, for example: - -```go -type myStruct struct { - payload string -} - -func main() { - om := orderedmap.New[int, *myStruct]() - - om.Set(12, &myStruct{"foo"}) - om.Set(1, &myStruct{"bar"}) - - value, present := om.Get(12) - if !present { - panic("should be there!") - } - fmt.Println(value.payload) // => foo - - for pair := om.Oldest(); pair != nil; pair = pair.Next() { - fmt.Printf("%d => %s\n", pair.Key, pair.Value.payload) - } // prints: - // 12 => foo - // 1 => bar -} -``` - -Also worth noting that you can provision ordered maps with a capacity hint, as you would do by passing an optional hint to `make(map[K]V, capacity`): -```go -om := orderedmap.New[int, *myStruct](28) -``` - -You can also pass in some initial data to store in the map: -```go -om := orderedmap.New[int, string](orderedmap.WithInitialData[int, string]( - orderedmap.Pair[int, string]{ - Key: 12, - Value: "foo", - }, - orderedmap.Pair[int, string]{ - Key: 28, - Value: "bar", - }, -)) -``` - -`OrderedMap`s also support JSON serialization/deserialization, and preserves order: - -```go -// serialization -data, err := json.Marshal(om) -... - -// deserialization -om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect -err := json.Unmarshal(data, &om) -... -``` - -Similarly, it also supports YAML serialization/deserialization using the yaml.v3 package, which also preserves order: - -```go -// serialization -data, err := yaml.Marshal(om) -... - -// deserialization -om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect -err := yaml.Unmarshal(data, &om) -... -``` - -## Alternatives - -There are several other ordered map golang implementations out there, but I believe that at the time of writing none of them offer the same functionality as this library; more specifically: -* [iancoleman/orderedmap](https://github.com/iancoleman/orderedmap) only accepts `string` keys, its `Delete` operations are linear -* [cevaris/ordered_map](https://github.com/cevaris/ordered_map) uses a channel for iterations, and leaks goroutines if the iteration is interrupted before fully traversing the map -* [mantyr/iterator](https://github.com/mantyr/iterator) also uses a channel for iterations, and its `Delete` operations are linear -* [samdolan/go-ordered-map](https://github.com/samdolan/go-ordered-map) adds unnecessary locking (users should add their own locking instead if they need it), its `Delete` and `Get` operations are linear, iterations trigger a linear memory allocation diff --git a/vendor/github.com/wk8/go-ordered-map/v2/json.go b/vendor/github.com/wk8/go-ordered-map/v2/json.go deleted file mode 100644 index a545b536b..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/json.go +++ /dev/null @@ -1,182 +0,0 @@ -package orderedmap - -import ( - "bytes" - "encoding" - "encoding/json" - "fmt" - "reflect" - "unicode/utf8" - - "github.com/buger/jsonparser" - "github.com/mailru/easyjson/jwriter" -) - -var ( - _ json.Marshaler = &OrderedMap[int, any]{} - _ json.Unmarshaler = &OrderedMap[int, any]{} -) - -// MarshalJSON implements the json.Marshaler interface. -func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen - if om == nil || om.list == nil { - return []byte("null"), nil - } - - writer := jwriter.Writer{} - writer.RawByte('{') - - for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() { - if firstIteration { - firstIteration = false - } else { - writer.RawByte(',') - } - - switch key := any(pair.Key).(type) { - case string: - writer.String(key) - case encoding.TextMarshaler: - writer.RawByte('"') - writer.Raw(key.MarshalText()) - writer.RawByte('"') - case int: - writer.IntStr(key) - case int8: - writer.Int8Str(key) - case int16: - writer.Int16Str(key) - case int32: - writer.Int32Str(key) - case int64: - writer.Int64Str(key) - case uint: - writer.UintStr(key) - case uint8: - writer.Uint8Str(key) - case uint16: - writer.Uint16Str(key) - case uint32: - writer.Uint32Str(key) - case uint64: - writer.Uint64Str(key) - default: - - // this switch takes care of wrapper types around primitive types, such as - // type myType string - switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() { - case reflect.String: - writer.String(keyValue.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - writer.Int64Str(keyValue.Int()) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - writer.Uint64Str(keyValue.Uint()) - default: - return nil, fmt.Errorf("unsupported key type: %T", key) - } - } - - writer.RawByte(':') - // the error is checked at the end of the function - writer.Raw(json.Marshal(pair.Value)) //nolint:errchkjson - } - - writer.RawByte('}') - - return dumpWriter(&writer) -} - -func dumpWriter(writer *jwriter.Writer) ([]byte, error) { - if writer.Error != nil { - return nil, writer.Error - } - - var buf bytes.Buffer - buf.Grow(writer.Size()) - if _, err := writer.DumpTo(&buf); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error { - if om.list == nil { - om.initialize(0) - } - - return jsonparser.ObjectEach( - data, - func(keyData []byte, valueData []byte, dataType jsonparser.ValueType, offset int) error { - if dataType == jsonparser.String { - // jsonparser removes the enclosing quotes; we need to restore them to make a valid JSON - valueData = data[offset-len(valueData)-2 : offset] - } - - var key K - var value V - - switch typedKey := any(&key).(type) { - case *string: - s, err := decodeUTF8(keyData) - if err != nil { - return err - } - *typedKey = s - case encoding.TextUnmarshaler: - if err := typedKey.UnmarshalText(keyData); err != nil { - return err - } - case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64: - if err := json.Unmarshal(keyData, typedKey); err != nil { - return err - } - default: - // this switch takes care of wrapper types around primitive types, such as - // type myType string - switch reflect.TypeOf(key).Kind() { - case reflect.String: - s, err := decodeUTF8(keyData) - if err != nil { - return err - } - - convertedKeyData := reflect.ValueOf(s).Convert(reflect.TypeOf(key)) - reflect.ValueOf(&key).Elem().Set(convertedKeyData) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if err := json.Unmarshal(keyData, &key); err != nil { - return err - } - default: - return fmt.Errorf("unsupported key type: %T", key) - } - } - - if err := json.Unmarshal(valueData, &value); err != nil { - return err - } - - om.Set(key, value) - return nil - }) -} - -func decodeUTF8(input []byte) (string, error) { - remaining, offset := input, 0 - runes := make([]rune, 0, len(remaining)) - - for len(remaining) > 0 { - r, size := utf8.DecodeRune(remaining) - if r == utf8.RuneError && size <= 1 { - return "", fmt.Errorf("not a valid UTF-8 string (at position %d): %s", offset, string(input)) - } - - runes = append(runes, r) - remaining = remaining[size:] - offset += size - } - - return string(runes), nil -} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go b/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go deleted file mode 100644 index 064714191..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go +++ /dev/null @@ -1,296 +0,0 @@ -// Package orderedmap implements an ordered map, i.e. a map that also keeps track of -// the order in which keys were inserted. -// -// All operations are constant-time. -// -// Github repo: https://github.com/wk8/go-ordered-map -// -package orderedmap - -import ( - "fmt" - - list "github.com/bahlo/generic-list-go" -) - -type Pair[K comparable, V any] struct { - Key K - Value V - - element *list.Element[*Pair[K, V]] -} - -type OrderedMap[K comparable, V any] struct { - pairs map[K]*Pair[K, V] - list *list.List[*Pair[K, V]] -} - -type initConfig[K comparable, V any] struct { - capacity int - initialData []Pair[K, V] -} - -type InitOption[K comparable, V any] func(config *initConfig[K, V]) - -// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity). -func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] { - return func(c *initConfig[K, V]) { - c.capacity = capacity - } -} - -// WithInitialData allows passing in initial data for the map. -func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] { - return func(c *initConfig[K, V]) { - c.initialData = initialData - if c.capacity < len(initialData) { - c.capacity = len(initialData) - } - } -} - -// New creates a new OrderedMap. -// options can either be one or several InitOption[K, V], or a single integer, -// which is then interpreted as a capacity hint, à la make(map[K]V, capacity). -func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen - orderedMap := &OrderedMap[K, V]{} - - var config initConfig[K, V] - for _, untypedOption := range options { - switch option := untypedOption.(type) { - case int: - if len(options) != 1 { - invalidOption() - } - config.capacity = option - - case InitOption[K, V]: - option(&config) - - default: - invalidOption() - } - } - - orderedMap.initialize(config.capacity) - orderedMap.AddPairs(config.initialData...) - - return orderedMap -} - -const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll - -func invalidOption() { panic(invalidOptionMessage) } - -func (om *OrderedMap[K, V]) initialize(capacity int) { - om.pairs = make(map[K]*Pair[K, V], capacity) - om.list = list.New[*Pair[K, V]]() -} - -// Get looks for the given key, and returns the value associated with it, -// or V's nil value if not found. The boolean it returns says whether the key is present in the map. -func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) { - if pair, present := om.pairs[key]; present { - return pair.Value, true - } - - return -} - -// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s. -func (om *OrderedMap[K, V]) Load(key K) (V, bool) { - return om.Get(key) -} - -// Value returns the value associated with the given key or the zero value. -func (om *OrderedMap[K, V]) Value(key K) (val V) { - if pair, present := om.pairs[key]; present { - val = pair.Value - } - return -} - -// GetPair looks for the given key, and returns the pair associated with it, -// or nil if not found. The Pair struct can then be used to iterate over the ordered map -// from that point, either forward or backward. -func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] { - return om.pairs[key] -} - -// Set sets the key-value pair, and returns what `Get` would have returned -// on that key prior to the call to `Set`. -func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) { - if pair, present := om.pairs[key]; present { - oldValue := pair.Value - pair.Value = value - return oldValue, true - } - - pair := &Pair[K, V]{ - Key: key, - Value: value, - } - pair.element = om.list.PushBack(pair) - om.pairs[key] = pair - - return -} - -// AddPairs allows setting multiple pairs at a time. It's equivalent to calling -// Set on each pair sequentially. -func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) { - for _, pair := range pairs { - om.Set(pair.Key, pair.Value) - } -} - -// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s. -func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) { - return om.Set(key, value) -} - -// Delete removes the key-value pair, and returns what `Get` would have returned -// on that key prior to the call to `Delete`. -func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) { - if pair, present := om.pairs[key]; present { - om.list.Remove(pair.element) - delete(om.pairs, key) - return pair.Value, true - } - return -} - -// Len returns the length of the ordered map. -func (om *OrderedMap[K, V]) Len() int { - if om == nil || om.pairs == nil { - return 0 - } - return len(om.pairs) -} - -// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's -// pairs from the oldest to the newest, e.g.: -// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } -func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] { - if om == nil || om.list == nil { - return nil - } - return listElementToPair(om.list.Front()) -} - -// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's -// pairs from the newest to the oldest, e.g.: -// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } -func (om *OrderedMap[K, V]) Newest() *Pair[K, V] { - if om == nil || om.list == nil { - return nil - } - return listElementToPair(om.list.Back()) -} - -// Next returns a pointer to the next pair. -func (p *Pair[K, V]) Next() *Pair[K, V] { - return listElementToPair(p.element.Next()) -} - -// Prev returns a pointer to the previous pair. -func (p *Pair[K, V]) Prev() *Pair[K, V] { - return listElementToPair(p.element.Prev()) -} - -func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] { - if element == nil { - return nil - } - return element.Value -} - -// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present -// in the map. -type KeyNotFoundError[K comparable] struct { - MissingKey K -} - -func (e *KeyNotFoundError[K]) Error() string { - return fmt.Sprintf("missing key: %v", e.MissingKey) -} - -// MoveAfter moves the value associated with key to its new position after the one associated with markKey. -// Returns an error iff key or markKey are not present in the map. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error { - elements, err := om.getElements(key, markKey) - if err != nil { - return err - } - om.list.MoveAfter(elements[0], elements[1]) - return nil -} - -// MoveBefore moves the value associated with key to its new position before the one associated with markKey. -// Returns an error iff key or markKey are not present in the map. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error { - elements, err := om.getElements(key, markKey) - if err != nil { - return err - } - om.list.MoveBefore(elements[0], elements[1]) - return nil -} - -func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) { - elements := make([]*list.Element[*Pair[K, V]], len(keys)) - for i, k := range keys { - pair, present := om.pairs[k] - if !present { - return nil, &KeyNotFoundError[K]{k} - } - elements[i] = pair.element - } - return elements, nil -} - -// MoveToBack moves the value associated with key to the back of the ordered map, -// i.e. makes it the newest pair in the map. -// Returns an error iff key is not present in the map. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) MoveToBack(key K) error { - _, err := om.GetAndMoveToBack(key) - return err -} - -// MoveToFront moves the value associated with key to the front of the ordered map, -// i.e. makes it the oldest pair in the map. -// Returns an error iff key is not present in the map. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) MoveToFront(key K) error { - _, err := om.GetAndMoveToFront(key) - return err -} - -// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) { - if pair, present := om.pairs[key]; present { - val = pair.Value - om.list.MoveToBack(pair.element) - } else { - err = &KeyNotFoundError[K]{key} - } - - return -} - -// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned, -// it will be a KeyNotFoundError. -func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) { - if pair, present := om.pairs[key]; present { - val = pair.Value - om.list.MoveToFront(pair.element) - } else { - err = &KeyNotFoundError[K]{key} - } - - return -} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/yaml.go b/vendor/github.com/wk8/go-ordered-map/v2/yaml.go deleted file mode 100644 index 602247128..000000000 --- a/vendor/github.com/wk8/go-ordered-map/v2/yaml.go +++ /dev/null @@ -1,71 +0,0 @@ -package orderedmap - -import ( - "fmt" - - "gopkg.in/yaml.v3" -) - -var ( - _ yaml.Marshaler = &OrderedMap[int, any]{} - _ yaml.Unmarshaler = &OrderedMap[int, any]{} -) - -// MarshalYAML implements the yaml.Marshaler interface. -func (om *OrderedMap[K, V]) MarshalYAML() (interface{}, error) { - if om == nil { - return []byte("null"), nil - } - - node := yaml.Node{ - Kind: yaml.MappingNode, - } - - for pair := om.Oldest(); pair != nil; pair = pair.Next() { - key, value := pair.Key, pair.Value - - keyNode := &yaml.Node{} - - // serialize key to yaml, then deserialize it back into the node - // this is a hack to get the correct tag for the key - if err := keyNode.Encode(key); err != nil { - return nil, err - } - - valueNode := &yaml.Node{} - if err := valueNode.Encode(value); err != nil { - return nil, err - } - - node.Content = append(node.Content, keyNode, valueNode) - } - - return &node, nil -} - -// UnmarshalYAML implements the yaml.Unmarshaler interface. -func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error { - if value.Kind != yaml.MappingNode { - return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind) - } - - if om.list == nil { - om.initialize(0) - } - - for index := 0; index < len(value.Content); index += 2 { - var key K - var val V - - if err := value.Content[index].Decode(&key); err != nil { - return err - } - if err := value.Content[index+1].Decode(&val); err != nil { - return err - } - - om.Set(key, val) - } - - return nil -} diff --git a/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go index 51121a3d5..e86346e8b 100644 --- a/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go +++ b/vendor/golang.org/x/oauth2/clientcredentials/clientcredentials.go @@ -55,7 +55,7 @@ type Config struct { // Token uses client credentials to retrieve a token. // -// The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable. func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { return c.TokenSource(ctx).Token() } @@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { // The token will auto-refresh as necessary. // // The provided context optionally controls which HTTP client -// is returned. See the oauth2.HTTPClient variable. +// is returned. See the [oauth2.HTTPClient] variable. // -// The returned Client and its Transport should not be modified. +// The returned [http.Client] and its Transport should not be modified. func (c *Config) Client(ctx context.Context) *http.Client { return oauth2.NewClient(ctx, c.TokenSource(ctx)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [oauth2.TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context and the // client ID and client secret. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { source := &tokenSource{ ctx: ctx, diff --git a/vendor/golang.org/x/oauth2/internal/doc.go b/vendor/golang.org/x/oauth2/internal/doc.go index 03265e888..8c7c475f2 100644 --- a/vendor/golang.org/x/oauth2/internal/doc.go +++ b/vendor/golang.org/x/oauth2/internal/doc.go @@ -2,5 +2,5 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package internal contains support packages for oauth2 package. +// Package internal contains support packages for [golang.org/x/oauth2]. package internal diff --git a/vendor/golang.org/x/oauth2/internal/oauth2.go b/vendor/golang.org/x/oauth2/internal/oauth2.go index 14989beaf..71ea6ad1f 100644 --- a/vendor/golang.org/x/oauth2/internal/oauth2.go +++ b/vendor/golang.org/x/oauth2/internal/oauth2.go @@ -13,7 +13,7 @@ import ( ) // ParseKey converts the binary contents of a private key file -// to an *rsa.PrivateKey. It detects whether the private key is in a +// to an [*rsa.PrivateKey]. It detects whether the private key is in a // PEM container or not. If so, it extracts the private key // from PEM container before conversion. It only supports PEM // containers with no passphrase. diff --git a/vendor/golang.org/x/oauth2/internal/token.go b/vendor/golang.org/x/oauth2/internal/token.go index e83ddeef0..8389f2462 100644 --- a/vendor/golang.org/x/oauth2/internal/token.go +++ b/vendor/golang.org/x/oauth2/internal/token.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "mime" "net/http" @@ -26,9 +25,9 @@ import ( // the requests to access protected resources on the OAuth 2.0 // provider's backend. // -// This type is a mirror of oauth2.Token and exists to break +// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break // an otherwise-circular dependency. Other internal packages -// should convert this Token into an oauth2.Token before use. +// should convert this Token into an [golang.org/x/oauth2.Token] before use. type Token struct { // AccessToken is the token that authorizes and authenticates // the requests. @@ -50,9 +49,16 @@ type Token struct { // mechanisms for that TokenSource will not be used. Expiry time.Time + // ExpiresIn is the OAuth2 wire format "expires_in" field, + // which specifies how many seconds later the token expires, + // relative to an unknown time base approximately around "now". + // It is the application's responsibility to populate + // `Expiry` from `ExpiresIn` when required. + ExpiresIn int64 `json:"expires_in,omitempty"` + // Raw optionally contains extra metadata from the server // when updating a token. - Raw interface{} + Raw any } // tokenJSON is the struct representing the HTTP response from OAuth2 @@ -99,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. -// -// Deprecated: this function no longer does anything. Caller code that -// wants to avoid potential extra HTTP requests made during -// auto-probing of the provider's auth style should set -// Endpoint.AuthStyle. -func RegisterBrokenAuthHeaderProvider(tokenURL string) {} - // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. type AuthStyle int @@ -143,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { return c } +type authStyleCacheKey struct { + url string + clientID string +} + // AuthStyleCache is the set of tokenURLs we've successfully used via // RetrieveToken and which style auth we ended up using. // It's called a cache, but it doesn't (yet?) shrink. It's expected that @@ -150,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { // small. type AuthStyleCache struct { mu sync.Mutex - m map[string]AuthStyle // keyed by tokenURL + m map[authStyleCacheKey]AuthStyle } // lookupAuthStyle reports which auth style we last used with tokenURL // when calling RetrieveToken and whether we have ever done so. -func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { +func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) { c.mu.Lock() defer c.mu.Unlock() - style, ok = c.m[tokenURL] + style, ok = c.m[authStyleCacheKey{tokenURL, clientID}] return } // setAuthStyle adds an entry to authStyleCache, documented above. -func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { +func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) { c.mu.Lock() defer c.mu.Unlock() if c.m == nil { - c.m = make(map[string]AuthStyle) + c.m = make(map[authStyleCacheKey]AuthStyle) } - c.m[tokenURL] = v + c.m[authStyleCacheKey{tokenURL, clientID}] = v } // newTokenRequest returns a new *http.Request to retrieve a new token @@ -210,9 +213,9 @@ func cloneURLValues(v url.Values) url.Values { } func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { - needsAuthStyleProbe := authStyle == 0 + needsAuthStyleProbe := authStyle == AuthStyleUnknown if needsAuthStyleProbe { - if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { + if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok { authStyle = style needsAuthStyleProbe = false } else { @@ -242,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, token, err = doTokenRoundTrip(ctx, req) } if needsAuthStyleProbe && err == nil { - styleCache.setAuthStyle(tokenURL, authStyle) + styleCache.setAuthStyle(tokenURL, clientID, authStyle) } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. @@ -257,7 +260,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { if err != nil { return nil, err } - body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) r.Body.Close() if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -312,7 +315,8 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { TokenType: tj.TokenType, RefreshToken: tj.RefreshToken, Expiry: tj.expiry(), - Raw: make(map[string]interface{}), + ExpiresIn: int64(tj.ExpiresIn), + Raw: make(map[string]any), } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } diff --git a/vendor/golang.org/x/oauth2/internal/transport.go b/vendor/golang.org/x/oauth2/internal/transport.go index b9db01ddf..afc0aeb27 100644 --- a/vendor/golang.org/x/oauth2/internal/transport.go +++ b/vendor/golang.org/x/oauth2/internal/transport.go @@ -9,8 +9,8 @@ import ( "net/http" ) -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate an [*http.Client] value with a context. var HTTPClient ContextKey // ContextKey is just an empty struct. It exists so HTTPClient can be diff --git a/vendor/golang.org/x/oauth2/jws/jws.go b/vendor/golang.org/x/oauth2/jws/jws.go index 95015648b..9bc484406 100644 --- a/vendor/golang.org/x/oauth2/jws/jws.go +++ b/vendor/golang.org/x/oauth2/jws/jws.go @@ -4,7 +4,7 @@ // Package jws provides a partial implementation // of JSON Web Signature encoding and decoding. -// It exists to support the golang.org/x/oauth2 package. +// It exists to support the [golang.org/x/oauth2] package. // // See RFC 7515. // @@ -48,7 +48,7 @@ type ClaimSet struct { // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // This array is marshalled using custom code (see (c *ClaimSet) encode()). - PrivateClaims map[string]interface{} `json:"-"` + PrivateClaims map[string]any `json:"-"` } func (c *ClaimSet) encode() (string, error) { @@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) { // Decode decodes a claim set from a JWS payload. func Decode(payload string) (*ClaimSet, error) { // decode returned id token to get expiry - s := strings.Split(payload, ".") - if len(s) < 2 { + _, claims, _, ok := parseToken(payload) + if !ok { // TODO(jbd): Provide more context about the error. return nil, errors.New("jws: invalid token received") } - decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + decoded, err := base64.RawURLEncoding.DecodeString(claims) if err != nil { return nil, err } @@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) { } // Encode encodes a signed JWS with provided header and claim set. -// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key. +// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key. func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { sg := func(data []byte) (sig []byte, err error) { h := sha256.New() @@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { // Verify tests whether the provided JWT token's signature was produced by the private key // associated with the supplied public key. func Verify(token string, key *rsa.PublicKey) error { - parts := strings.Split(token, ".") - if len(parts) != 3 { + header, claims, sig, ok := parseToken(token) + if !ok { return errors.New("jws: invalid token received, token must have 3 parts") } - - signedContent := parts[0] + "." + parts[1] - signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) + signatureString, err := base64.RawURLEncoding.DecodeString(sig) if err != nil { return err } h := sha256.New() - h.Write([]byte(signedContent)) + h.Write([]byte(header + tokenDelim + claims)) return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) } + +func parseToken(s string) (header, claims, sig string, ok bool) { + header, s, ok = strings.Cut(s, tokenDelim) + if !ok { // no period found + return "", "", "", false + } + claims, s, ok = strings.Cut(s, tokenDelim) + if !ok { // only one period found + return "", "", "", false + } + sig, _, ok = strings.Cut(s, tokenDelim) + if ok { // three periods found + return "", "", "", false + } + return header, claims, sig, true +} + +const tokenDelim = "." diff --git a/vendor/golang.org/x/oauth2/jwt/jwt.go b/vendor/golang.org/x/oauth2/jwt/jwt.go index b2bf18298..38a92daca 100644 --- a/vendor/golang.org/x/oauth2/jwt/jwt.go +++ b/vendor/golang.org/x/oauth2/jwt/jwt.go @@ -13,7 +13,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -69,7 +68,7 @@ type Config struct { // PrivateClaims optionally specifies custom private claims in the JWT. // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 - PrivateClaims map[string]interface{} + PrivateClaims map[string]any // UseIDToken optionally specifies whether ID token should be used instead // of access token when the server returns both. @@ -136,7 +135,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -148,10 +147,8 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } // tokenRes is the JSON response body. var tokenRes struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - IDToken string `json:"id_token"` - ExpiresIn int64 `json:"expires_in"` // relative seconds from now + oauth2.Token + IDToken string `json:"id_token"` } if err := json.Unmarshal(body, &tokenRes); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -160,7 +157,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, } - raw := make(map[string]interface{}) + raw := make(map[string]any) json.Unmarshal(body, &raw) // no error checks for optional fields token = token.WithExtra(raw) diff --git a/vendor/golang.org/x/oauth2/oauth2.go b/vendor/golang.org/x/oauth2/oauth2.go index 74f052aa9..de34feb84 100644 --- a/vendor/golang.org/x/oauth2/oauth2.go +++ b/vendor/golang.org/x/oauth2/oauth2.go @@ -22,9 +22,9 @@ import ( ) // NoContext is the default context you should supply if not using -// your own context.Context (see https://golang.org/x/net/context). +// your own [context.Context]. // -// Deprecated: Use context.Background() or context.TODO() instead. +// Deprecated: Use [context.Background] or [context.TODO] instead. var NoContext = context.TODO() // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. @@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. -// For the client credentials 2-legged OAuth2 flow, see the clientcredentials -// package (https://golang.org/x/oauth2/clientcredentials). +// For the client credentials 2-legged OAuth2 flow, see the +// [golang.org/x/oauth2/clientcredentials] package. type Config struct { // ClientID is the application's ID. ClientID string @@ -46,7 +46,7 @@ type Config struct { // ClientSecret is the application's secret. ClientSecret string - // Endpoint contains the resource server's token endpoint + // Endpoint contains the authorization server's token endpoint // URLs. These are constants specific to each server and are // often available via site-specific packages, such as // google.Endpoint or github.Endpoint. @@ -135,7 +135,7 @@ type setParam struct{ k, v string } func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } -// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters +// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters // to a provider's authorization endpoint. func SetAuthURLParam(key, value string) AuthCodeOption { return setParam{key, value} @@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption { // request and callback. The authorization server includes this value when // redirecting the user agent back to the client. // -// Opts may include AccessTypeOnline or AccessTypeOffline, as well -// as ApprovalForce. +// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well +// as [ApprovalForce]. // // To protect against CSRF attacks, opts should include a PKCE challenge // (S256ChallengeOption). Not all servers support PKCE. An alternative is to @@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { // and when other authorization grant types are not available." // See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { v := url.Values{ "grant_type": {"password"}, @@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor // It is used after a resource provider redirects the user back // to the Redirect URI (the URL obtained from AuthCodeURL). // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. // -// The code will be in the *http.Request.FormValue("code"). Before -// calling Exchange, be sure to validate FormValue("state") if you are +// The code will be in the [http.Request.FormValue]("code"). Before +// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are // using it to protect against CSRF attacks. // // If using PKCE to protect against CSRF attacks, opts should include a @@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client { return NewClient(ctx, c.TokenSource(ctx, t)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { tkr := &tokenRefresher{ ctx: ctx, @@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { } } -// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" +// tokenRefresher is a TokenSource that makes "grant_type=refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { ctx context.Context // used to get HTTP requests @@ -288,7 +288,7 @@ func (tf *tokenRefresher) Token() (*Token, error) { if tf.refreshToken != tk.RefreshToken { tf.refreshToken = tk.RefreshToken } - return tk, err + return tk, nil } // reuseTokenSource is a TokenSource that holds a single token in memory @@ -305,8 +305,7 @@ type reuseTokenSource struct { } // Token returns the current token if it's still valid, else will -// refresh the current token (using r.Context for HTTP client -// information) and return the new one. +// refresh the current token and return the new one. func (s *reuseTokenSource) Token() (*Token, error) { s.mu.Lock() defer s.mu.Unlock() @@ -322,7 +321,7 @@ func (s *reuseTokenSource) Token() (*Token, error) { return t, nil } -// StaticTokenSource returns a TokenSource that always returns the same token. +// StaticTokenSource returns a [TokenSource] that always returns the same token. // Because the provided token t is never refreshed, StaticTokenSource is only // useful for tokens that never expire. func StaticTokenSource(t *Token) TokenSource { @@ -338,16 +337,16 @@ func (s staticTokenSource) Token() (*Token, error) { return s.t, nil } -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate a [*http.Client] value with a context. var HTTPClient internal.ContextKey -// NewClient creates an *http.Client from a Context and TokenSource. +// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource]. // The returned client is not valid beyond the lifetime of the context. // -// Note that if a custom *http.Client is provided via the Context it +// Note that if a custom [*http.Client] is provided via the [context.Context] it // is used only for token acquisition and is not used to configure the -// *http.Client returned from NewClient. +// [*http.Client] returned from NewClient. // // As a special case, if src is nil, a non-OAuth2 client is returned // using the provided context. This exists to support related OAuth2 @@ -356,15 +355,19 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { if src == nil { return internal.ContextClient(ctx) } + cc := internal.ContextClient(ctx) return &http.Client{ Transport: &Transport{ - Base: internal.ContextClient(ctx).Transport, + Base: cc.Transport, Source: ReuseTokenSource(nil, src), }, + CheckRedirect: cc.CheckRedirect, + Jar: cc.Jar, + Timeout: cc.Timeout, } } -// ReuseTokenSource returns a TokenSource which repeatedly returns the +// ReuseTokenSource returns a [TokenSource] which repeatedly returns the // same token as long as it's valid, starting with t. // When its cached token is invalid, a new token is obtained from src. // @@ -372,10 +375,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { // (such as a file on disk) between runs of a program, rather than // obtaining new tokens unnecessarily. // -// The initial token t may be nil, in which case the TokenSource is +// The initial token t may be nil, in which case the [TokenSource] is // wrapped in a caching version if it isn't one already. This also // means it's always safe to wrap ReuseTokenSource around any other -// TokenSource without adverse effects. +// [TokenSource] without adverse effects. func ReuseTokenSource(t *Token, src TokenSource) TokenSource { // Don't wrap a reuseTokenSource in itself. That would work, // but cause an unnecessary number of mutex operations. @@ -393,8 +396,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource { } } -// ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the -// TokenSource returned by ReuseTokenSource, except the expiry buffer is +// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the +// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is // configurable. The expiration time of a token is calculated as // t.Expiry.Add(-earlyExpiry). func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { diff --git a/vendor/golang.org/x/oauth2/pkce.go b/vendor/golang.org/x/oauth2/pkce.go index 50593b6df..cea8374d5 100644 --- a/vendor/golang.org/x/oauth2/pkce.go +++ b/vendor/golang.org/x/oauth2/pkce.go @@ -1,6 +1,7 @@ // Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + package oauth2 import ( @@ -20,9 +21,9 @@ const ( // This follows recommendations in RFC 7636. // // A fresh verifier should be generated for each authorization. -// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL -// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange -// (or Config.DeviceAccessToken). +// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] +// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken] +// with [VerifierOption]. func GenerateVerifier() string { // "RECOMMENDED that the output of a suitable random number generator be // used to create a 32-octet sequence. The octet sequence is then @@ -36,22 +37,22 @@ func GenerateVerifier() string { return base64.RawURLEncoding.EncodeToString(data) } -// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be -// passed to Config.Exchange or Config.DeviceAccessToken only. +// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be +// passed to [Config.Exchange] or [Config.DeviceAccessToken]. func VerifierOption(verifier string) AuthCodeOption { return setParam{k: codeVerifierKey, v: verifier} } // S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256. // -// Prefer to use S256ChallengeOption where possible. +// Prefer to use [S256ChallengeOption] where possible. func S256ChallengeFromVerifier(verifier string) string { sha := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sha[:]) } // S256ChallengeOption derives a PKCE code challenge derived from verifier with -// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess +// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // only. func S256ChallengeOption(verifier string) AuthCodeOption { return challengeOption{ diff --git a/vendor/golang.org/x/oauth2/token.go b/vendor/golang.org/x/oauth2/token.go index 109997d77..239ec3296 100644 --- a/vendor/golang.org/x/oauth2/token.go +++ b/vendor/golang.org/x/oauth2/token.go @@ -44,7 +44,7 @@ type Token struct { // Expiry is the optional expiration time of the access token. // - // If zero, TokenSource implementations will reuse the same + // If zero, [TokenSource] implementations will reuse the same // token forever and RefreshToken or equivalent // mechanisms for that TokenSource will not be used. Expiry time.Time `json:"expiry,omitempty"` @@ -58,7 +58,7 @@ type Token struct { // raw optionally contains extra metadata from the server // when updating a token. - raw interface{} + raw any // expiryDelta is used to calculate when a token is considered // expired, by subtracting from Expiry. If zero, defaultExpiryDelta @@ -86,16 +86,16 @@ func (t *Token) Type() string { // SetAuthHeader sets the Authorization header to r using the access // token in t. // -// This method is unnecessary when using Transport or an HTTP Client +// This method is unnecessary when using [Transport] or an HTTP Client // returned by this package. func (t *Token) SetAuthHeader(r *http.Request) { r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) } -// WithExtra returns a new Token that's a clone of t, but using the +// WithExtra returns a new [Token] that's a clone of t, but using the // provided raw extra map. This is only intended for use by packages // implementing derivative OAuth2 flows. -func (t *Token) WithExtra(extra interface{}) *Token { +func (t *Token) WithExtra(extra any) *Token { t2 := new(Token) *t2 = *t t2.raw = extra @@ -105,8 +105,8 @@ func (t *Token) WithExtra(extra interface{}) *Token { // Extra returns an extra field. // Extra fields are key-value pairs returned by the server as a // part of the token retrieval response. -func (t *Token) Extra(key string) interface{} { - if raw, ok := t.raw.(map[string]interface{}); ok { +func (t *Token) Extra(key string) any { + if raw, ok := t.raw.(map[string]any); ok { return raw[key] } @@ -163,13 +163,14 @@ func tokenFromInternal(t *internal.Token) *Token { TokenType: t.TokenType, RefreshToken: t.RefreshToken, Expiry: t.Expiry, + ExpiresIn: t.ExpiresIn, raw: t.Raw, } } // retrieveToken takes a *Config and uses that to retrieve an *internal.Token. // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along -// with an error.. +// with an error. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { diff --git a/vendor/golang.org/x/oauth2/transport.go b/vendor/golang.org/x/oauth2/transport.go index 90657915f..8bbebbac9 100644 --- a/vendor/golang.org/x/oauth2/transport.go +++ b/vendor/golang.org/x/oauth2/transport.go @@ -11,12 +11,12 @@ import ( "sync" ) -// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests, -// wrapping a base RoundTripper and adding an Authorization header -// with a token from the supplied Sources. +// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests, +// wrapping a base [http.RoundTripper] and adding an Authorization header +// with a token from the supplied [TokenSource]. // // Transport is a low-level mechanism. Most code will use the -// higher-level Config.Client method instead. +// higher-level [Config.Client] method instead. type Transport struct { // Source supplies the token to add to outgoing requests' // Authorization headers. @@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - req2 := cloneRequest(req) // per RoundTripper contract + req2 := req.Clone(req.Context()) token.SetAuthHeader(req2) // req.Body is assumed to be closed by the base RoundTripper. @@ -73,17 +73,3 @@ func (t *Transport) base() http.RoundTripper { } return http.DefaultTransport } - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - return r2 -} diff --git a/vendor/modules.txt b/vendor/modules.txt index 5a9a01d33..0068ffeec 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -153,15 +153,9 @@ github.com/aws/smithy-go/tracing github.com/aws/smithy-go/transport/http github.com/aws/smithy-go/transport/http/internal/io github.com/aws/smithy-go/waiter -# github.com/bahlo/generic-list-go v0.2.0 -## explicit; go 1.18 -github.com/bahlo/generic-list-go # github.com/benbjohnson/clock v1.3.5 ## explicit; go 1.15 github.com/benbjohnson/clock -# github.com/buger/jsonparser v1.1.1 -## explicit; go 1.13 -github.com/buger/jsonparser # github.com/cenkalti/backoff/v4 v4.3.0 ## explicit; go 1.18 github.com/cenkalti/backoff/v4 @@ -239,6 +233,9 @@ github.com/google/go-cmp/cmp/internal/diff github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value +# github.com/google/jsonschema-go v0.3.0 +## explicit; go 1.23.0 +github.com/google/jsonschema-go/jsonschema # github.com/google/uuid v1.6.0 ## explicit github.com/google/uuid @@ -272,9 +269,6 @@ github.com/hashicorp/hcl/json/token # github.com/inconshreveable/mousetrap v1.1.0 ## explicit; go 1.18 github.com/inconshreveable/mousetrap -# github.com/invopop/jsonschema v0.13.0 -## explicit; go 1.18 -github.com/invopop/jsonschema # github.com/jellydator/ttlcache/v3 v3.3.0 ## explicit; go 1.18 github.com/jellydator/ttlcache/v3 @@ -294,15 +288,6 @@ github.com/lufia/plan9stats # github.com/magiconair/properties v1.8.9 ## explicit; go 1.19 github.com/magiconair/properties -# github.com/mailru/easyjson v0.7.7 -## explicit; go 1.12 -github.com/mailru/easyjson/buffer -github.com/mailru/easyjson/jwriter -# github.com/mark3labs/mcp-go v0.43.2 -## explicit; go 1.23.0 -github.com/mark3labs/mcp-go/mcp -github.com/mark3labs/mcp-go/server -github.com/mark3labs/mcp-go/util # github.com/mattn/go-isatty v0.0.20 ## explicit; go 1.15 github.com/mattn/go-isatty @@ -325,6 +310,15 @@ github.com/maypok86/otter/v2/stats # github.com/mitchellh/mapstructure v1.5.0 ## explicit; go 1.14 github.com/mitchellh/mapstructure +# github.com/modelcontextprotocol/go-sdk v1.2.0 +## explicit; go 1.23.0 +github.com/modelcontextprotocol/go-sdk/auth +github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2 +github.com/modelcontextprotocol/go-sdk/internal/util +github.com/modelcontextprotocol/go-sdk/internal/xcontext +github.com/modelcontextprotocol/go-sdk/jsonrpc +github.com/modelcontextprotocol/go-sdk/mcp +github.com/modelcontextprotocol/go-sdk/oauthex # github.com/ncruces/go-strftime v0.1.9 ## explicit; go 1.17 github.com/ncruces/go-strftime @@ -419,9 +413,6 @@ github.com/tklauser/go-sysconf # github.com/tklauser/numcpus v0.11.0 ## explicit; go 1.24.0 github.com/tklauser/numcpus -# github.com/wk8/go-ordered-map/v2 v2.1.8 -## explicit; go 1.18 -github.com/wk8/go-ordered-map/v2 # github.com/yosida95/uritemplate/v3 v3.0.2 ## explicit; go 1.14 github.com/yosida95/uritemplate/v3 @@ -566,8 +557,8 @@ golang.org/x/net/idna golang.org/x/net/internal/httpcommon golang.org/x/net/internal/timeseries golang.org/x/net/trace -# golang.org/x/oauth2 v0.26.0 -## explicit; go 1.18 +# golang.org/x/oauth2 v0.30.0 +## explicit; go 1.23.0 golang.org/x/oauth2 golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/internal