From cdaf28a6f6dc7eb52806140a2c3998e4c69f0232 Mon Sep 17 00:00:00 2001 From: amandoabreu Date: Tue, 9 May 2023 11:15:11 +0200 Subject: [PATCH] Add Oauth authentication for Azure AD SSO --- .../Models/TestConfiguration.cs | 5 ++ .../Snowflake.Client.Tests.csproj | 2 + .../UnitTests/AzureAdAuthInfoTest.cs | 34 +++++++++++++ .../UnitTests/AzureAdTokenProviderTest.cs | 32 ++++++++++++ Snowflake.Client/AzureAdTokenProvider.cs | 41 +++++++++++++++ Snowflake.Client/IAzureAdTokenProvider.cs | 11 ++++ Snowflake.Client/Model/AuthInfo.cs | 2 +- Snowflake.Client/Model/AzureAdAuthInfo.cs | 29 +++++++++++ Snowflake.Client/Model/IAuthInfo.cs | 11 ++++ Snowflake.Client/RequestBuilder.cs | 28 ++++++---- Snowflake.Client/Snowflake.Client.csproj | 4 ++ Snowflake.Client/SnowflakeClient.cs | 51 +++++++++++++++++++ 12 files changed, 239 insertions(+), 11 deletions(-) create mode 100644 Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs create mode 100644 Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs create mode 100644 Snowflake.Client/AzureAdTokenProvider.cs create mode 100644 Snowflake.Client/IAzureAdTokenProvider.cs create mode 100644 Snowflake.Client/Model/AzureAdAuthInfo.cs create mode 100644 Snowflake.Client/Model/IAuthInfo.cs diff --git a/Snowflake.Client.Tests/Models/TestConfiguration.cs b/Snowflake.Client.Tests/Models/TestConfiguration.cs index 1fda5df..f8c052e 100644 --- a/Snowflake.Client.Tests/Models/TestConfiguration.cs +++ b/Snowflake.Client.Tests/Models/TestConfiguration.cs @@ -3,5 +3,10 @@ public class TestConfiguration { public SnowflakeConnectionInfo Connection { get; set; } + public string AdClientId { get; set; } + public string AdClientSecret { get; set; } + public string AdServicePrincipalObjectId { get; set; } + public string AdTenantId { get; set; } + public string AdScope { get; set; } } } diff --git a/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj b/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj index e1ce1f6..c7f4ed3 100644 --- a/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj +++ b/Snowflake.Client.Tests/Snowflake.Client.Tests.csproj @@ -7,6 +7,8 @@ + + diff --git a/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs b/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs new file mode 100644 index 0000000..2100e91 --- /dev/null +++ b/Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs @@ -0,0 +1,34 @@ +using System; +using NUnit.Framework; +using Snowflake.Client.Tests.Models; +using Snowflake.Client.Model; +using System.IO; +using System.Text.Json; + +namespace Snowflake.Client.Tests.IntegrationTests +{ + [TestFixture] + public class AzureAdAuthInfoTests + { + protected readonly AzureAdAuthInfo _azureAdAuthInfo; + + public AzureAdAuthInfoTests() + { + var configJson = File.ReadAllText("testconfig.json"); + var testParameters = JsonSerializer.Deserialize(configJson, new JsonSerializerOptions() { PropertyNameCaseInsensitive = true }); + var connectionInfo = testParameters.Connection; + + _azureAdAuthInfo = new AzureAdAuthInfo( + testParameters.AdClientId, + testParameters.AdClientSecret, + testParameters.AdServicePrincipalObjectId, + testParameters.AdTenantId, + testParameters.AdScope, + connectionInfo.Region, + connectionInfo.Account, + connectionInfo.User, + connectionInfo.Host, + connectionInfo.Role); + } + } +} diff --git a/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs b/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs new file mode 100644 index 0000000..55d3050 --- /dev/null +++ b/Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs @@ -0,0 +1,32 @@ +using Microsoft.Identity.Client; +using Moq; +using NUnit.Framework; +using Snowflake.Client; +using Snowflake.Client.Model; +using Snowflake.Client.Tests.IntegrationTests; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Snowflake.Client.Tests +{ + public class AzureAdTokenProviderTests : AzureAdAuthInfoTests + { + [Test] + public async Task GetAzureAdAccessTokenAsync_ReturnsAccessToken() + { + var expectedAccessToken = "accessToken"; + var mockTokenProvider = new Mock(); + + mockTokenProvider + .Setup(provider => provider.GetAzureAdAccessTokenAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(expectedAccessToken); + + // Act + string actualAccessToken = await mockTokenProvider.Object.GetAzureAdAccessTokenAsync(_azureAdAuthInfo); + + // Assert + Assert.AreEqual(expectedAccessToken, actualAccessToken); + } + } +} diff --git a/Snowflake.Client/AzureAdTokenProvider.cs b/Snowflake.Client/AzureAdTokenProvider.cs new file mode 100644 index 0000000..20b7196 --- /dev/null +++ b/Snowflake.Client/AzureAdTokenProvider.cs @@ -0,0 +1,41 @@ +using Microsoft.Identity.Client; +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Client.Model; + +namespace Snowflake.Client +{ + public class AzureAdTokenProvider : IAzureAdTokenProvider + { + public async Task GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default) + { + try + { + if (authInfo.ClientId == null || authInfo.ClientSecret == null || authInfo.ServicePrincipalObjectId == null || authInfo.TenantId == null || authInfo.Scope == null) + { + throw new SnowflakeException("Error: One or more required environment variables are missing.", 400); + } + + return await GetAccessTokenAsync(authInfo.ClientId, authInfo.ClientSecret, authInfo.ServicePrincipalObjectId, authInfo.TenantId, authInfo.Scope); + } + catch (Exception ex) + { + throw new SnowflakeException($"Failed getting the Azure Token. Message: {ex.Message}", ex); + } + } + + private async Task GetAccessTokenAsync(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope) + { + IConfidentialClientApplication app = ConfidentialClientApplicationBuilder.Create(clientId) + .WithClientSecret(clientSecret) + .WithAuthority(new Uri($"https://login.microsoftonline.com/{tenantId}/")) + .Build(); + + var scopes = new[] { scope }; + + AuthenticationResult result = await app.AcquireTokenForClient(scopes).ExecuteAsync(); + return result.AccessToken; + } + } +} \ No newline at end of file diff --git a/Snowflake.Client/IAzureAdTokenProvider.cs b/Snowflake.Client/IAzureAdTokenProvider.cs new file mode 100644 index 0000000..d5a4386 --- /dev/null +++ b/Snowflake.Client/IAzureAdTokenProvider.cs @@ -0,0 +1,11 @@ +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Client.Model; + +namespace Snowflake.Client +{ + public interface IAzureAdTokenProvider + { + Task GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default); + } +} \ No newline at end of file diff --git a/Snowflake.Client/Model/AuthInfo.cs b/Snowflake.Client/Model/AuthInfo.cs index d0e0ba2..fec3d73 100644 --- a/Snowflake.Client/Model/AuthInfo.cs +++ b/Snowflake.Client/Model/AuthInfo.cs @@ -3,7 +3,7 @@ /// /// Snowflake Authentication information. /// - public class AuthInfo + public class AuthInfo : IAuthInfo { /// /// Your Snowflake account name diff --git a/Snowflake.Client/Model/AzureAdAuthInfo.cs b/Snowflake.Client/Model/AzureAdAuthInfo.cs new file mode 100644 index 0000000..45eb805 --- /dev/null +++ b/Snowflake.Client/Model/AzureAdAuthInfo.cs @@ -0,0 +1,29 @@ +namespace Snowflake.Client.Model +{ + public class AzureAdAuthInfo : AuthInfo + { + public string ClientId { get; set; } + public string ClientSecret { get; set; } + public string ServicePrincipalObjectId { get; set; } + public string TenantId { get; set; } + public string Scope { get; set; } + public string Host {get; set; } + public string Role {get; set; } + + + public AzureAdAuthInfo(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role) + : base(user, account, region) + { + ClientId = clientId; + ClientSecret = clientSecret; + ServicePrincipalObjectId = servicePrincipalObjectId; + TenantId = tenantId; + Scope = scope; + Region = region; + Account = account; + User = user; + Host = host; + Role = role; + } + } +} \ No newline at end of file diff --git a/Snowflake.Client/Model/IAuthInfo.cs b/Snowflake.Client/Model/IAuthInfo.cs new file mode 100644 index 0000000..8976247 --- /dev/null +++ b/Snowflake.Client/Model/IAuthInfo.cs @@ -0,0 +1,11 @@ +namespace Snowflake.Client.Model +{ + public interface IAuthInfo + { + string Account { get; set; } + string User { get; set; } + string Region { get; set; } + + string ToString(); + } +} \ No newline at end of file diff --git a/Snowflake.Client/RequestBuilder.cs b/Snowflake.Client/RequestBuilder.cs index 04fff78..96bccbb 100644 --- a/Snowflake.Client/RequestBuilder.cs +++ b/Snowflake.Client/RequestBuilder.cs @@ -51,19 +51,27 @@ internal void ClearSessionTokens() _masterToken = null; } - internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo) + internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo, String azureAdAccessToken = null) { var requestUri = BuildLoginUrl(sessionInfo); + var data = new LoginRequestData(); + + if (authInfo is AzureAdAuthInfo azureAdAuthInfo) { + data = new LoginRequestData() { + Authenticator = "OAUTH", + Token = azureAdAccessToken, + }; + } else { + data = new LoginRequestData() { + Password = authInfo.Password, + }; + } - var data = new LoginRequestData() - { - LoginName = authInfo.User, - Password = authInfo.Password, - AccountName = authInfo.Account, - ClientAppId = _clientInfo.DriverName, - ClientAppVersion = _clientInfo.DriverVersion, - ClientEnvironment = _clientInfo.Environment - }; + data.LoginName = authInfo.User; + data.AccountName = authInfo.Account; + data.ClientAppId = _clientInfo.DriverName; + data.ClientAppVersion = _clientInfo.DriverVersion; + data.ClientEnvironment = _clientInfo.Environment; var requestBody = new LoginRequest() { Data = data }; var jsonBody = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions); diff --git a/Snowflake.Client/Snowflake.Client.csproj b/Snowflake.Client/Snowflake.Client.csproj index 1b235d2..4ceab60 100644 --- a/Snowflake.Client/Snowflake.Client.csproj +++ b/Snowflake.Client/Snowflake.Client.csproj @@ -32,4 +32,8 @@ Provides straightforward and efficient way to execute SQL queries in Snowflake a + + + + diff --git a/Snowflake.Client/SnowflakeClient.cs b/Snowflake.Client/SnowflakeClient.cs index b1b6272..616dc00 100644 --- a/Snowflake.Client/SnowflakeClient.cs +++ b/Snowflake.Client/SnowflakeClient.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Identity.Client; namespace Snowflake.Client { @@ -23,11 +24,41 @@ public class SnowflakeClient : ISnowflakeClient /// public SnowflakeClientSettings Settings => _clientSettings; + /// + /// Azure AD Token Provider + /// + private readonly AzureAdTokenProvider _azureAdTokenProvider; + private SnowflakeSession _snowflakeSession; private readonly RestClient _restClient; private readonly RequestBuilder _requestBuilder; private readonly SnowflakeClientSettings _clientSettings; + /// + /// Creates new Snowflake client. + /// + /// Client ID + /// Client Secret + /// Service Principal Object ID + /// Tenant ID + /// Scope + /// Region: "us-east-1", etc. Required for all except for US West Oregon (us-west-2). + /// Account + /// Username + /// Host + /// Role + public SnowflakeClient(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role) + : this(new AzureAdAuthInfo(clientId, clientSecret, servicePrincipalObjectId, tenantId, scope, region, account, user, host, role), urlInfo: new UrlInfo + { + Host = host, + }, + sessionInfo: new SessionInfo + { + Role = role, + }) + { + } + /// /// Creates new Snowflake client. /// @@ -52,6 +83,11 @@ public SnowflakeClient(AuthInfo authInfo, SessionInfo sessionInfo = null, UrlInf { } + public SnowflakeClient(AzureAdAuthInfo authInfo, SessionInfo sessionInfo = null, UrlInfo urlInfo = null, JsonSerializerOptions jsonMapperOptions = null) + : this(new SnowflakeClientSettings(authInfo, sessionInfo, urlInfo, jsonMapperOptions)) + { + } + /// /// Creates new Snowflake client. /// @@ -63,6 +99,7 @@ public SnowflakeClient(SnowflakeClientSettings settings) _clientSettings = settings; _restClient = new RestClient(); _requestBuilder = new RequestBuilder(settings.UrlInfo); + _azureAdTokenProvider = new AzureAdTokenProvider(); SnowflakeDataMapper.Configure(settings.JsonMapperOptions); ChunksDownloader.Configure(settings.ChunksDownloaderOptions); @@ -104,10 +141,24 @@ public async Task InitNewSessionAsync(CancellationToken ct = default) return true; } + /// + /// Authenticates user and returns new Snowflake session. + /// + /// New Snowflake session private async Task AuthenticateAsync(AuthInfo authInfo, SessionInfo sessionInfo, CancellationToken ct) { var loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo); + if(authInfo is AzureAdAuthInfo azureAdAuthInfo) + { + var azureAdAccessToken = await _azureAdTokenProvider.GetAzureAdAccessTokenAsync(azureAdAuthInfo, ct).ConfigureAwait(false); + loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo, azureAdAccessToken); + } + else + { + loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo); + } + var response = await _restClient.SendAsync(loginRequest, ct).ConfigureAwait(false); if (!response.Success)