diff --git a/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategy.cs b/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategy.cs index f660392048..0591da3a55 100644 --- a/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategy.cs +++ b/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategy.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.Reflection; using Azure.Mcp.Core.Areas.Server.Models; using Azure.Mcp.Core.Areas.Server.Options; using Microsoft.Extensions.Logging; @@ -15,14 +14,16 @@ namespace Azure.Mcp.Core.Areas.Server.Commands.Discovery; /// /// Options for configuring the service behavior. /// Logger instance for this discovery strategy. -public sealed class RegistryDiscoveryStrategy(IOptions options, ILogger logger) : BaseDiscoveryStrategy(logger) +/// Factory that can create HttpClient objects. +/// Manifest of all the MCP server registries. +public sealed class RegistryDiscoveryStrategy(IOptions options, ILogger logger, IHttpClientFactory httpClientFactory, IRegistryRoot registryRoot) : BaseDiscoveryStrategy(logger) { private readonly IOptions _options = options; + private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; /// public override async Task> DiscoverServersAsync(CancellationToken cancellationToken) { - var registryRoot = await LoadRegistryAsync(); if (registryRoot == null) { return []; @@ -33,40 +34,7 @@ public override async Task> DiscoverServersAsync .Where(s => _options.Value.Namespace == null || _options.Value.Namespace.Length == 0 || _options.Value.Namespace.Contains(s.Key, StringComparer.OrdinalIgnoreCase)) - .Select(s => new RegistryServerProvider(s.Key, s.Value)) + .Select(s => new RegistryServerProvider(s.Key, s.Value, _httpClientFactory)) .Cast(); } - - /// - /// Loads the registry configuration from the embedded resource file. - /// - /// The deserialized registry root containing server configurations, or null if not found. - private async Task LoadRegistryAsync() - { - var assembly = Assembly.GetExecutingAssembly(); - var resourceName = assembly - .GetManifestResourceNames() - .FirstOrDefault(n => n.EndsWith("registry.json", StringComparison.OrdinalIgnoreCase)); - - if (resourceName is null) - { - return null; - } - - await using var stream = assembly.GetManifestResourceStream(resourceName)!; - var registry = await JsonSerializer.DeserializeAsync(stream, ServerJsonContext.Default.RegistryRoot); - - if (registry?.Servers != null) - { - foreach (var kvp in registry.Servers) - { - if (kvp.Value != null) - { - kvp.Value.Name = kvp.Key; - } - } - } - - return registry; - } } diff --git a/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryServerProvider.cs b/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryServerProvider.cs index dac744e633..4047c34c28 100644 --- a/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryServerProvider.cs +++ b/core/Azure.Mcp.Core/src/Areas/Server/Commands/Discovery/RegistryServerProvider.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using Azure.Mcp.Core.Areas.Server.Models; +using Azure.Mcp.Core.Helpers; using ModelContextProtocol.Client; namespace Azure.Mcp.Core.Areas.Server.Commands.Discovery; @@ -12,10 +13,13 @@ namespace Azure.Mcp.Core.Areas.Server.Commands.Discovery; /// /// The unique identifier for the server. /// Configuration information for the server. -public sealed class RegistryServerProvider(string id, RegistryServerInfo serverInfo) : IMcpServerProvider +/// Factory for creating HTTP clients. +/// The token credential provider for OAuth authentication. +public sealed class RegistryServerProvider(string id, RegistryServerInfo serverInfo, IHttpClientFactory httpClientFactory) : IMcpServerProvider { private readonly string _id = id; private readonly RegistryServerInfo _serverInfo = serverInfo; + private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; /// /// Creates metadata that describes this registry-based server. @@ -37,7 +41,6 @@ public async Task CreateClientAsync(McpClientOptions clientOptions, C { Func>? clientFactory = null; - // Determine which factory function to use based on configuration if (!string.IsNullOrWhiteSpace(_serverInfo.Url)) { clientFactory = CreateHttpClientAsync; @@ -88,8 +91,24 @@ private async Task CreateHttpClientAsync(McpClientOptions clientOptio Name = _id, Endpoint = new Uri(_serverInfo.Url!), TransportMode = HttpTransportMode.AutoDetect, + // HttpClientTransportOptions offers an OAuth property to configure client side OAuth parameters, such as RedirectUri and ClientId. + // When OAuth property is set, the MCP client will attempt to complete the Auth flow following the MCP protocol. + // However, there is a gap between what MCP protocol requires the OAuth provider to implement and what Entra supports. This MCP client will always send a resource parameter to the token endpoint because it is required by the MCP protocol but Entra doesn't support it. More details in issue #939 and related discussions in modelcontextprotocol/csharp-sdk GitHub repo. }; - var clientTransport = new HttpClientTransport(transportOptions); + + HttpClientTransport clientTransport; + if (_serverInfo.OAuthScopes is not null) + { + // Registry servers with OAuthScopes must create HttpClient with this key to create an HttpClient that knows how to fetch its access tokens. + // The HttpClients are registered in RegistryServerServiceCollectionExtensions.cs. + var client = _httpClientFactory.CreateClient(RegistryServerHelper.GetRegistryServerHttpClientName(_serverInfo.Name!)); + clientTransport = new HttpClientTransport(transportOptions, client, ownsHttpClient: true); + } + else + { + clientTransport = new HttpClientTransport(transportOptions); + } + return await McpClient.CreateAsync(clientTransport, clientOptions, cancellationToken: cancellationToken); } diff --git a/core/Azure.Mcp.Core/src/Areas/Server/Models/IRegistryRoot.cs b/core/Azure.Mcp.Core/src/Areas/Server/Models/IRegistryRoot.cs new file mode 100644 index 0000000000..8c633d8516 --- /dev/null +++ b/core/Azure.Mcp.Core/src/Areas/Server/Models/IRegistryRoot.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Core.Areas.Server.Models; + +/// +/// Represents the root structure of the MCP server registry JSON file. +/// Contains a collection of server configurations keyed by server name. +/// +public interface IRegistryRoot +{ + /// + /// Gets the dictionary of server configurations, keyed by server name. + /// + public Dictionary? Servers { get; init; } +} diff --git a/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryRoot.cs b/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryRoot.cs index 9116486065..65395109ed 100644 --- a/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryRoot.cs +++ b/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryRoot.cs @@ -9,7 +9,7 @@ namespace Azure.Mcp.Core.Areas.Server.Models; /// Represents the root structure of the MCP server registry JSON file. /// Contains a collection of server configurations keyed by server name. /// -public sealed class RegistryRoot +public sealed class RegistryRoot : IRegistryRoot { /// /// Gets the dictionary of server configurations, keyed by server name. diff --git a/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryServerInfo.cs b/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryServerInfo.cs index c17519da55..6a2375c84e 100644 --- a/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryServerInfo.cs +++ b/core/Azure.Mcp.Core/src/Areas/Server/Models/RegistryServerInfo.cs @@ -19,11 +19,19 @@ public sealed class RegistryServerInfo public string? Name { get; set; } /// - /// Gets the URL endpoint (deprecated - no longer used). + /// Gets the URL of the remote server. + /// This should be if is "stdio". /// [JsonPropertyName("url")] public string? Url { get; init; } + /// + /// Gets OAuth scopes to request in the access token. + /// Used for remote MCP servers protected by OAuth. + /// + [JsonPropertyName("oauthScopes")] + public string[]? OAuthScopes { get; init; } + /// /// Gets a description of the server's purpose or capabilities. /// @@ -38,6 +46,7 @@ public sealed class RegistryServerInfo /// /// Gets the transport type, e.g., "stdio". + /// This should be if is non-. /// [JsonPropertyName("type")] public string? Type { get; init; } diff --git a/core/Azure.Mcp.Core/src/Areas/Server/RegistryServerServiceCollectionExtensions.cs b/core/Azure.Mcp.Core/src/Areas/Server/RegistryServerServiceCollectionExtensions.cs new file mode 100644 index 0000000000..2a3ab39d7f --- /dev/null +++ b/core/Azure.Mcp.Core/src/Areas/Server/RegistryServerServiceCollectionExtensions.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Mcp.Core.Areas.Server.Models; +using Azure.Mcp.Core.Helpers; +using Azure.Mcp.Core.Services.Azure.Authentication; +using Microsoft.Extensions.DependencyInjection; + +namespace Azure.Mcp.Core.Areas.Server; + +/// +/// Extension methods for configuring RegistryServer services. +/// +public static class RegistryServerServiceCollectionExtensions +{ + /// + /// Add HttpClient for each registry server with OAuthScopes that knows how to fetch its access token. + /// + public static IServiceCollection AddRegistryRoot(this IServiceCollection services) + { + var registry = RegistryServerHelper.GetRegistryRoot(); + if (registry?.Servers is null) + { + return services; + } + + foreach (var kvp in registry.Servers) + { + if (kvp.Value is not null) + { + // Set the name of the server for easier access + kvp.Value.Name = kvp.Key; + } + + if (kvp.Value is null || string.IsNullOrWhiteSpace(kvp.Value.Url) || kvp.Value.OAuthScopes is null) + { + continue; + } + + var serverName = kvp.Key; + var serverUrl = kvp.Value.Url; + var oauthScopes = kvp.Value.OAuthScopes; + if (oauthScopes.Length == 0) + { + continue; + } + + services.AddHttpClient(RegistryServerHelper.GetRegistryServerHttpClientName(serverName)) + .AddHttpMessageHandler((services) => + { + var tokenCredentialProvider = services.GetRequiredService(); + return new AccessTokenHandler(tokenCredentialProvider, oauthScopes); + }); + } + + services.AddSingleton(registry); + + return services; + } +} diff --git a/core/Azure.Mcp.Core/src/Helpers/RegistryServerHelper.cs b/core/Azure.Mcp.Core/src/Helpers/RegistryServerHelper.cs new file mode 100644 index 0000000000..5d8c477881 --- /dev/null +++ b/core/Azure.Mcp.Core/src/Helpers/RegistryServerHelper.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net.Http.Headers; +using System.Reflection; +using Azure.Core; +using Azure.Mcp.Core.Areas.Server; +using Azure.Mcp.Core.Areas.Server.Commands.Discovery; +using Azure.Mcp.Core.Areas.Server.Models; +using Azure.Mcp.Core.Services.Azure.Authentication; + +namespace Azure.Mcp.Core.Helpers +{ + /// + /// DelegatingHandler that adds a Bearer access token to each outgoing request. + /// + public sealed class AccessTokenHandler : DelegatingHandler + { + private readonly IAzureTokenCredentialProvider _tokenCredentialProvider; + private readonly string[] _oauthScopes; + + public AccessTokenHandler(IAzureTokenCredentialProvider tokenCredentialProvider, string[] oauthScopes) + { + _tokenCredentialProvider = tokenCredentialProvider; + _oauthScopes = oauthScopes; + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + var credential = await _tokenCredentialProvider.GetTokenCredentialAsync(tenantId: null, cancellationToken); + var tokenContext = new TokenRequestContext(_oauthScopes); + var token = await credential.GetTokenAsync(tokenContext, cancellationToken); + if (!string.IsNullOrEmpty(token.Token)) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token.Token); + } + return await base.SendAsync(request, cancellationToken); + } + } + + public sealed class RegistryServerHelper + { + public static string GetRegistryServerHttpClientName(string serverName) + { + return $"azmcp.{nameof(RegistryServerProvider)}.{serverName}"; + } + + public static IRegistryRoot? GetRegistryRoot() + { + var assembly = Assembly.GetExecutingAssembly(); + var resourceName = assembly + .GetManifestResourceNames() + .FirstOrDefault(n => n.EndsWith("registry.json", StringComparison.OrdinalIgnoreCase)); + if (resourceName is null) + { + return null; + } + + using var stream = assembly.GetManifestResourceStream(resourceName); + if (stream is null) + { + return null; + } + var registry = JsonSerializer.Deserialize(stream, ServerJsonContext.Default.RegistryRoot); + if (registry?.Servers is null) + { + return null; + } + + return registry; + } + } +} diff --git a/core/Azure.Mcp.Core/src/Services/Http/HttpClientFactoryConfigurator.cs b/core/Azure.Mcp.Core/src/Services/Http/HttpClientFactoryConfigurator.cs index 802fa0074b..442ec4ed05 100644 --- a/core/Azure.Mcp.Core/src/Services/Http/HttpClientFactoryConfigurator.cs +++ b/core/Azure.Mcp.Core/src/Services/Http/HttpClientFactoryConfigurator.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System; -using System.Linq; using System.Net; using System.Reflection; using System.Runtime.InteropServices; diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategyTests.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategyTests.cs index ab622f9de8..1a53e94c46 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategyTests.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryDiscoveryStrategyTests.cs @@ -3,6 +3,7 @@ using Azure.Mcp.Core.Areas.Server.Commands.Discovery; using Azure.Mcp.Core.Areas.Server.Options; +using Azure.Mcp.Core.Helpers; using Xunit; namespace Azure.Mcp.Core.UnitTests.Areas.Server.Commands.Discovery; @@ -13,7 +14,9 @@ private static RegistryDiscoveryStrategy CreateStrategy(ServiceStartOptions? opt { var serviceOptions = Microsoft.Extensions.Options.Options.Create(options ?? new ServiceStartOptions()); var logger = NSubstitute.Substitute.For>(); - return new RegistryDiscoveryStrategy(serviceOptions, logger); + var httpClientFactory = NSubstitute.Substitute.For(); + var registryRoot = RegistryServerHelper.GetRegistryRoot(); + return new RegistryDiscoveryStrategy(serviceOptions, logger, httpClientFactory, registryRoot!); } [Fact] diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryServerProviderTests.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryServerProviderTests.cs index ec3129e792..27b68ee830 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryServerProviderTests.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/Discovery/RegistryServerProviderTests.cs @@ -6,12 +6,20 @@ using Azure.Mcp.Core.Areas.Server.Commands.Discovery; using Azure.Mcp.Core.Areas.Server.Models; using ModelContextProtocol.Client; +using NSubstitute; using Xunit; namespace Azure.Mcp.Core.UnitTests.Areas.Server.Commands.Discovery; public class RegistryServerProviderTests { + private static RegistryServerProvider CreateServerProvider(string id, RegistryServerInfo serverInfo) + { + var httpClientFactory = Substitute.For(); + httpClientFactory.CreateClient(Arg.Any()) + .Returns(Substitute.For()); + return new RegistryServerProvider(id, serverInfo, httpClientFactory); + } [Fact] public void Constructor_InitializesCorrectly() { @@ -23,7 +31,7 @@ public void Constructor_InitializesCorrectly() }; // Act - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Assert Assert.NotNull(provider); @@ -39,7 +47,7 @@ public void CreateMetadata_ReturnsExpectedMetadata() { Description = "Test Description" }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act var metadata = provider.CreateMetadata(); @@ -61,7 +69,7 @@ public void CreateMetadata_EmptyDescription_ReturnsEmptyString() { Description = null }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act var metadata = provider.CreateMetadata(); @@ -85,7 +93,7 @@ public void CreateMetadata_WithTitle_ReturnsTitleInMetadata() Title = testTitle, Description = "Test Description" }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act var metadata = provider.CreateMetadata(); @@ -109,7 +117,7 @@ public void CreateMetadata_WithTitle_ReturnsTitleInMetadata() // Description = "Test SSE Provider", // Url = $"{server.Endpoint}/mcp" // }; - // var provider = new RegistryServerProvider(testId, serverInfo); + // var provider = CreateServerProvider(testId, serverInfo); // // Act & Assert // var exception = await Assert.ThrowsAsync( @@ -130,7 +138,7 @@ public async Task CreateClientAsync_WithStdioType_CreatesStdioClient() Command = "echo", Args = ["hello world"] }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act & Assert - Should throw InvalidOperationException for subprocess startup failure // since configuration is valid but external process fails to start properly @@ -156,7 +164,7 @@ public async Task CreateClientAsync_WithEnvVariables_MergesWithSystemEnvironment { "TEST_VAR", "test value" } } }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act & Assert - Should throw InvalidOperationException for subprocess startup failure // since configuration is valid but external process fails to start properly @@ -176,7 +184,7 @@ public async Task CreateClientAsync_NoUrlOrType_ThrowsArgumentException() Description = "Invalid Provider - No Transport" // No Url or Type specified }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act & Assert var exception = await Assert.ThrowsAsync( @@ -197,7 +205,7 @@ public async Task CreateClientAsync_StdioWithoutCommand_ThrowsInvalidOperationEx Type = "stdio" // No Command specified }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act & Assert var exception = await Assert.ThrowsAsync( @@ -221,7 +229,7 @@ public async Task CreateClientAsync_WithInstallInstructions_IncludesInstructions Args = ["--serve"], InstallInstructions = installInstructions }; - var provider = new RegistryServerProvider(testId, serverInfo); + var provider = CreateServerProvider(testId, serverInfo); // Act & Assert - Should throw InvalidOperationException with install instructions var exception = await Assert.ThrowsAsync( diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ServiceCollectionExtensionsTests.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ServiceCollectionExtensionsTests.cs index ad16dcdf40..51d8014695 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ServiceCollectionExtensionsTests.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ServiceCollectionExtensionsTests.cs @@ -1,12 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.Mcp.Core.Areas.Server; using Azure.Mcp.Core.Areas.Server.Commands; using Azure.Mcp.Core.Areas.Server.Commands.Discovery; using Azure.Mcp.Core.Areas.Server.Commands.Runtime; using Azure.Mcp.Core.Areas.Server.Commands.ToolLoading; using Azure.Mcp.Core.Areas.Server.Options; using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Services.Azure.Authentication; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; using ModelContextProtocol.Server; @@ -22,7 +24,8 @@ private IServiceCollection SetupBaseServices() { var services = CommandFactoryHelpers.SetupCommonServices(); services.AddSingleton(sp => CommandFactoryHelpers.CreateCommandFactory(sp)); - + services.AddSingleIdentityTokenCredentialProvider(); + services.AddRegistryRoot(); return services; } diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/ServerToolLoaderTests.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/ServerToolLoaderTests.cs index 027ea726cc..0e44e41311 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/ServerToolLoaderTests.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/ServerToolLoaderTests.cs @@ -5,6 +5,7 @@ using Azure.Mcp.Core.Areas.Server.Commands.Discovery; using Azure.Mcp.Core.Areas.Server.Commands.ToolLoading; using Azure.Mcp.Core.Areas.Server.Options; +using Azure.Mcp.Core.Helpers; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; @@ -49,6 +50,14 @@ private static ModelContextProtocol.Server.RequestContext }; } + private static RegistryDiscoveryStrategy CreateStrategy(ServiceStartOptions options, ILogger logger) + { + var serviceOptions = Microsoft.Extensions.Options.Options.Create(options ?? new ServiceStartOptions()); + var httpClientFactory = Substitute.For(); + var registryRoot = RegistryServerHelper.GetRegistryRoot(); + return new RegistryDiscoveryStrategy(serviceOptions, logger, httpClientFactory, registryRoot!); + } + [Fact] public async Task CallToolHandler_WithoutListToolsFirst_ShouldSucceed() { @@ -58,7 +67,7 @@ public async Task CallToolHandler_WithoutListToolsFirst_ShouldSucceed() var serviceStartOptions = Microsoft.Extensions.Options.Options.Create(new ServiceStartOptions()); var toolLoaderOptions = Microsoft.Extensions.Options.Options.Create(new ToolLoaderOptions()); var discoveryLogger = loggerFactory.CreateLogger(); - var discoveryStrategy = new RegistryDiscoveryStrategy(serviceStartOptions, discoveryLogger); + var discoveryStrategy = CreateStrategy(serviceStartOptions.Value, discoveryLogger); var logger = loggerFactory.CreateLogger(); var toolLoader = new ServerToolLoader(discoveryStrategy, toolLoaderOptions, logger); @@ -112,7 +121,7 @@ public async Task ListToolsHandler_WithRealRegistryDiscovery_ReturnsExpectedStru var serviceStartOptions = Microsoft.Extensions.Options.Options.Create(new ServiceStartOptions()); var toolLoaderOptions = Microsoft.Extensions.Options.Options.Create(new ToolLoaderOptions()); var discoveryLogger = loggerFactory.CreateLogger(); - var discoveryStrategy = new RegistryDiscoveryStrategy(serviceStartOptions, discoveryLogger); + var discoveryStrategy = CreateStrategy(serviceStartOptions.Value, discoveryLogger); var logger = loggerFactory.CreateLogger(); var toolLoader = new ServerToolLoader(discoveryStrategy, toolLoaderOptions, logger); diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/SingleProxyToolLoaderTests.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/SingleProxyToolLoaderTests.cs index 8d69dbe4ed..f5281320f6 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/SingleProxyToolLoaderTests.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/Commands/ToolLoading/SingleProxyToolLoaderTests.cs @@ -5,6 +5,7 @@ using Azure.Mcp.Core.Areas.Server.Commands.Discovery; using Azure.Mcp.Core.Areas.Server.Commands.ToolLoading; using Azure.Mcp.Core.Areas.Server.Options; +using Azure.Mcp.Core.Helpers; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; @@ -15,6 +16,14 @@ namespace Azure.Mcp.Core.UnitTests.Areas.Server.Commands.ToolLoading; public class SingleProxyToolLoaderTests { + private static RegistryDiscoveryStrategy CreateStrategy(ServiceStartOptions options, ILogger logger) + { + var serviceOptions = Microsoft.Extensions.Options.Options.Create(options ?? new ServiceStartOptions()); + var httpClientFactory = Substitute.For(); + var registryRoot = RegistryServerHelper.GetRegistryRoot(); + return new RegistryDiscoveryStrategy(serviceOptions, logger, httpClientFactory, registryRoot!); + } + private static (SingleProxyToolLoader toolLoader, IMcpDiscoveryStrategy discoveryStrategy) CreateToolLoader(bool useRealDiscovery = true) { var serviceProvider = CommandFactoryHelpers.CreateDefaultServiceProvider(); @@ -31,7 +40,7 @@ private static (SingleProxyToolLoader toolLoader, IMcpDiscoveryStrategy discover commandGroupLogger ); var registryLogger = serviceProvider.GetRequiredService>(); - var registryDiscoveryStrategy = new RegistryDiscoveryStrategy(options, registryLogger); + var registryDiscoveryStrategy = CreateStrategy(options.Value, registryLogger); var compositeLogger = serviceProvider.GetRequiredService>(); var compositeDiscoveryStrategy = new CompositeDiscoveryStrategy([ commandGroupDiscoveryStrategy, diff --git a/servers/Azure.Mcp.Server/src/Program.cs b/servers/Azure.Mcp.Server/src/Program.cs index cc3ddb8d52..3b73936c0c 100644 --- a/servers/Azure.Mcp.Server/src/Program.cs +++ b/servers/Azure.Mcp.Server/src/Program.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Net; +using Azure.Mcp.Core.Areas.Server; using Azure.Mcp.Core.Areas.Server.Commands; using Azure.Mcp.Core.Commands; using Azure.Mcp.Core.Extensions; @@ -214,6 +215,8 @@ internal static void ConfigureServices(IServiceCollection services) services.AddSingleton(area); area.ConfigureServices(services); } + + services.AddRegistryRoot(); } internal static async Task InitializeServicesAsync(IServiceProvider serviceProvider)