diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Albums.cs b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Albums.cs index 943e30c2..bd2cce6b 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Albums.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Albums.cs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using System.Collections.Generic; using Google.Cloud.EntityFrameworkCore.Spanner.Storage; @@ -26,7 +27,7 @@ public Albums() public long AlbumId { get; set; } public string Title { get; set; } - public SpannerDate? ReleaseDate { get; set; } + public DateOnly? ReleaseDate { get; set; } public long SingerId { get; set; } public HashSet Awards { get; set; } = new (); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Singers.cs b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Singers.cs index e0a12164..3f46bdfe 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Singers.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/Singers.cs @@ -31,7 +31,7 @@ public Singers() public string FirstName { get; set; } public string LastName { get; set; } public string FullName { get; set; } - public DateOnly? BirthDate { get; set; } + public SpannerDate? BirthDate { get; set; } public byte[] Picture { get; set; } public virtual ICollection Albums { get; set; } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/TableWithAllColumnTypes.cs b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/TableWithAllColumnTypes.cs index b8318c45..07a97887 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/TableWithAllColumnTypes.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.IntegrationTests/Model/TableWithAllColumnTypes.cs @@ -14,7 +14,6 @@ using System; using System.Collections.Generic; -using Google.Cloud.EntityFrameworkCore.Spanner.Storage; using Google.Cloud.Spanner.V1; using System.Text.Json; using SpannerDate = Google.Cloud.EntityFrameworkCore.Spanner.Storage.SpannerDate; diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/AutoGeneratedPrimaryKeyTests/AutoGeneratedPrimaryKeyMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/AutoGeneratedPrimaryKeyTests/AutoGeneratedPrimaryKeyMockServerTests.cs index 9677846c..98921a05 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/AutoGeneratedPrimaryKeyTests/AutoGeneratedPrimaryKeyMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/AutoGeneratedPrimaryKeyTests/AutoGeneratedPrimaryKeyMockServerTests.cs @@ -56,7 +56,7 @@ public AutoGeneratedPrimaryKeyMockServerTests(SpannerMockServerFixture service) service.SpannerMock.Reset(); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; [Fact] public async Task FindInvoiceAsync_ReturnsNull_IfNotFound() diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs index b3a1efb3..346ca289 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkMockServerTests.cs @@ -33,8 +33,14 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Threading.Tasks; +using Google.Cloud.Spanner.Admin.Database.V1; +using Google.Cloud.Spanner.DataProvider; +using Google.Rpc; using Xunit; +using SpannerConnection = Google.Cloud.Spanner.Data.SpannerConnection; using SpannerDate = Google.Cloud.EntityFrameworkCore.Spanner.Storage.SpannerDate; +using SpannerParameter = Google.Cloud.Spanner.Data.SpannerParameter; +using Status = Google.Rpc.Status; using V1 = Google.Cloud.Spanner.V1; #pragma warning disable EF1001 @@ -96,7 +102,13 @@ public EntityFrameworkMockServerTests(SpannerMockServerFixture service) service.SpannerMock.Reset(); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; + //private string ConnectionString => $"{_fixture.Host}:{_fixture.Port}/projects/p1/instances/i1/databases/d1;usePlainText=true"; + + bool UsesClientLib() + { + return Environment.GetEnvironmentVariable("USE_CLIENT_LIB") == "true"; + } [Fact] public async Task FindSingerAsync_ReturnsNull_IfNotFound() @@ -137,10 +149,11 @@ public async Task FindSingersUsingListOfIds_UsesParameterizedQuery() v => Assert.Equal("2", v.StringValue), v => Assert.Equal("3", v.StringValue) ); - Assert.Single(request.ParamTypes); - var type = request.ParamTypes["__singerIds_0"]; - Assert.Equal(V1.TypeCode.Array, type.Code); - Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); + Assert.Empty(request.ParamTypes); + // Assert.Single(request.ParamTypes); + // var type = request.ParamTypes["singerIds_0"]; + // Assert.Equal(V1.TypeCode.Array, type.Code); + // Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); } ); } @@ -169,10 +182,11 @@ public async Task FindSingersUsingListOfIntegers_UsesParameterizedQuery() v => Assert.Equal("9", v.StringValue), v => Assert.Equal("10", v.StringValue) ); - Assert.Single(request.ParamTypes); - var type = request.ParamTypes["__singerIds_0"]; - Assert.Equal(V1.TypeCode.Array, type.Code); - Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); + Assert.Empty(request.ParamTypes); + // Assert.Single(request.ParamTypes); + // var type = request.ParamTypes["singerIds_0"]; + // Assert.Equal(V1.TypeCode.Array, type.Code); + // Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); } ); } @@ -213,9 +227,10 @@ public async Task FindPerformancesByType_UsesParameterizedQuery() Assert.Single(request.Params.Fields); var fields = request.Params.Fields; Assert.Equal("0", fields["__type_0"].StringValue); - Assert.Single(request.ParamTypes); - var requestType = request.ParamTypes["__type_0"]; - Assert.Equal(V1.TypeCode.Int64, requestType.Code); + Assert.Empty(request.ParamTypes); + // Assert.Single(request.ParamTypes); + // var requestType = request.ParamTypes["__type_0"]; + // Assert.Equal(V1.TypeCode.Int64, requestType.Code); } ); } @@ -259,10 +274,11 @@ public async Task FindPerformancesByCollectionOfTypes_UsesParameterizedQuery() Assert.Equal(2, fields["__typesAsInts_0"].ListValue.Values.Count); Assert.Equal("0", fields["__typesAsInts_0"].ListValue.Values[0].StringValue); Assert.Equal("1", fields["__typesAsInts_0"].ListValue.Values[1].StringValue); - Assert.Single(request.ParamTypes); - var type = request.ParamTypes["__typesAsInts_0"]; - Assert.Equal(V1.TypeCode.Array, type.Code); - Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); + Assert.Empty(request.ParamTypes); + // Assert.Single(request.ParamTypes); + // var type = request.ParamTypes["__typesAsInts_0"]; + // Assert.Equal(V1.TypeCode.Array, type.Code); + // Assert.Equal(V1.TypeCode.Int64, type.ArrayElementType.Code); } ); } @@ -289,7 +305,8 @@ public async Task FindSingerAsync_ReturnsInstance_IfFound() request => { Assert.Equal(sql, request.Sql); - Assert.Null(request.Transaction); + Assert.Equal(new TransactionOptions{ReadOnly = new TransactionOptions.Types.ReadOnly{Strong = true, ReturnReadTimestamp = true}}, request.Transaction.SingleUse); + //Assert.Null(request.Transaction); } ); // A read-only operation should not initiate and commit a transaction. @@ -315,18 +332,36 @@ public async Task InsertSinger_SelectsFullName() var updateCount = await db.SaveChangesAsync(); Assert.Equal(1L, updateCount); - Assert.Empty(_fixture.SpannerMock.Requests.OfType()); - Assert.Empty(_fixture.SpannerMock.Requests.OfType()); - Assert.Collection( - _fixture.SpannerMock.Requests.OfType(), - request => - { - Assert.Equal(insertSql, request.Sql); - Assert.False(request.Transaction.HasId); - Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); - Assert.Equal(TransactionOptions.ModeOneofCase.ReadWrite, request.Transaction.Begin.ModeCase); - } - ); + var useInlineBegin = true; + if (useInlineBegin) + { + Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + Assert.Collection( + _fixture.SpannerMock.Requests.OfType(), + request => + { + Assert.Equal(insertSql, request.Sql); + Assert.False(request.Transaction.HasId); + Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); + Assert.Equal(TransactionOptions.ModeOneofCase.ReadWrite, request.Transaction.Begin.ModeCase); + } + ); + } + else + { + Assert.Single(_fixture.SpannerMock.Requests.OfType()); + Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + Assert.Collection( + _fixture.SpannerMock.Requests.OfType(), + request => + { + Assert.Equal(insertSql, request.Sql); + Assert.True(request.Transaction.HasId); + } + ); + } + Assert.Single(_fixture.SpannerMock.Requests, request => request is CommitRequest); Assert.Collection(_fixture.SpannerMock.Requests @@ -367,13 +402,16 @@ public async Task InsertTicketSale_ReturnsId() request => { Assert.Equal(insertSql, request.Sql); - Assert.Collection(request.ParamTypes, pair => + if (UsesClientLib()) { - Assert.Equal(V1.TypeCode.String, pair.Value.Code); - }, pair => + Assert.Collection(request.ParamTypes, + pair => { Assert.Equal(V1.TypeCode.String, pair.Value.Code); }, + pair => { Assert.Equal(V1.TypeCode.Json, pair.Value.Code); }); + } + else { - Assert.Equal(V1.TypeCode.Json, pair.Value.Code); - }); + Assert.Collection(request.ParamTypes, pair => { Assert.Equal(V1.TypeCode.String, pair.Value.Code); }); + } Assert.NotNull(request.Transaction?.Id); } ); @@ -452,12 +490,20 @@ public async Task UpdateSinger_SelectsFullName() Assert.Equal(1L, updateCount); Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + var useInlineBegin = false; Assert.Collection( _fixture.SpannerMock.Requests.OfType(), request => { Assert.Equal(selectSingerSql, request.Sql); - Assert.Null(request.Transaction?.Id); + if (useInlineBegin) + { + Assert.Null(request.Transaction?.Id); + } + else + { + Assert.NotNull(request.Transaction?.Id); + } }, request => { @@ -480,8 +526,11 @@ public async Task DeleteSinger_DoesNotSelectFullName() var updateCount = await db.SaveChangesAsync(); Assert.Equal(1L, updateCount); + var beginRequests = _fixture.SpannerMock.Requests.OfType(); + Assert.Empty(beginRequests); + var requests = _fixture.SpannerMock.Requests.OfType(); Assert.Collection( - _fixture.SpannerMock.Requests.OfType(), + requests, request => { Assert.Single(request.Statements); @@ -538,23 +587,43 @@ public async Task CanUseReadOnlyTransaction() $" `s`.`LastName`, `s`.`Picture`{Environment.NewLine}FROM `Singers` AS `s`{Environment.NewLine}" + $"WHERE `s`.`SingerId` = @__p_0{Environment.NewLine}LIMIT 1"); await using var db = new MockServerSampleDbContext(ConnectionString); + await db.Database.OpenConnectionAsync(); await using var transaction = await db.Database.BeginReadOnlyTransactionAsync(); Assert.NotNull(await db.Singers.FindAsync(1L)); + var useInlineBegin = true; Assert.Collection( _fixture.SpannerMock.Requests.OfType(), request => { Assert.Equal(sql, request.Sql); Assert.NotNull(request.Transaction); - Assert.False(request.Transaction.HasId); - Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); - Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Transaction.Begin.ModeCase); - Assert.True(request.Transaction.Begin.ReadOnly.HasStrong); + if (useInlineBegin) + { + Assert.False(request.Transaction.HasId); + Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); + Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Transaction.Begin.ModeCase); + Assert.True(request.Transaction.Begin.ReadOnly.HasStrong); + } + else + { + Assert.True(request.Transaction.HasId); + } } ); - Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + if (useInlineBegin) + { + Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + } + else + { + Assert.Collection(_fixture.SpannerMock.Requests.OfType(), + request => { + Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Options.ModeCase); + Assert.True(request.Options.ReadOnly.HasStrong); + }); + } } [Fact] @@ -564,24 +633,48 @@ public async Task CanUseReadOnlyTransactionWithTimestampBound() $" `s`.`LastName`, `s`.`Picture`{Environment.NewLine}FROM `Singers` AS `s`{Environment.NewLine}" + $"WHERE `s`.`SingerId` = @__p_0{Environment.NewLine}LIMIT 1"); await using var db = new MockServerSampleDbContext(ConnectionString); + await db.Database.OpenConnectionAsync(); await using var transaction = await db.Database.BeginReadOnlyTransactionAsync(TimestampBound.OfExactStaleness(TimeSpan.FromSeconds(10))); Assert.NotNull(await db.Singers.FindAsync(1L)); + var useInlineBegin = true; Assert.Collection( _fixture.SpannerMock.Requests.OfType(), request => { Assert.Equal(sql, request.Sql); Assert.NotNull(request.Transaction); - Assert.False(request.Transaction.HasId); - Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); - Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Transaction.Begin.ModeCase); - Assert.Equal(TransactionOptions.Types.ReadOnly.TimestampBoundOneofCase.ExactStaleness, request.Transaction.Begin.ReadOnly.TimestampBoundCase); - Assert.Equal(10, request.Transaction.Begin.ReadOnly.ExactStaleness.Seconds); + if (useInlineBegin) + { + Assert.False(request.Transaction.HasId); + Assert.Equal(TransactionSelector.SelectorOneofCase.Begin, request.Transaction.SelectorCase); + Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Transaction.Begin.ModeCase); + Assert.Equal(TransactionOptions.Types.ReadOnly.TimestampBoundOneofCase.ExactStaleness, + request.Transaction.Begin.ReadOnly.TimestampBoundCase); + Assert.Equal(10, request.Transaction.Begin.ReadOnly.ExactStaleness.Seconds); + } + else + { + Assert.True(request.Transaction.HasId); + } } ); - Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + if (useInlineBegin) + { + Assert.Empty(_fixture.SpannerMock.Requests.OfType()); + } + else + { + Assert.Collection(_fixture.SpannerMock.Requests.OfType(), + request => + { + Assert.Equal(TransactionOptions.ModeOneofCase.ReadOnly, request.Options.ModeCase); + Assert.Equal(TransactionOptions.Types.ReadOnly.TimestampBoundOneofCase.ExactStaleness, + request.Options.ReadOnly.TimestampBoundCase); + Assert.Equal(10, request.Options.ReadOnly.ExactStaleness.Seconds); + }); + } } [Fact] @@ -601,7 +694,8 @@ public async Task CanReadWithMaxStaleness() request => { Assert.Equal(sql, request.Sql); - Assert.Equal(Duration.FromTimeSpan(TimeSpan.FromSeconds(10)), request.Transaction?.SingleUse?.ReadOnly?.MaxStaleness); + Assert.Equal(Duration.FromTimeSpan(TimeSpan.FromSeconds(10)), + request.Transaction?.SingleUse?.ReadOnly?.MaxStaleness); } ); } @@ -720,30 +814,70 @@ public async Task InsertUsingRawSqlReturnsUpdateCountWithoutAdditionalSelectComm ColJson = JsonDocument.Parse("{\"key1\": \"value1\", \"key2\": \"value2\"}"), ColJsonArray = new List{ JsonDocument.Parse("{\"key1\": \"value1\", \"key2\": \"value2\"}"), JsonDocument.Parse("{\"key1\": \"value3\", \"key2\": \"value4\"}") }, }; - var updateCount = await db.Database.ExecuteSqlRawAsync(rawSql, - new SpannerParameter("ColBool", SpannerDbType.Bool, row.ColBool), - new SpannerParameter("ColBoolArray", SpannerDbType.ArrayOf(SpannerDbType.Bool), row.ColBoolArray), - new SpannerParameter("ColBytes", SpannerDbType.Bytes, row.ColBytes), - new SpannerParameter("ColBytesMax", SpannerDbType.Bytes, row.ColBytesMax), - new SpannerParameter("ColBytesArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), row.ColBytesArray), - new SpannerParameter("ColBytesMaxArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), row.ColBytesMaxArray), - new SpannerParameter("ColDate", SpannerDbType.Date, row.ColDate), - new SpannerParameter("ColDateArray", SpannerDbType.ArrayOf(SpannerDbType.Date), row.ColDateArray), - new SpannerParameter("ColFloat64", SpannerDbType.Float64, row.ColFloat64), - new SpannerParameter("ColFloat64Array", SpannerDbType.ArrayOf(SpannerDbType.Float64), row.ColFloat64Array), - new SpannerParameter("ColInt64", SpannerDbType.Int64, row.ColInt64), - new SpannerParameter("ColInt64Array", SpannerDbType.ArrayOf(SpannerDbType.Int64), row.ColInt64Array), - new SpannerParameter("ColNumeric", SpannerDbType.Numeric, row.ColNumeric), - new SpannerParameter("ColNumericArray", SpannerDbType.ArrayOf(SpannerDbType.Numeric), row.ColNumericArray), - new SpannerParameter("ColString", SpannerDbType.String, row.ColString), - new SpannerParameter("ColStringArray", SpannerDbType.ArrayOf(SpannerDbType.String), row.ColStringArray), - new SpannerParameter("ColStringMax", SpannerDbType.String, row.ColStringMax), - new SpannerParameter("ColStringMaxArray", SpannerDbType.ArrayOf(SpannerDbType.String), row.ColStringMaxArray), - new SpannerParameter("ColTimestamp", SpannerDbType.Timestamp, row.ColTimestamp), - new SpannerParameter("ColTimestampArray", SpannerDbType.ArrayOf(SpannerDbType.Timestamp), row.ColTimestampArray), - new SpannerParameter("ColJson", SpannerDbType.Json, row.ColJson?.ToString()), - new SpannerParameter("ColJsonArray", SpannerDbType.ArrayOf(SpannerDbType.Json), row.ColJsonArray?.Select(d => d?.ToString())) - ); + int updateCount; + if (UsesClientLib()) + { + updateCount = await db.Database.ExecuteSqlRawAsync(rawSql, + new SpannerParameter("ColBool", SpannerDbType.Bool, row.ColBool), + new SpannerParameter("ColBoolArray", SpannerDbType.ArrayOf(SpannerDbType.Bool), row.ColBoolArray), + new SpannerParameter("ColBytes", SpannerDbType.Bytes, row.ColBytes), + new SpannerParameter("ColBytesMax", SpannerDbType.Bytes, row.ColBytesMax), + new SpannerParameter("ColBytesArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), + row.ColBytesArray), + new SpannerParameter("ColBytesMaxArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), + row.ColBytesMaxArray), + new SpannerParameter("ColDate", SpannerDbType.Date, row.ColDate), + new SpannerParameter("ColDateArray", SpannerDbType.ArrayOf(SpannerDbType.Date), row.ColDateArray), + new SpannerParameter("ColFloat64", SpannerDbType.Float64, row.ColFloat64), + new SpannerParameter("ColFloat64Array", SpannerDbType.ArrayOf(SpannerDbType.Float64), + row.ColFloat64Array), + new SpannerParameter("ColInt64", SpannerDbType.Int64, row.ColInt64), + new SpannerParameter("ColInt64Array", SpannerDbType.ArrayOf(SpannerDbType.Int64), + row.ColInt64Array), + new SpannerParameter("ColNumeric", SpannerDbType.Numeric, row.ColNumeric), + new SpannerParameter("ColNumericArray", SpannerDbType.ArrayOf(SpannerDbType.Numeric), + row.ColNumericArray), + new SpannerParameter("ColString", SpannerDbType.String, row.ColString), + new SpannerParameter("ColStringArray", SpannerDbType.ArrayOf(SpannerDbType.String), + row.ColStringArray), + new SpannerParameter("ColStringMax", SpannerDbType.String, row.ColStringMax), + new SpannerParameter("ColStringMaxArray", SpannerDbType.ArrayOf(SpannerDbType.String), + row.ColStringMaxArray), + new SpannerParameter("ColTimestamp", SpannerDbType.Timestamp, row.ColTimestamp), + new SpannerParameter("ColTimestampArray", SpannerDbType.ArrayOf(SpannerDbType.Timestamp), + row.ColTimestampArray), + new SpannerParameter("ColJson", SpannerDbType.Json, row.ColJson?.ToString()), + new SpannerParameter("ColJsonArray", SpannerDbType.ArrayOf(SpannerDbType.Json), + row.ColJsonArray?.Select(d => d?.ToString())) + ); + } + else + { + updateCount = await db.Database.ExecuteSqlRawAsync(rawSql, + row.ColBool, + row.ColBoolArray, + row.ColBytes, + row.ColBytesMax, + row.ColBytesArray, + row.ColBytesMaxArray, + row.ColDate, + row.ColDateArray, + row.ColFloat64, + row.ColFloat64Array, + row.ColInt64, + row.ColInt64Array, + row.ColNumeric.Value.ToDecimal(LossOfPrecisionHandling.Truncate), + row.ColNumericArray.Select(d => d.Value.ToDecimal(LossOfPrecisionHandling.Truncate)).ToList(), + row.ColString, + row.ColStringArray, + row.ColStringMax, + row.ColStringMaxArray, + row.ColTimestamp, + row.ColTimestampArray, + row.ColJson?.ToString(), + row.ColJsonArray?.Select(d => d?.ToString()).ToList() + ); + } Assert.Equal(1, updateCount); // Verify that the INSERT statement is the only one on the mock server. @@ -872,8 +1006,16 @@ public async Task ExplicitAndImplicitTransactionIsRetried(bool disableInternalRe await cmd.ExecuteScalarAsync(); // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); - Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + if (UsesClientLib()) + { + var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); + Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + } + else + { + var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); + Assert.Equal(Code.Aborted, (Code) e.Status.Code); + } } else { @@ -940,26 +1082,44 @@ public async Task ExplicitAndImplicitTransactionIsRetried_WhenUsingRawSql(bool d await cmd.ExecuteScalarAsync(); // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var e = await Assert.ThrowsAsync(() => db.Database.ExecuteSqlRawAsync(insertSql, - new SpannerParameter("p0", SpannerDbType.String, "C1"), - new SpannerParameter("p1", SpannerDbType.Bool, true), - new SpannerParameter("p2", SpannerDbType.Int64, 1000L), - new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), - new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) - )); - Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + if (UsesClientLib()) + { + var e = await Assert.ThrowsAsync(() => db.Database.ExecuteSqlRawAsync(insertSql, + new SpannerParameter("p0", SpannerDbType.String, "C1"), + new SpannerParameter("p1", SpannerDbType.Bool, true), + new SpannerParameter("p2", SpannerDbType.Int64, 1000L), + new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), + new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) + )); + Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + } + else + { + var e = await Assert.ThrowsAsync( + () => db.Database.ExecuteSqlRawAsync(insertSql, "C1", true, 1000L, "Concert Hall", null) + ); + Assert.Equal((int) Code.Aborted, e.Status.Code); + } } else { // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var updateCount = await db.Database.ExecuteSqlRawAsync(insertSql, - new SpannerParameter("p0", SpannerDbType.String, "C1"), - new SpannerParameter("p1", SpannerDbType.Bool, true), - new SpannerParameter("p2", SpannerDbType.Int64, 1000L), - new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), - new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) - ); + int updateCount; + if (UsesClientLib()) + { + updateCount = await db.Database.ExecuteSqlRawAsync(insertSql, + new SpannerParameter("p0", SpannerDbType.String, "C1"), + new SpannerParameter("p1", SpannerDbType.Bool, true), + new SpannerParameter("p2", SpannerDbType.Int64, 1000L), + new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), + new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) + ); + } + else + { + updateCount = await db.Database.ExecuteSqlRawAsync(insertSql, "C1", true, 1000L, "Concert Hall", null); + } Assert.Equal(1L, updateCount); if (useExplicitTransaction) { @@ -2444,45 +2604,68 @@ public async Task CanInsertAllTypes() request => { var types = request.ParamTypes; - var index = -1; - Assert.Equal(V1.TypeCode.Int64, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Bool, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Bool, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Bytes, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Bytes, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Bytes, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Bytes, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Date, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Date, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Float32, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Float32, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Float64, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Float64, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Int64, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Json, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Json, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Numeric, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Numeric, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.String, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.String, types["p" + index].ArrayElementType.Code); - Assert.Equal(V1.TypeCode.Timestamp, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); - Assert.Equal(V1.TypeCode.Timestamp, types["p" + index].ArrayElementType.Code); - Assert.Equal(24, index); + if (UsesClientLib()) + { + var index = -1; + Assert.Equal(24, types.Count); + Assert.Equal(V1.TypeCode.Int64, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Bool, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Bool, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Bytes, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Bytes, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Bytes, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Bytes, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Date, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Date, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Float32, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Float32, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Float64, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Float64, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Int64, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Json, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Json, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Numeric, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Numeric, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.String, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.String, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.String, types["p" + index].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Timestamp, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Array, types["p" + ++index].Code); + Assert.Equal(V1.TypeCode.Timestamp, types["p" + index].ArrayElementType.Code); + Assert.Equal(24, index); + } + else + { + // SpannerLib only includes a type code if one has explicitly been set for the parameter. + Assert.Equal(13, types.Count); + Assert.Equal(V1.TypeCode.Int64, types["p0"].Code); + Assert.Equal(V1.TypeCode.String, types["p1"].Code); + Assert.Equal(V1.TypeCode.Bytes, types["p4"].Code); + Assert.Equal(V1.TypeCode.Bytes, types["p6"].Code); + Assert.Equal(V1.TypeCode.Date, types["p8"].Code); + Assert.Equal(V1.TypeCode.Array, types["p9"].Code); + Assert.Equal(V1.TypeCode.Date, types["p9"].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Float32, types["p10"].Code); + Assert.Equal(V1.TypeCode.Float64, types["p12"].Code); + Assert.Equal(V1.TypeCode.Array, types["p16"].Code); + Assert.Equal(V1.TypeCode.Json, types["p16"].ArrayElementType.Code); + Assert.Equal(V1.TypeCode.Numeric, types["p17"].Code); + Assert.Equal(V1.TypeCode.String, types["p21"].Code); + Assert.Equal(V1.TypeCode.Timestamp, types["p23"].Code); + } } ); @@ -2554,11 +2737,11 @@ public async Task CanInsertAllTypes() ); Assert.Equal("", fields["p" + ++index].StringValue); Assert.Empty(fields["p" + ++index].ListValue.Values); - Assert.Equal("2000-01-01T00:00:00Z", fields["p" + ++index].StringValue); + Assert.Equal("2000-01-01T00:00:00.0000000Z", fields["p" + ++index].StringValue); Assert.Collection(fields["p" + ++index].ListValue.Values, - v => Assert.Equal("2000-01-01T00:00:00.001Z", v.StringValue), + v => Assert.Equal("2000-01-01T00:00:00.0010000Z", v.StringValue), v => Assert.Equal(Value.KindOneofCase.NullValue, v.KindCase), - v => Assert.Equal("2000-01-01T00:00:00.002Z", v.StringValue) + v => Assert.Equal("2000-01-01T00:00:00.0020000Z", v.StringValue) ); Assert.Equal(24, index); } @@ -2754,12 +2937,14 @@ internal static StatementResult CreateTableWithAllColumnTypesResultSet() ); } - [Fact] + [SkippableFact] public async Task RequestIncludesEfCoreClientHeader() { + Skip.IfNot(UsesClientLib()); + var sql = $"SELECT `s`.`SingerId`, `s`.`BirthDate`, `s`.`FirstName`, `s`.`FullName`, `s`.`LastName`, " + $"`s`.`Picture`{Environment.NewLine}FROM `Singers` AS `s`{Environment.NewLine}" + - $"WHERE `s`.`SingerId` = @__p_0{Environment.NewLine}LIMIT 1"; + $"WHERE `s`.`SingerId` = @p_0{Environment.NewLine}LIMIT 1"; _fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( new List>(), new List())); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkSessionLeakMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkSessionLeakMockServerTests.cs index 6373213e..bfb9f0f0 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkSessionLeakMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkSessionLeakMockServerTests.cs @@ -30,8 +30,13 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Threading.Tasks; +using Google.Cloud.Spanner.DataProvider; +using Google.Rpc; using Xunit; +using SpannerConnection = Google.Cloud.Spanner.Data.SpannerConnection; +using SpannerConnectionStringBuilder = Google.Cloud.Spanner.Data.SpannerConnectionStringBuilder; using SpannerDate = Google.Cloud.EntityFrameworkCore.Spanner.Storage.SpannerDate; +using SpannerParameter = Google.Cloud.Spanner.Data.SpannerParameter; using V1 = Google.Cloud.Spanner.V1; #pragma warning disable EF1001 @@ -47,6 +52,11 @@ internal LimitedSessionsSampleDbContext(string connectionString, SessionPoolMana _connectionString = connectionString; _manager = manager; } + + bool UsesClientLib() + { + return Environment.GetEnvironmentVariable("USE_CLIENT_LIB") == "true"; + } protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { @@ -54,14 +64,27 @@ protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { return; } - var builder = new SpannerConnectionStringBuilder(_connectionString, ChannelCredentials.Insecure) + + if (UsesClientLib()) { - SessionPoolManager = _manager - }; - optionsBuilder - .UseSpanner(new SpannerRetriableConnection(new SpannerConnection(builder)), _ => SpannerModelValidationConnectionProvider.Instance.EnableDatabaseModelValidation(false), ChannelCredentials.Insecure) - .UseMutations(MutationUsage.Never) - .UseLazyLoadingProxies(); + var builder = new SpannerConnectionStringBuilder(_connectionString, ChannelCredentials.Insecure) + { + SessionPoolManager = _manager + }; + optionsBuilder + .UseSpanner(new SpannerRetriableConnection(new SpannerConnection(builder)), + _ => SpannerModelValidationConnectionProvider.Instance.EnableDatabaseModelValidation(false), + ChannelCredentials.Insecure) + .UseMutations(MutationUsage.Never) + .UseLazyLoadingProxies(); + } + else + { + optionsBuilder + .UseSpanner(_connectionString, _ => SpannerModelValidationConnectionProvider.Instance.EnableDatabaseModelValidation(false), ChannelCredentials.Insecure) + .UseMutations(MutationUsage.Never) + .UseLazyLoadingProxies(); + } } } @@ -87,7 +110,13 @@ public EntityFrameworkSessionLeakMockServerTests(SpannerMockServerFixture servic _manager = SessionPoolManager.Create(options); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; + // private string ConnectionString => $"{_fixture.Host}:{_fixture.Port}/projects/p1/instances/i1/databases/d1;usePlainText=true"; + + bool UsesClientLib() + { + return Environment.GetEnvironmentVariable("USE_CLIENT_LIB") == "true"; + } private static async Task Repeat(int count, Func action) { @@ -359,40 +388,74 @@ public async Task InsertUsingRawSqlReturnsUpdateCountWithoutAdditionalSelectComm using var db = CreateContext(); await Repeat(async () => { - var updateCount = await db.Database.ExecuteSqlRawAsync(rawSql, - new SpannerParameter("ColBool", SpannerDbType.Bool, row.ColBool), - new SpannerParameter("ColBoolArray", SpannerDbType.ArrayOf(SpannerDbType.Bool), row.ColBoolArray), - new SpannerParameter("ColBytes", SpannerDbType.Bytes, row.ColBytes), - new SpannerParameter("ColBytesMax", SpannerDbType.Bytes, row.ColBytesMax), - new SpannerParameter("ColBytesArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), - row.ColBytesArray), - new SpannerParameter("ColBytesMaxArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), - row.ColBytesMaxArray), - new SpannerParameter("ColDate", SpannerDbType.Date, row.ColDate), - new SpannerParameter("ColDateArray", SpannerDbType.ArrayOf(SpannerDbType.Date), row.ColDateArray), - new SpannerParameter("ColFloat64", SpannerDbType.Float64, row.ColFloat64), - new SpannerParameter("ColFloat64Array", SpannerDbType.ArrayOf(SpannerDbType.Float64), - row.ColFloat64Array), - new SpannerParameter("ColInt64", SpannerDbType.Int64, row.ColInt64), - new SpannerParameter("ColInt64Array", SpannerDbType.ArrayOf(SpannerDbType.Int64), - row.ColInt64Array), - new SpannerParameter("ColNumeric", SpannerDbType.Numeric, row.ColNumeric), - new SpannerParameter("ColNumericArray", SpannerDbType.ArrayOf(SpannerDbType.Numeric), - row.ColNumericArray), - new SpannerParameter("ColString", SpannerDbType.String, row.ColString), - new SpannerParameter("ColStringArray", SpannerDbType.ArrayOf(SpannerDbType.String), - row.ColStringArray), - new SpannerParameter("ColStringMax", SpannerDbType.String, row.ColStringMax), - new SpannerParameter("ColStringMaxArray", SpannerDbType.ArrayOf(SpannerDbType.String), - row.ColStringMaxArray), - new SpannerParameter("ColTimestamp", SpannerDbType.Timestamp, row.ColTimestamp), - new SpannerParameter("ColTimestampArray", SpannerDbType.ArrayOf(SpannerDbType.Timestamp), - row.ColTimestampArray), - new SpannerParameter("ColJson", SpannerDbType.Json, row.ColJson?.ToString()), - new SpannerParameter("ColJsonArray", SpannerDbType.ArrayOf(SpannerDbType.Json), - row.ColJsonArray?.Select(d => d?.ToString())) - ); - Assert.Equal(1, updateCount); + if (UsesClientLib()) + { + var updateCount = await db.Database.ExecuteSqlRawAsync(rawSql, + new SpannerParameter("ColBool", SpannerDbType.Bool, row.ColBool), + new SpannerParameter("ColBoolArray", SpannerDbType.ArrayOf(SpannerDbType.Bool), + row.ColBoolArray), + new SpannerParameter("ColBytes", SpannerDbType.Bytes, row.ColBytes), + new SpannerParameter("ColBytesMax", SpannerDbType.Bytes, row.ColBytesMax), + new SpannerParameter("ColBytesArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), + row.ColBytesArray), + new SpannerParameter("ColBytesMaxArray", SpannerDbType.ArrayOf(SpannerDbType.Bytes), + row.ColBytesMaxArray), + new SpannerParameter("ColDate", SpannerDbType.Date, row.ColDate), + new SpannerParameter("ColDateArray", SpannerDbType.ArrayOf(SpannerDbType.Date), + row.ColDateArray), + new SpannerParameter("ColFloat64", SpannerDbType.Float64, row.ColFloat64), + new SpannerParameter("ColFloat64Array", SpannerDbType.ArrayOf(SpannerDbType.Float64), + row.ColFloat64Array), + new SpannerParameter("ColInt64", SpannerDbType.Int64, row.ColInt64), + new SpannerParameter("ColInt64Array", SpannerDbType.ArrayOf(SpannerDbType.Int64), + row.ColInt64Array), + new SpannerParameter("ColNumeric", SpannerDbType.Numeric, row.ColNumeric), + new SpannerParameter("ColNumericArray", SpannerDbType.ArrayOf(SpannerDbType.Numeric), + row.ColNumericArray), + new SpannerParameter("ColString", SpannerDbType.String, row.ColString), + new SpannerParameter("ColStringArray", SpannerDbType.ArrayOf(SpannerDbType.String), + row.ColStringArray), + new SpannerParameter("ColStringMax", SpannerDbType.String, row.ColStringMax), + new SpannerParameter("ColStringMaxArray", SpannerDbType.ArrayOf(SpannerDbType.String), + row.ColStringMaxArray), + new SpannerParameter("ColTimestamp", SpannerDbType.Timestamp, row.ColTimestamp), + new SpannerParameter("ColTimestampArray", SpannerDbType.ArrayOf(SpannerDbType.Timestamp), + row.ColTimestampArray), + new SpannerParameter("ColJson", SpannerDbType.Json, row.ColJson?.ToString()), + new SpannerParameter("ColJsonArray", SpannerDbType.ArrayOf(SpannerDbType.Json), + row.ColJsonArray?.Select(d => d?.ToString())) + ); + Assert.Equal(1, updateCount); + } + else + { + var updateCount = await db.Database.ExecuteSqlRawAsync( + rawSql, + row.ColBool, + row.ColBoolArray, + row.ColBytes, + row.ColBytesMax, + row.ColBytesArray, + row.ColBytesMaxArray, + row.ColDate, + row.ColDateArray, + row.ColFloat64, + row.ColFloat64Array, + row.ColInt64, + row.ColInt64Array, + row.ColNumeric, + row.ColNumericArray, + row.ColString, + row.ColStringArray, + row.ColStringMax, + row.ColStringMaxArray, + row.ColTimestamp, + row.ColTimestampArray, + row.ColJson?.ToString(), + row.ColJsonArray?.Select(d => d?.ToString()) + ); + Assert.Equal(1, updateCount); + } }); } @@ -440,8 +503,16 @@ await Repeat(async () => await cmd.ExecuteScalarAsync(); // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); - Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + if (UsesClientLib()) + { + var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); + Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + } + else + { + var e = await Assert.ThrowsAsync(() => db.SaveChangesAsync()); + Assert.Equal((int) Code.Aborted, e.Status.Code); + } } else { @@ -496,27 +567,47 @@ await Repeat(async () => await cmd.ExecuteScalarAsync(); // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var e = await Assert.ThrowsAsync(() => db.Database.ExecuteSqlRawAsync(insertSql, - new SpannerParameter("p0", SpannerDbType.String, "C1"), - new SpannerParameter("p1", SpannerDbType.Bool, true), - new SpannerParameter("p2", SpannerDbType.Int64, 1000L), - new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), - new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) - )); - Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + if (UsesClientLib()) + { + var f = () => db.Database.ExecuteSqlRawAsync(insertSql, + new SpannerParameter("p0", SpannerDbType.String, "C1"), + new SpannerParameter("p1", SpannerDbType.Bool, true), + new SpannerParameter("p2", SpannerDbType.Int64, 1000L), + new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), + new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) + ); + var e = await Assert.ThrowsAsync(f); + Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + } + else + { + var f = () => db.Database.ExecuteSqlRawAsync( + insertSql, "C1", true, 1000L, "Concert Hall", null); + var e = await Assert.ThrowsAsync(f); + Assert.Equal((int) Code.Aborted, e.Status.Code); + } } else { // Abort the next statement that is executed on the mock server. _fixture.SpannerMock.AbortNextStatement(); - var updateCount = await db.Database.ExecuteSqlRawAsync(insertSql, - new SpannerParameter("p0", SpannerDbType.String, "C1"), - new SpannerParameter("p1", SpannerDbType.Bool, true), - new SpannerParameter("p2", SpannerDbType.Int64, 1000L), - new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), - new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) - ); - Assert.Equal(1L, updateCount); + if (UsesClientLib()) + { + var updateCount = await db.Database.ExecuteSqlRawAsync(insertSql, + new SpannerParameter("p0", SpannerDbType.String, "C1"), + new SpannerParameter("p1", SpannerDbType.Bool, true), + new SpannerParameter("p2", SpannerDbType.Int64, 1000L), + new SpannerParameter("p3", SpannerDbType.String, "Concert Hall"), + new SpannerParameter("p4", SpannerDbType.ArrayOf(SpannerDbType.Float64)) + ); + Assert.Equal(1L, updateCount); + } + else + { + var updateCount = await db.Database.ExecuteSqlRawAsync( + insertSql, "C1", true, 1000L, "Concert Hall", null); + Assert.Equal(1L, updateCount); + } if (useExplicitTransaction) { await transaction.CommitAsync(); @@ -2106,9 +2197,12 @@ public async Task MultipleReadWriteTransactionsWithUsingBlocks_DoesNotHoldOnToSe await transaction3.CommitAsync(); } - [Fact] + [SkippableFact] public async Task NestedTransactionsStartNewTransactions() { + // SpannerLib uses multiplexed sessions, so the pool is not exhausted. + Skip.IfNot(UsesClientLib()); + AddFindSingerResult($"SELECT `s`.`SingerId`, `s`.`BirthDate`, `s`.`FirstName`, `s`.`FullName`," + $" `s`.`LastName`, `s`.`Picture`{Environment.NewLine}FROM `Singers` AS `s`{Environment.NewLine}" + $"WHERE `s`.`SingerId` = @__p_0{Environment.NewLine}LIMIT 1"); @@ -2200,9 +2294,12 @@ await Repeat(() => }); } - [Fact] + [SkippableFact] public async Task OnlyDisposingReadOnlyTransactionWithoutCommitting_LeaksSession() { + // SpannerLib uses multiplexed sessions, so the session pool will not be exhausted. + Skip.IfNot(UsesClientLib()); + AddFindSingerResult($"SELECT `s`.`SingerId`, `s`.`BirthDate`, `s`.`FirstName`, `s`.`FullName`," + $" `s`.`LastName`, `s`.`Picture`{Environment.NewLine}FROM `Singers` AS `s`{Environment.NewLine}" + $"WHERE `s`.`SingerId` = @__p_0{Environment.NewLine}LIMIT 1"); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkUsingMutationsMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkUsingMutationsMockServerTests.cs index 544812ff..c68c6eb5 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkUsingMutationsMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/EntityFrameworkUsingMutationsMockServerTests.cs @@ -28,6 +28,8 @@ using System.Linq; using System.Text.Json; using System.Threading.Tasks; +using Google.Cloud.Spanner.DataProvider; +using Google.Rpc; using Xunit; using SpannerDate = Google.Cloud.EntityFrameworkCore.Spanner.Storage.SpannerDate; using V1 = Google.Cloud.Spanner.V1; @@ -92,7 +94,13 @@ public EntityFrameworkMockUsingMutationsServerTests(SpannerMockServerFixture ser service.SpannerMock.Reset(); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; + //private string ConnectionString => $"{_fixture.Host}:{_fixture.Port}/projects/p1/instances/i1/databases/d1;usePlainText=true"; + + bool UsesClientLib() + { + return Environment.GetEnvironmentVariable("USE_CLIENT_LIB") == "true"; + } [Fact] public async Task InsertAlbum() @@ -132,6 +140,7 @@ public async Task InsertSinger() var selectFullNameSql = AddSelectSingerFullNameResult("Alice Morrison", 0); await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); db.Singers.Add(new Singers { SingerId = 1L, @@ -251,6 +260,7 @@ public async Task UpdateSinger_SelectsFullName() var selectFullNameSql = AddSelectSingerFullNameResult("Alice Pieterson-Morrison", 0); await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); var singer = await db.Singers.FindAsync(1L); Assert.NotNull(singer); singer.LastName = "Pieterson-Morrison"; @@ -270,12 +280,12 @@ public async Task UpdateSinger_SelectsFullName() request => { Assert.Equal(selectSingerSql.Trim(), request.Sql.Trim()); - Assert.Null(request.Transaction?.Id); + Assert.True(request.Transaction?.Id?.IsEmpty); }, request => { Assert.Equal(selectFullNameSql.Trim(), request.Sql.Trim()); - Assert.Null(request.Transaction?.Id); + Assert.True(request.Transaction?.Id?.IsEmpty); } ); Assert.Collection( @@ -481,8 +491,16 @@ public async Task ExplicitAndImplicitTransactionIsRetried(bool disableInternalRe if (disableInternalRetries && useExplicitTransaction) { await db.SaveChangesAsync(); - var e = await Assert.ThrowsAsync(() => transaction.CommitAsync()); - Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + if (UsesClientLib()) + { + var e = await Assert.ThrowsAsync(() => transaction.CommitAsync()); + Assert.Equal(ErrorCode.Aborted, e.ErrorCode); + } + else + { + var e = await Assert.ThrowsAsync(() => transaction.CommitAsync()); + Assert.Equal((int) Code.Aborted, e.Status.Code); + } } else { @@ -516,6 +534,7 @@ public async Task ExplicitAndImplicitTransactionIsRetried(bool disableInternalRe public async Task CanInsertCommitTimestamp() { await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); _fixture.SpannerMock.AddOrUpdateStatementResult($"{Environment.NewLine}SELECT `ColComputed`" + $"{Environment.NewLine}FROM `TableWithAllColumnTypes`{Environment.NewLine}WHERE TRUE AND `ColInt64` = @p0", StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = V1.TypeCode.String }, "FOO")); @@ -541,6 +560,7 @@ public async Task CanInsertCommitTimestamp() public async Task CanUpdateCommitTimestamp() { await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); _fixture.SpannerMock.AddOrUpdateStatementResult($"{Environment.NewLine}SELECT `ColComputed`{Environment.NewLine}FROM `TableWithAllColumnTypes`{Environment.NewLine}WHERE TRUE AND `ColInt64` = @p0", StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = V1.TypeCode.String }, "FOO")); var row = new TableWithAllColumnTypes { ColInt64 = 1L }; @@ -567,6 +587,7 @@ public async Task CanUpdateCommitTimestamp() public async Task CanInsertRowWithCommitTimestampAndComputedColumn() { await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); var selectSql = $"{Environment.NewLine}SELECT `ColComputed`{Environment.NewLine}FROM `TableWithAllColumnTypes`{Environment.NewLine}WHERE TRUE AND `ColInt64` = @p0"; _fixture.SpannerMock.AddOrUpdateStatementResult(selectSql, StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = V1.TypeCode.String }, "FOO")); @@ -607,6 +628,7 @@ public async Task CanInsertRowWithCommitTimestampAndComputedColumn() public async Task CanInsertAllTypes() { await using var db = new MockServerSampleDbContextUsingMutations(ConnectionString); + await db.Database.OpenConnectionAsync(); _fixture.SpannerMock.AddOrUpdateStatementResult($"{Environment.NewLine}SELECT `ColComputed`" + $"{Environment.NewLine}FROM `TableWithAllColumnTypes`{Environment.NewLine}WHERE TRUE AND `ColInt64` = @p0", StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = V1.TypeCode.String }, "FOO")); @@ -764,13 +786,13 @@ public async Task CanInsertAllTypes() Assert.Empty(values[index++].ListValue.Values); Assert.Equal("ColTimestamp", columns[index]); Assert.Equal(Value.KindOneofCase.StringValue, values[index].KindCase); - Assert.Equal("2000-01-01T00:00:00Z", values[index++].StringValue); + Assert.Equal("2000-01-01T00:00:00.0000000Z", values[index++].StringValue); Assert.Equal("ColTimestampArray", columns[index]); Assert.Equal(Value.KindOneofCase.ListValue, values[index].KindCase); Assert.Collection(values[index].ListValue.Values, - v => Assert.Equal("2000-01-01T00:00:00.001Z", v.StringValue), + v => Assert.Equal("2000-01-01T00:00:00.0010000Z", v.StringValue), v => Assert.Equal(Value.KindOneofCase.NullValue, v.KindCase), - v => Assert.Equal("2000-01-01T00:00:00.002Z", v.StringValue) + v => Assert.Equal("2000-01-01T00:00:00.0020000Z", v.StringValue) ); } ); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MigrationTests/MigrationMockServerTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MigrationTests/MigrationMockServerTests.cs index 001f07aa..5ff844c6 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MigrationTests/MigrationMockServerTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MigrationTests/MigrationMockServerTests.cs @@ -35,7 +35,8 @@ public MigrationMockServerTests(SpannerMockServerFixture service) service.SpannerMock.Reset(); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; + //private string ConnectionString => $"{_fixture.Host}:{_fixture.Port}/projects/p1/instances/i1/databases/d1;usePlainText=true"; [Fact] public void TestMigrateUsesDdlBatch() @@ -52,6 +53,7 @@ public void TestMigrateUsesDdlBatch() StatementResult.CreateUpdateCount(1) ); using var db = new MockMigrationSampleDbContext(ConnectionString); + db.Database.OpenConnection(); db.Database.Migrate(); Assert.Collection(_fixture.DatabaseAdminMock.Requests, diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MockSpannerServer.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MockSpannerServer.cs index c3a5885e..3005c7cb 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MockSpannerServer.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/MockSpannerServer.cs @@ -309,6 +309,8 @@ public void Dispose() } } + private static readonly string s_dialect_query = + "select option_value from information_schema.database_options where option_name='database_dialect'"; private static readonly Empty s_empty = new (); private static readonly TransactionOptions s_singleUse = new() { ReadOnly = new TransactionOptions.Types.ReadOnly { Strong = true, ReturnReadTimestamp = false } }; @@ -326,6 +328,11 @@ public void Dispose() private bool _abortNextStatement; private readonly ConcurrentDictionary _executionTimes = new(); + public MockSpannerService() + { + AddDialectResult(); + } + public void AddOrUpdateStatementResult(string sql, StatementResult result) { _results.AddOrUpdate(sql.Trim(), @@ -342,6 +349,20 @@ public void AddOrUpdateExecutionTime(string method, ExecutionTime executionTime) ); } + private void AddDialectResult() + { + AddOrUpdateStatementResult(s_dialect_query, + StatementResult.CreateResultSet( + new List> + { + Tuple.Create(V1.TypeCode.String, "option_value"), + }, + new List + { + new object[] { "GOOGLE_STANDARD_SQL" }, + })); + } + internal void AbortTransaction(TransactionId transactionId) { var prop = transactionId.GetType().GetProperty("Id", BindingFlags.Instance | BindingFlags.NonPublic); @@ -372,6 +393,7 @@ public void Reset() _results.Clear(); _abortedTransactions.Clear(); _abortNextStatement = false; + AddDialectResult(); } public override Task BeginTransaction(BeginTransactionRequest request, ServerCallContext context) @@ -413,10 +435,14 @@ public override Task Rollback(RollbackRequest request, ServerCallContext return Task.FromResult(s_empty); } - private Session CreateSession(DatabaseName database) + private Session CreateSession(DatabaseName database, bool multiplexed) { var id = Interlocked.Increment(ref _sessionCounter); - Session session = new Session { SessionName = new SessionName(database.ProjectId, database.InstanceId, database.DatabaseId, $"session-{id}") }; + Session session = new Session + { + SessionName = new SessionName(database.ProjectId, database.InstanceId, database.DatabaseId, $"session-{id}"), + Multiplexed = multiplexed, + }; if (!_sessions.TryAdd(session.SessionName, session)) { throw new RpcException(new Grpc.Core.Status(StatusCode.AlreadyExists, $"Session with id session-{id} already exists")); @@ -494,7 +520,7 @@ private Transaction BeginTransaction(SessionName session, TransactionOptions opt tx.Id = ByteString.CopyFromUtf8($"{session}/transactions/{id}"); if (options.ModeCase == TransactionOptions.ModeOneofCase.ReadOnly && options.ReadOnly.ReturnReadTimestamp) { - tx.ReadTimestamp = Timestamp.FromDateTime(DateTime.Now); + tx.ReadTimestamp = Timestamp.FromDateTime(DateTime.UtcNow); } if (!singleUse) { @@ -515,7 +541,7 @@ public override Task BatchCreateSessions(BatchCreat BatchCreateSessionsResponse response = new BatchCreateSessionsResponse(); for (int i = 0; i < request.SessionCount; i++) { - response.Session.Add(CreateSession(database)); + response.Session.Add(CreateSession(database, false)); } return Task.FromResult(response); } @@ -526,7 +552,7 @@ public override Task CreateSession(CreateSessionRequest request, Server _contexts.Enqueue(context); _headers.Enqueue(context.RequestHeaders); var database = request.DatabaseAsDatabaseName; - return Task.FromResult(CreateSession(database)); + return Task.FromResult(CreateSession(database, request.Session?.Multiplexed ?? false)); } public override Task GetSession(GetSessionRequest request, ServerCallContext context) @@ -576,7 +602,7 @@ public override Task ExecuteBatchDml(ExecuteBatchDmlReq _executionTimes.TryGetValue(nameof(ExecuteBatchDml), out ExecutionTime executionTime); executionTime?.SimulateExecutionTime(); _ = TryFindSession(request.SessionAsSessionName); - _ = FindOrBeginTransaction(request.SessionAsSessionName, request.Transaction); + var tx = FindOrBeginTransaction(request.SessionAsSessionName, request.Transaction); var response = new ExecuteBatchDmlResponse { // TODO: Return other statuses based on the mocked results. @@ -603,7 +629,12 @@ public override Task ExecuteBatchDml(ExecuteBatchDmlReq { executionTime.SimulateExecutionTime(); } - response.ResultSets.Add(CreateUpdateCountResultSet(result.UpdateCount)); + var resultSet = CreateUpdateCountResultSet(result.UpdateCount); + if (index == 0 && request.Transaction?.Begin != null && tx != null) + { + resultSet.Metadata.Transaction = tx; + } + response.ResultSets.Add(resultSet); break; case StatementResult.StatementResultType.Exception: if (index == 0) @@ -636,7 +667,10 @@ private Status StatusFromException(Exception e) public override async Task ExecuteStreamingSql(ExecuteSqlRequest request, IServerStreamWriter responseStream, ServerCallContext context) { - _requests.Enqueue(request); + if (!request.Sql.Equals(s_dialect_query)) + { + _requests.Enqueue(request); + } _contexts.Enqueue(context); _headers.Enqueue(context.RequestHeaders); _executionTimes.TryGetValue(nameof(ExecuteStreamingSql) + request.Sql, out ExecutionTime executionTime); @@ -700,7 +734,7 @@ private async Task WriteUpdateCount(Transaction transaction, long updateCount, I { PartialResultSet prs = new PartialResultSet { - Metadata = new ResultSetMetadata { Transaction = transaction }, + Metadata = new ResultSetMetadata { Transaction = transaction, RowType = new StructType()}, Stats = new ResultSetStats { RowCountExact = updateCount } }; await responseStream.WriteAsync(prs); @@ -710,6 +744,7 @@ private ResultSet CreateUpdateCountResultSet(long updateCount) { ResultSet rs = new ResultSet { + Metadata = new ResultSetMetadata { RowType = new StructType()}, Stats = new ResultSetStats { RowCountExact = updateCount } }; return rs; diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/SpannerMockServerFixture.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/SpannerMockServerFixture.cs index 4af9169f..63d4d0d9 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/SpannerMockServerFixture.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/SpannerMockServerFixture.cs @@ -17,8 +17,10 @@ using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using System; +using System.Collections.Generic; using System.Linq; using System.Net; +using Google.Cloud.Spanner.Admin.Database.V1; namespace Google.Cloud.EntityFrameworkCore.Spanner.Tests; @@ -38,6 +40,17 @@ public class SpannerMockServerFixture : IDisposable public SpannerMockServerFixture() { SpannerMock = new MockSpannerService(); + SpannerMock.AddOrUpdateStatementResult( + "select option_value from information_schema.database_options where option_name='database_dialect'", + StatementResult.CreateResultSet( + new List> + { + Tuple.Create(Cloud.Spanner.V1.TypeCode.String, "option_value"), + }, + new List + { + new object[] { nameof(DatabaseDialect.GoogleStandardSql) }, + })); DatabaseAdminMock = new MockDatabaseAdminService(); var endpoint = IPEndPoint.Parse("127.0.0.1:0"); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/TypeConversionTests.cs b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/TypeConversionTests.cs index 1288c743..ed8bfd5c 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.Tests/TypeConversionTests.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner.Tests/TypeConversionTests.cs @@ -69,7 +69,12 @@ public TypeConversionTests(SpannerMockServerFixture service) service.SpannerMock.Reset(); } - private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port}"; + private string ConnectionString => $"Data Source=projects/p1/instances/i1/databases/d1;Host={_fixture.Host};Port={_fixture.Port};UsePlainText=true"; + + bool UsesClientLib() + { + return Environment.GetEnvironmentVariable("USE_CLIENT_LIB") == "true"; + } [Fact] public async Task TestEntity_ConvertValuesWithoutPrecisionLossOrOverflow_Succeeds() @@ -171,7 +176,14 @@ public async Task TestEntity_ConvertValuesWithDecimalOverflow_Fails() )); using var db = new TypeConversionDbContext(ConnectionString); - await Assert.ThrowsAsync(() => db.TestEntities.FindAsync(1L).AsTask()); + if (UsesClientLib()) + { + await Assert.ThrowsAsync(() => db.TestEntities.FindAsync(1L).AsTask()); + } + else + { + await Assert.ThrowsAsync(() => db.TestEntities.FindAsync(1L).AsTask()); + } } } } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner.sln b/Google.Cloud.EntityFrameworkCore.Spanner.sln index 1573d331..d724e4cf 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner.sln +++ b/Google.Cloud.EntityFrameworkCore.Spanner.sln @@ -13,6 +13,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Google.Cloud.EntityFramewor EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Google.Cloud.EntityFrameworkCore.Spanner.Benchmarks", "Google.Cloud.EntityFrameworkCore.Spanner.Benchmarks\Google.Cloud.EntityFrameworkCore.Spanner.Benchmarks.csproj", "{A1721F06-FE32-408E-BFAF-A6C94E28E9B0}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net", "..\..\GolandProjects\go-sql-spanner\drivers\spanner-ado-net\spanner-ado-net\spanner-ado-net.csproj", "{FE047135-AD08-4942-981D-BF4D730E1A94}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -39,6 +41,10 @@ Global {A1721F06-FE32-408E-BFAF-A6C94E28E9B0}.Debug|Any CPU.Build.0 = Debug|Any CPU {A1721F06-FE32-408E-BFAF-A6C94E28E9B0}.Release|Any CPU.ActiveCfg = Release|Any CPU {A1721F06-FE32-408E-BFAF-A6C94E28E9B0}.Release|Any CPU.Build.0 = Release|Any CPU + {FE047135-AD08-4942-981D-BF4D730E1A94}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FE047135-AD08-4942-981D-BF4D730E1A94}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FE047135-AD08-4942-981D-BF4D730E1A94}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FE047135-AD08-4942-981D-BF4D730E1A94}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/QueryableExtensions.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/QueryableExtensions.cs index 6bf71b09..54d76c0f 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/QueryableExtensions.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/QueryableExtensions.cs @@ -23,6 +23,9 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; +using SpannerCommand = Google.Cloud.Spanner.DataProvider.SpannerCommand; namespace Google.Cloud.EntityFrameworkCore.Spanner.Extensions { @@ -97,6 +100,10 @@ public override InterceptionResult ReaderExecuting( { ManipulateCommand(cmd); } + else if (command is SpannerCommand spannerCommand) + { + ManipulateCommand(spannerCommand); + } return result; } @@ -126,6 +133,46 @@ private static void ManipulateCommand(SpannerRetriableCommand command) } } } + + private static void ManipulateCommand(SpannerCommand command) + { + var hint = s_supportedHints.FirstOrDefault(hint => hint.IsHint(command.CommandText)); + if (hint != null) + { + try + { + var timestampBound = hint.CreateTimestampBound(command.CommandText); + command.SingleUseReadOnlyTransactionOptions = new TransactionOptions.Types.ReadOnly + { + ReturnReadTimestamp = true, + }; + switch (timestampBound.Mode) + { + case TimestampBoundMode.Strong: + command.SingleUseReadOnlyTransactionOptions.Strong = true; + break; + case TimestampBoundMode.ExactStaleness: + command.SingleUseReadOnlyTransactionOptions.ExactStaleness = Duration.FromTimeSpan(timestampBound.Staleness); + break; + case TimestampBoundMode.MaxStaleness: + command.SingleUseReadOnlyTransactionOptions.MaxStaleness = Duration.FromTimeSpan(timestampBound.Staleness); + break; + case TimestampBoundMode.ReadTimestamp: + command.SingleUseReadOnlyTransactionOptions.ReadTimestamp = Timestamp.FromDateTime(timestampBound.Timestamp); + break; + case TimestampBoundMode.MinReadTimestamp: + command.SingleUseReadOnlyTransactionOptions.MinReadTimestamp = Timestamp.FromDateTime(timestampBound.Timestamp); + break; + } + } + catch (Exception) + { + // Ignore any invalid timestamp bound in the comment. + // That could happen if someone by chance happened to manually add a comment that is the same + // as a timestamp bound hint, but with an invalid value. + } + } + } } internal abstract class TimestampBoundHint diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerDbContextOptionsExtensions.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerDbContextOptionsExtensions.cs index 961dfaff..c26bdfc5 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerDbContextOptionsExtensions.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerDbContextOptionsExtensions.cs @@ -24,6 +24,7 @@ using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using System; +using System.Data.Common; using System.Threading; namespace Google.Cloud.EntityFrameworkCore.Spanner.Extensions @@ -106,7 +107,7 @@ public static DbContextOptionsBuilder UseSpanner( /// The optionsBuilder for chaining internal static DbContextOptionsBuilder UseSpanner( this DbContextOptionsBuilder optionsBuilder, - SpannerRetriableConnection connection, + DbConnection connection, Action spannerOptionsAction = null, ChannelCredentials channelCredentials = null) { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerIDbContextTransactionExtensions.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerIDbContextTransactionExtensions.cs index 3906368d..2c0b71d6 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerIDbContextTransactionExtensions.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Extensions/SpannerIDbContextTransactionExtensions.cs @@ -17,6 +17,7 @@ using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Storage; using System; +using Google.Cloud.Spanner.DataProvider; namespace Google.Cloud.EntityFrameworkCore.Spanner.Extensions { @@ -33,8 +34,21 @@ public static class SpannerIDbContextTransactionExtensions /// If the transaction is not a read/write Spanner transaction public static void DisableInternalRetries([NotNull] this IDbContextTransaction dbContextTransaction) { - GaxPreconditions.CheckArgument(dbContextTransaction.GetDbTransaction() is SpannerRetriableTransaction, nameof(dbContextTransaction), "Must be a read/write Spanner transaction"); - ((SpannerRetriableTransaction) dbContextTransaction.GetDbTransaction()).EnableInternalRetries = false; + var dbTx = dbContextTransaction.GetDbTransaction(); + if (dbTx is SpannerRetriableTransaction retriableTransaction) + { + retriableTransaction.EnableInternalRetries = false; + } + else if (dbTx is SpannerTransaction spannerTransaction) + { + var cmd = spannerTransaction.Connection!.CreateCommand(); + cmd.CommandText = "set local retry_aborts_internally = false"; + cmd.ExecuteNonQuery(); + } + else + { + GaxPreconditions.CheckArgument(dbContextTransaction.GetDbTransaction() is SpannerRetriableTransaction, nameof(dbContextTransaction), "Must be a read/write Spanner transaction"); + } } /// diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Google.Cloud.EntityFrameworkCore.Spanner.csproj b/Google.Cloud.EntityFrameworkCore.Spanner/Google.Cloud.EntityFrameworkCore.Spanner.csproj index f962fab6..a01785b9 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Google.Cloud.EntityFrameworkCore.Spanner.csproj +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Google.Cloud.EntityFrameworkCore.Spanner.csproj @@ -39,4 +39,8 @@ + + + + \ No newline at end of file diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Migrations/Internal/SpannerMigrationCommandExecutor.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Migrations/Internal/SpannerMigrationCommandExecutor.cs index 88cc1926..5ddeeefd 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Migrations/Internal/SpannerMigrationCommandExecutor.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Migrations/Internal/SpannerMigrationCommandExecutor.cs @@ -41,23 +41,54 @@ public async Task ExecuteNonQueryAsync(IEnumerable migrationCo } var ddlStatements = statements.Where(IsDdlStatement).ToArray(); var otherStatements = statements.Where(x => !IsDdlStatement(x)); - var spannerConnection = ((SpannerRelationalConnection) connection).DbConnection as SpannerRetriableConnection; - if (ddlStatements.Any()) + var dbConnection = ((SpannerRelationalConnection)connection).DbConnection; + if (dbConnection is SpannerRetriableConnection spannerConnection) { - var cmd = spannerConnection.CreateDdlCommand(ddlStatements[0], ddlStatements.Skip(1).ToArray()); - await cmd.ExecuteNonQueryAsync(cancellationToken); + if (ddlStatements.Any()) + { + var cmd = spannerConnection.CreateDdlCommand(ddlStatements[0], ddlStatements.Skip(1).ToArray()); + await cmd.ExecuteNonQueryAsync(cancellationToken); + } + if (otherStatements.Any()) + { + using var transaction = await spannerConnection.BeginTransactionAsync(cancellationToken); + var cmd = spannerConnection.CreateBatchDmlCommand(); + cmd.Transaction = transaction; + foreach (var statement in otherStatements) + { + cmd.Add(statement); + } + await cmd.ExecuteNonQueryAsync(cancellationToken); + await transaction.CommitAsync(cancellationToken); + } } - if (otherStatements.Any()) + else { - using var transaction = await spannerConnection.BeginTransactionAsync(cancellationToken); - var cmd = spannerConnection.CreateBatchDmlCommand(); - cmd.Transaction = transaction; - foreach (var statement in otherStatements) + if (ddlStatements.Any()) + { + var batch = dbConnection.CreateBatch(); + foreach (var statement in ddlStatements) + { + var cmd = batch.CreateBatchCommand(); + cmd.CommandText = statement; + batch.BatchCommands.Add(cmd); + } + await batch.ExecuteNonQueryAsync(cancellationToken); + } + if (otherStatements.Any()) { - cmd.Add(statement); + using var transaction = await dbConnection.BeginTransactionAsync(cancellationToken); + var batch = dbConnection.CreateBatch(); + batch.Transaction = transaction; + foreach (var statement in otherStatements) + { + var cmd = batch.CreateBatchCommand(); + cmd.CommandText = statement; + batch.BatchCommands.Add(cmd); + } + await batch.ExecuteNonQueryAsync(cancellationToken); + await transaction.CommitAsync(cancellationToken); } - await cmd.ExecuteNonQueryAsync(cancellationToken); - await transaction.CommitAsync(cancellationToken); } } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerArrayTypes.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerArrayTypes.cs new file mode 100644 index 00000000..7934f3fd --- /dev/null +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerArrayTypes.cs @@ -0,0 +1,10 @@ +using Google.Cloud.Spanner.V1; + +namespace Google.Cloud.EntityFrameworkCore.Spanner.Storage.Internal; + +internal static class SpannerArrayTypes +{ + internal static readonly Cloud.Spanner.V1.Type SArrayOfDateType = new() { Code = TypeCode.Array, ArrayElementType = new Cloud.Spanner.V1.Type{Code = TypeCode.Date}}; + internal static readonly Cloud.Spanner.V1.Type SArrayOfJsonType = new() { Code = TypeCode.Array, ArrayElementType = new Cloud.Spanner.V1.Type{Code = TypeCode.Json}}; + +} \ No newline at end of file diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerComplexTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerComplexTypeMapping.cs index 22bf8038..e7d0988e 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerComplexTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerComplexTypeMapping.cs @@ -69,15 +69,24 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - - base.ConfigureParameter(parameter); - if (!IsArrayType && Size.HasValue && Size.Value > 0) + if (parameter is SpannerParameter spannerParameter) + { + base.ConfigureParameter(parameter); + if (!IsArrayType && Size.HasValue && Size.Value > 0) + { + parameter.Size = Size.Value; + } + spannerParameter.SpannerDbType = _complexType; + } + else if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter) + { + base.ConfigureParameter(parameter); + } + else { - parameter.Size = Size.Value; + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); } - spannerParameter.SpannerDbType = _complexType; } } } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDatabaseCreator.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDatabaseCreator.cs index 0e92f920..060c9132 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDatabaseCreator.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDatabaseCreator.cs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using Google.Api.Gax; using Google.Cloud.EntityFrameworkCore.Spanner.Migrations.Operations; using Google.Cloud.Spanner.Data; @@ -22,6 +23,8 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Google.Cloud.Spanner.DataProvider; +using Google.Rpc; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; @@ -55,6 +58,7 @@ public SpannerDatabaseCreator( public override void Create() { using var masterConnection = _connection.CreateMasterConnection(); + masterConnection.Open(); Dependencies.MigrationCommandExecutor .ExecuteNonQuery(CreateCreateOperations(), masterConnection); } @@ -63,7 +67,8 @@ public override void Create() /// public override async Task CreateAsync(CancellationToken cancellationToken = default) { - using var masterConnection = _connection.CreateMasterConnection(); + await using var masterConnection = _connection.CreateMasterConnection(); + await masterConnection.OpenAsync(cancellationToken); await Dependencies.MigrationCommandExecutor .ExecuteNonQueryAsync(CreateCreateOperations(), masterConnection, cancellationToken) .ConfigureAwait(false); @@ -130,6 +135,10 @@ public override async Task ExistsAsync(CancellationToken cancellationToken { return false; } + catch (SpannerDbException e) when (e.Status.Code == (int) Code.NotFound) + { + return false; + } return true; } @@ -137,6 +146,7 @@ public override async Task ExistsAsync(CancellationToken cancellationToken public override void Delete() { using var masterConnection = _connection.CreateMasterConnection(); + masterConnection.Open(); Dependencies.MigrationCommandExecutor .ExecuteNonQuery(CreateDropCommands(), masterConnection); } @@ -144,7 +154,8 @@ public override void Delete() /// public override async Task DeleteAsync(CancellationToken cancellationToken = default) { - using var masterConnection = _connection.CreateMasterConnection(); + await using var masterConnection = _connection.CreateMasterConnection(); + await masterConnection.OpenAsync(cancellationToken); await Dependencies.MigrationCommandExecutor .ExecuteNonQueryAsync(CreateDropCommands(), masterConnection, cancellationToken) .ConfigureAwait(false); diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateArrayTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateArrayTypeMapping.cs index fdd6db7a..f435f88c 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateArrayTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateArrayTypeMapping.cs @@ -44,11 +44,20 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfDateType; + } + else + { + if (!(parameter is SpannerParameter spannerParameter)) + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + base.ConfigureParameter(parameter); + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateListTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateListTypeMapping.cs index 1ab50058..d3ff6a41 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateListTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDateListTypeMapping.cs @@ -45,11 +45,20 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfDateType; + } + else + { + if (!(parameter is SpannerParameter spannerParameter)) + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + base.ConfigureParameter(parameter); + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDecimalTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDecimalTypeMapping.cs index d62fe47f..3b8265d4 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDecimalTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerDecimalTypeMapping.cs @@ -26,22 +26,22 @@ internal class SpannerDecimalTypeMapping() : RelationalTypeMapping(new Relationa protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) => new SpannerDecimalTypeMapping(); - public override DbParameter CreateParameter( - DbCommand command, - string name, -#nullable enable - object? value, -#nullable disable - bool? nullable = null, - ParameterDirection direction = ParameterDirection.Input) - { - return new SpannerParameter(name, SpannerDbType.Numeric, value); - } - - protected override void ConfigureParameter(DbParameter parameter) - { - ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Numeric; - base.ConfigureParameter(parameter); - } +// public override DbParameter CreateParameter( +// DbCommand command, +// string name, +// #nullable enable +// object? value, +// #nullable disable +// bool? nullable = null, +// ParameterDirection direction = ParameterDirection.Input) +// { +// return new SpannerParameter(name, SpannerDbType.Numeric, value); +// } +// +// protected override void ConfigureParameter(DbParameter parameter) +// { +// ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Numeric; +// base.ConfigureParameter(parameter); +// } } } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonArrayTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonArrayTypeMapping.cs index 95c97826..8bc8bb52 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonArrayTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonArrayTypeMapping.cs @@ -46,11 +46,20 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfJsonType; + } + else + { + if (!(parameter is SpannerParameter spannerParameter)) + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Json); + base.ConfigureParameter(parameter); + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Json); + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonListTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonListTypeMapping.cs index c67c1615..f251d41f 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonListTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonListTypeMapping.cs @@ -46,11 +46,23 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfJsonType; + } + else + { + if (!(parameter is SpannerParameter)) + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Json); + base.ConfigureParameter(parameter); + if (parameter is SpannerParameter spannerParameter) + { + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Json); + } + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonTypeMapping.cs index 053f6678..6fc39d33 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerJsonTypeMapping.cs @@ -44,7 +44,10 @@ protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters p protected override void ConfigureParameter(DbParameter parameter) { - ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Json; + if (parameter is SpannerParameter spannerParameter) + { + spannerParameter.SpannerDbType = SpannerDbType.Json; + } base.ConfigureParameter(parameter); } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateArrayTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateArrayTypeMapping.cs index 136e6180..10b7b7a6 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateArrayTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateArrayTypeMapping.cs @@ -44,11 +44,20 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfDateType; + } + else + { + if (!(parameter is SpannerParameter spannerParameter)) + throw new ArgumentException( + $"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + base.ConfigureParameter(parameter); + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateListTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateListTypeMapping.cs index 88cf9d6b..83250e5c 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateListTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNullableDateListTypeMapping.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Data.Common; using System.Linq; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; namespace Google.Cloud.EntityFrameworkCore.Spanner.Storage.Internal { @@ -45,11 +46,19 @@ protected override void ConfigureParameter(DbParameter parameter) // This key step will configure our SpannerParameter with this complex type, which will result in // the proper type conversions when the requests go out. - if (!(parameter is SpannerParameter spannerParameter)) - throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); + if (parameter is Google.Cloud.Spanner.DataProvider.SpannerParameter spannerDriverParameter) + { + base.ConfigureParameter(parameter); + spannerDriverParameter.SpannerParameterType = SpannerArrayTypes.SArrayOfDateType; + } + else + { + if (!(parameter is SpannerParameter spannerParameter)) + throw new ArgumentException($"Spanner-specific type mapping {GetType().Name} being used with non-Spanner parameter type {parameter.GetType().Name}"); - base.ConfigureParameter(parameter); - spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + base.ConfigureParameter(parameter); + spannerParameter.SpannerDbType = SpannerDbType.ArrayOf(SpannerDbType.Date); + } } protected override string GenerateNonNullSqlLiteral(object value) diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNumericTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNumericTypeMapping.cs index a95334ca..04343914 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNumericTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerNumericTypeMapping.cs @@ -29,23 +29,23 @@ public SpannerNumericTypeMapping() protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) => new SpannerNumericTypeMapping(); - public override DbParameter CreateParameter( - DbCommand command, - string name, -#nullable enable - object? value, -#nullable disable - bool? nullable = null, - ParameterDirection direction = ParameterDirection.Input) - { - // TODO: Remove once the default mapping of type NUMERIC has been added to the client library. - return new SpannerParameter(name, SpannerDbType.Numeric, value); - } - - protected override void ConfigureParameter(DbParameter parameter) - { - ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Numeric; - base.ConfigureParameter(parameter); - } +// public override DbParameter CreateParameter( +// DbCommand command, +// string name, +// #nullable enable +// object? value, +// #nullable disable +// bool? nullable = null, +// ParameterDirection direction = ParameterDirection.Input) +// { +// // TODO: Remove once the default mapping of type NUMERIC has been added to the client library. +// return new SpannerParameter(name, SpannerDbType.Numeric, value); +// } +// +// protected override void ConfigureParameter(DbParameter parameter) +// { +// ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Numeric; +// base.ConfigureParameter(parameter); +// } } } diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRelationalConnection.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRelationalConnection.cs index ae5f3f7c..5bc32d23 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRelationalConnection.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRelationalConnection.cs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using Google.Cloud.EntityFrameworkCore.Spanner.Extensions; using Google.Cloud.EntityFrameworkCore.Spanner.Infrastructure; using Google.Cloud.EntityFrameworkCore.Spanner.Infrastructure.Internal; @@ -22,6 +23,8 @@ using System.Data.Common; using System.Threading; using System.Threading.Tasks; +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; namespace Google.Cloud.EntityFrameworkCore.Spanner.Storage.Internal { @@ -42,7 +45,7 @@ public SpannerRelationalConnection(RelationalConnectionDependencies dependencies ConnectionStringBuilder = relationalOptions.ConnectionStringBuilder; } - private SpannerRetriableConnection Connection => DbConnection as SpannerRetriableConnection; + //private SpannerRetriableConnection Connection => DbConnection as SpannerRetriableConnection; public MutationUsage MutationUsage { get; } @@ -56,8 +59,17 @@ protected override DbConnection CreateDbConnection() ConnectionString = ConnectionString, SessionPoolManager = SpannerDbContextOptionsExtensions.SessionPoolManager }; - var con = new SpannerConnection(builder); - return new SpannerRetriableConnection(con); + // if (ConnectionString!.StartsWith("Data Source=", StringComparison.OrdinalIgnoreCase)) + // { + // var con = new SpannerConnection(builder); + // return new SpannerRetriableConnection(con); + // } + // else + // { + var con = new Google.Cloud.Spanner.DataProvider.SpannerConnection(); + con.ConnectionString = builder.ConnectionString; + return con; + // } } /// @@ -71,8 +83,50 @@ protected override DbConnection CreateDbConnection() /// /// The read timestamp to use for the transaction /// A read-only transaction that uses the specified - public IDbContextTransaction BeginReadOnlyTransaction(TimestampBound timestampBound) => - UseTransaction(Connection.BeginReadOnlyTransaction(timestampBound)); + public IDbContextTransaction BeginReadOnlyTransaction(TimestampBound timestampBound) + { + if (DbConnection is SpannerRetriableConnection connection) + { + return UseTransaction(connection.BeginReadOnlyTransaction(timestampBound)); + } + if (DbConnection is Google.Cloud.Spanner.DataProvider.SpannerConnection spannerConnection) + { + return UseTransaction(spannerConnection.BeginTransaction(CreateTransactionOptions(timestampBound))); + } + throw new ArgumentException("Not a Spanner connection"); + } + + private static TransactionOptions CreateTransactionOptions(TimestampBound timestampBound) + { + TransactionOptions options = new TransactionOptions + { + ReadOnly = new TransactionOptions.Types.ReadOnly + { + ReturnReadTimestamp = timestampBound.ReturnReadTimestamp, + } + }; + switch (timestampBound.Mode) + { + case TimestampBoundMode.Strong: + options.ReadOnly.Strong = true; + break; + case TimestampBoundMode.ReadTimestamp: + options.ReadOnly.ReadTimestamp = Timestamp.FromDateTime(timestampBound.Timestamp); + break; + case TimestampBoundMode.MinReadTimestamp: + options.ReadOnly.MinReadTimestamp = Timestamp.FromDateTime(timestampBound.Timestamp); + break; + case TimestampBoundMode.ExactStaleness: + options.ReadOnly.ExactStaleness = Duration.FromTimeSpan(timestampBound.Staleness); + break; + case TimestampBoundMode.MaxStaleness: + options.ReadOnly.MaxStaleness = Duration.FromTimeSpan(timestampBound.Staleness); + break; + default: + throw new ArgumentOutOfRangeException(nameof(timestampBound.Mode), $"unknown timestampBound mode: {timestampBound.Mode}"); + } + return options; + } /// /// Begins a read-only transaction on this connection. @@ -86,8 +140,20 @@ public IDbContextTransaction BeginReadOnlyTransaction(TimestampBound timestampBo /// The read timestamp to use for the transaction /// A cancellation token to monitor for the asynchronous operation. /// A read-only transaction that uses the specified - public async Task BeginReadOnlyTransactionAsync(TimestampBound timestampBound, CancellationToken cancellationToken = default) => - await UseTransactionAsync(await Connection.BeginReadOnlyTransactionAsync(timestampBound, cancellationToken)); + public async Task BeginReadOnlyTransactionAsync(TimestampBound timestampBound, + CancellationToken cancellationToken = default) + { + if (DbConnection is SpannerRetriableConnection connection) + { + return await UseTransactionAsync( + await connection.BeginReadOnlyTransactionAsync(timestampBound, cancellationToken), cancellationToken); + } + if (DbConnection is Google.Cloud.Spanner.DataProvider.SpannerConnection spannerConnection) + { + return await UseTransactionAsync(spannerConnection.BeginTransaction(CreateTransactionOptions(timestampBound)), cancellationToken); + } + throw new ArgumentException("Not a Spanner connection"); + } /// /// Creates a connection to the Cloud Spanner instance that is referenced by . @@ -98,9 +164,16 @@ public ISpannerRelationalConnection CreateMasterConnection() // Spanner does not have anything like a master database, so we just return a new instance of a // RelationalConnection with the same options and dependencies. This ensures that all settings of the // underlying connection are carried over to the new RelationalConnection, such as credentials and host. - var masterConn = (SpannerRetriableConnection) CreateDbConnection(); + var masterConn = CreateDbConnection(); var optionsBuilder = new DbContextOptionsBuilder(); - optionsBuilder.UseSpanner(masterConn); + if (masterConn is SpannerRetriableConnection spannerRetriableConnection) + { + optionsBuilder.UseSpanner(spannerRetriableConnection); + } + else if (masterConn is Google.Cloud.Spanner.DataProvider.SpannerConnection spannerConnection) + { + optionsBuilder.UseSpanner(spannerConnection); + } #pragma warning disable EF1001 var dependencies = new RelationalConnectionDependencies( diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRetriableConnection.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRetriableConnection.cs index 74a64669..1a79519f 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRetriableConnection.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerRetriableConnection.cs @@ -19,6 +19,9 @@ using System.Data.Common; using System.Threading; using System.Threading.Tasks; +using SpannerCommand = Google.Cloud.Spanner.Data.SpannerCommand; +using SpannerConnection = Google.Cloud.Spanner.Data.SpannerConnection; +using SpannerParameterCollection = Google.Cloud.Spanner.Data.SpannerParameterCollection; namespace Google.Cloud.EntityFrameworkCore.Spanner.Storage.Internal { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerSqlGenerationHelper.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerSqlGenerationHelper.cs index 322ed595..0b4e9222 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerSqlGenerationHelper.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerSqlGenerationHelper.cs @@ -15,6 +15,7 @@ using Google.Api.Gax; using Microsoft.EntityFrameworkCore.Storage; using System.Text; +using Microsoft.EntityFrameworkCore.Query; namespace Google.Cloud.EntityFrameworkCore.Spanner.Storage.Internal { diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerStructuralJsonTypeMapping.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerStructuralJsonTypeMapping.cs index 3a59e30a..6b60ad78 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerStructuralJsonTypeMapping.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerStructuralJsonTypeMapping.cs @@ -70,6 +70,10 @@ protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters p protected override void ConfigureParameter(DbParameter parameter) { - ((SpannerParameter)parameter).SpannerDbType = SpannerDbType.Json; + if (parameter is SpannerParameter spannerParameter) + { + spannerParameter.SpannerDbType = SpannerDbType.Json; + } base.ConfigureParameter(parameter); - }} \ No newline at end of file + } +} \ No newline at end of file diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerTypeMappingSource.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerTypeMappingSource.cs index 9b00beeb..8dba6632 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerTypeMappingSource.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Storage/Internal/SpannerTypeMappingSource.cs @@ -39,7 +39,8 @@ private static readonly BoolTypeMapping s_bool private static readonly SpannerDateTypeMapping s_date = new SpannerDateTypeMapping(); - private static readonly SpannerDateOnlyTypeMapping s_dateonly = new (); + //private static readonly SpannerDateOnlyTypeMapping s_dateonly = new (); + private static readonly DateOnlyTypeMapping s_dateonly = new ("DATE"); private static readonly SpannerTimestampTypeMapping s_datetime = new SpannerTimestampTypeMapping(); @@ -161,6 +162,16 @@ private static readonly SpannerComplexTypeMapping s_intList private static readonly SpannerNullableDateListTypeMapping s_nullableDateList = new SpannerNullableDateListTypeMapping(); private static readonly SpannerDateListTypeMapping s_dateList = new SpannerDateListTypeMapping(); + + private static readonly SpannerComplexTypeMapping s_nullableDateOnlyArray + = new SpannerComplexTypeMapping(SpannerDbType.ArrayOf(SpannerDbType.Date), typeof(DateOnly?[])); + private static readonly SpannerComplexTypeMapping s_dateOnlyArray + = new SpannerComplexTypeMapping(SpannerDbType.ArrayOf(SpannerDbType.Date), typeof(DateOnly[])); + + private static readonly SpannerComplexTypeMapping s_nullableDateOnlyList + = new SpannerComplexTypeMapping(SpannerDbType.ArrayOf(SpannerDbType.Date), typeof(List)); + private static readonly SpannerComplexTypeMapping s_dateOnlyList + = new SpannerComplexTypeMapping(SpannerDbType.ArrayOf(SpannerDbType.Date), typeof(List)); private static readonly SpannerComplexTypeMapping s_nullableTimestampArray = new SpannerComplexTypeMapping(SpannerDbType.ArrayOf(SpannerDbType.Timestamp), typeof(DateTime?[])); @@ -259,6 +270,10 @@ public SpannerTypeMappingSource( {typeof(SpannerDate?[]), s_nullableDateArray}, {typeof(List), s_dateList}, {typeof(List), s_nullableDateList}, + {typeof(List), s_dateOnlyList}, + {typeof(List), s_nullableDateOnlyList}, + {typeof(DateOnly[]), s_dateOnlyArray}, + {typeof(DateOnly?[]), s_nullableDateOnlyArray}, {typeof(List), s_timestampList}, {typeof(List), s_nullableTimestampList}, {typeof(DateTime[]), s_timestampArray}, diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerModificationCommandBatch .cs b/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerModificationCommandBatch .cs index 0ea2ab37..9310c5ec 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerModificationCommandBatch .cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerModificationCommandBatch .cs @@ -27,6 +27,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using SpannerConnection = Google.Cloud.Spanner.DataProvider.SpannerConnection; namespace Google.Cloud.EntityFrameworkCore.Spanner.Update.Internal { @@ -49,7 +50,7 @@ internal sealed class SpannerModificationCommandBatch : ModificationCommandBatch { private readonly IRelationalTypeMappingSource _typeMapper; private readonly List _modificationCommands = new(); - private readonly List _propagateResultsCommands = new(); + private readonly List _propagateResultsCommands = new(); private readonly char _statementTerminator; private readonly bool _hasExplicitTransaction; private bool _areMoreBatchesExpected; @@ -137,7 +138,22 @@ public override async Task ExecuteAsync(IRelationalConnection connection, CancellationToken cancellationToken = default) { var spannerRelationalConnection = (SpannerRelationalConnection)connection; - var spannerConnection = (SpannerRetriableConnection)connection.DbConnection; + if (connection.DbConnection is SpannerRetriableConnection spannerRetriableConnection) + { + await ExecuteAsync(connection, spannerRelationalConnection, spannerRetriableConnection, cancellationToken); + } + else if (connection.DbConnection is SpannerConnection spannerDriverConnection) + { + await ExecuteAsync(connection, spannerRelationalConnection, spannerDriverConnection, cancellationToken); + } + else + { + throw new ArgumentException("Not a Spanner connection"); + } + } + + private async Task ExecuteAsync(IRelationalConnection connection, SpannerRelationalConnection spannerRelationalConnection, SpannerRetriableConnection spannerConnection, CancellationToken cancellationToken = default) + { // There should always be a transaction: // 1. Implicit: A transaction is automatically started by Entity Framework when SaveChanges() is called. // 2. Explicit: The client application has called BeginTransaction() on the database. @@ -148,9 +164,39 @@ public override async Task ExecuteAsync(IRelationalConnection connection, var containsReads = _modificationCommands.Any(c => c.ColumnModifications.Any(cm => cm.IsRead)); var useMutations = spannerRelationalConnection.MutationUsage == Infrastructure.MutationUsage.Always - || (!_hasExplicitTransaction - && !containsReads - && spannerRelationalConnection.MutationUsage == Infrastructure.MutationUsage.ImplicitTransactions); + || (!_hasExplicitTransaction + && !containsReads + && spannerRelationalConnection.MutationUsage == Infrastructure.MutationUsage.ImplicitTransactions); + if (useMutations) + { + await ExecuteMutationsAsync(spannerConnection, transaction, cancellationToken); + } + else if (containsReads) + { + await ExecuteDmlAsync(connection, spannerConnection, transaction, cancellationToken); + } + else + { + await ExecuteBatchDmlAsync(spannerConnection, transaction, cancellationToken); + } + } + + private async Task ExecuteAsync(IRelationalConnection connection, SpannerRelationalConnection spannerRelationalConnection, SpannerConnection spannerConnection, CancellationToken cancellationToken = default) + { + // There should always be a transaction: + // 1. Implicit: A transaction is automatically started by Entity Framework when SaveChanges() is called. + // 2. Explicit: The client application has called BeginTransaction() on the database. + if (connection.CurrentTransaction?.GetDbTransaction() == null) + { + throw new InvalidOperationException("There is no active transaction. Cloud Spanner does not support executing updates without a transaction."); + } + var transaction = connection.CurrentTransaction.GetDbTransaction(); + + var containsReads = _modificationCommands.Any(c => c.ColumnModifications.Any(cm => cm.IsRead)); + var useMutations = spannerRelationalConnection.MutationUsage == Infrastructure.MutationUsage.Always + || (!_hasExplicitTransaction + && !containsReads + && spannerRelationalConnection.MutationUsage == Infrastructure.MutationUsage.ImplicitTransactions); if (useMutations) { await ExecuteMutationsAsync(spannerConnection, transaction, cancellationToken); @@ -247,6 +293,85 @@ private async Task ExecuteMutationsAsync( } } + private async Task ExecuteMutationsAsync( + SpannerConnection spannerConnection, DbTransaction transaction, CancellationToken cancellationToken) + { + int index = 0; + foreach (var modificationCommand in _modificationCommands) + { + // We assume that each mutation will affect exactly one row. This assumption always holds for INSERT + // and UPDATE mutations (unless they return an error). DELETE mutations could affect zero rows if the + // row had already been deleted, and more than one row if the deleted row is in a table with one or + // more INTERLEAVED tables that are defined with ON DELETE CASCADE. + // + // This can be changed if a concurrency token check fails. + var updateCount = 1L; + + // Concurrency token checks cannot be included in mutations. Instead, we need to do manual select to check + // that the concurrency token is still the same as what we expect. This select is executed in the same + // transaction as the mutations, so it is guaranteed that the value that we read here will still be valid + // when the mutations are committed. + var operations = modificationCommand.ColumnModifications; + var hasConcurrencyCondition = operations.Any(o => o.IsCondition && (o.Property?.IsConcurrencyToken ?? false)); + if (hasConcurrencyCondition) + { + var conditionOperations = operations.Where(o => o.IsCondition).ToList(); + var concurrencySql = ((SpannerUpdateSqlGenerator)Dependencies.UpdateSqlGenerator).GenerateSelectConcurrencyCheckSql(modificationCommand.TableName, conditionOperations); + var concurrencyCommand = spannerConnection.CreateCommand(); + concurrencyCommand.CommandText = concurrencySql; + concurrencyCommand.Transaction = transaction; + foreach (var columnModification in conditionOperations) + { + concurrencyCommand.Parameters.Add(CreateParameter(columnModification, concurrencyCommand, UseValue.Original, false)); + } + // Execute the concurrency check query in the read/write transaction and check whether the expected row exists. + using var reader = await concurrencyCommand.ExecuteReaderAsync(cancellationToken); + if (!await reader.ReadAsync(cancellationToken)) + { + // Set the update count to 0 to trigger a concurrency exception. + // We do not throw the exception here already, as there might be more concurrency problems, + // and we want to be able to report all in the exception. + updateCount = 0L; + } + } + + // Mutation commands must use a specific TIMESTAMP constant for pending commit timestamps instead of the + // placeholder string PENDING_COMMIT_TIMESTAMP(). This instructs any pending commit timestamp modifications + // to use the mutation constant instead. + if (modificationCommand is SpannerPendingCommitTimestampModificationCommand commitTimestampModificationCommand) + { + commitTimestampModificationCommand.MarkAsMutationCommand(); + // TODO: Support pending commit timestamp modification commands for SpannerDriver. + // transaction.AddSpannerPendingCommitTimestampModificationCommand(commitTimestampModificationCommand); + } + // Create the mutation command and execute it. + var cmd = CreateSpannerMutationCommand(spannerConnection, transaction, modificationCommand); + // Note: The following line does not actually execute any command on the backend, it only buffers + // the mutation locally to be sent with the next Commit statement. + await cmd.ExecuteNonQueryAsync(cancellationToken); + UpdateCounts.Add(updateCount); + + // Check whether we need to generate a SELECT command to propagate computed values back to the context. + // This SELECT command will be executed outside of the current implicit transaction. + // The propagation query is skipped if the batch uses an explicit transaction, as it will not be able + // to read the new value anyways. + if (modificationCommand.ColumnModifications.Any(o => o.IsRead) && !_hasExplicitTransaction) + { + var keyOperations = operations.Where(o => o.IsKey).ToList(); + var readOperations = operations.Where(o => o.IsRead).ToList(); + var sql = ((SpannerUpdateSqlGenerator)Dependencies.UpdateSqlGenerator).GenerateSelectAffectedSql( + modificationCommand.TableName, modificationCommand.Schema, readOperations, keyOperations, index); + _propagateResultsCommands.Add(CreateSelectedAffectedCommand(spannerConnection, modificationCommand, sql)); + } + index++; + } + // Check that there were no concurrency problems detected. + if (RowsAffected != _modificationCommands.Count) + { + ThrowAggregateUpdateConcurrencyException(); + } + } + private async Task ExecuteDmlAsync(IRelationalConnection connection, SpannerRetriableConnection spannerConnection, SpannerRetriableTransaction transaction, CancellationToken cancellationToken) { var commands = CreateSpannerDmlCommands(spannerConnection, transaction); @@ -268,6 +393,28 @@ private async Task ExecuteDmlAsync(IRelationalConnection connection, SpannerRetr } UpdateCounts = updateCounts; } + + private async Task ExecuteDmlAsync(IRelationalConnection connection, SpannerConnection spannerConnection, DbTransaction transaction, CancellationToken cancellationToken) + { + var commands = CreateSpannerDmlCommands(spannerConnection, transaction); + var index = 0; + var updateCounts = new List(commands.Count); + foreach (var command in commands) + { + var reader = await command.ExecuteReaderAsync(cancellationToken); + var relationalReader = CreateRelationalDataReader(connection, command, reader); + var modificationCommand = _modificationCommands[index]; + var rowsAffected = 0L; + while (await relationalReader.ReadAsync(cancellationToken)) + { + modificationCommand.PropagateResults(relationalReader); + rowsAffected++; + } + updateCounts.Add(rowsAffected); + index++; + } + UpdateCounts = updateCounts; + } /// /// Executes the command batch using DML. DML is less efficient than mutations, but do allow applications @@ -290,6 +437,23 @@ private async Task ExecuteBatchDmlAsync(SpannerRetriableConnection spannerConnec _propagateResultsCommands.AddRange(cmd.Item2); } } + + private async Task ExecuteBatchDmlAsync(SpannerConnection spannerConnection, DbTransaction transaction, CancellationToken cancellationToken) + { + // Create a Batch DML command that contains all the updates in this batch. + // The update statements will include any concurrency token checks that might be needed. + var cmd = CreateSpannerBatchDmlCommand(spannerConnection, transaction); + UpdateCounts = spannerConnection.ExecuteBatchDml(cmd.Item1).ToList(); + if (RowsAffected != _modificationCommands.Count) + { + ThrowAggregateUpdateConcurrencyException(); + } + // Add any select commands that were generated by the batch for updates that need to propagate results. + if (cmd.Item2.Count > 0) + { + _propagateResultsCommands.AddRange(cmd.Item2); + } + } /// /// Constructs and throws a DbUpdateConcurrencyException for this batch based on the UpdateCounts. @@ -382,6 +546,29 @@ private List CreateSpannerDmlCommands(SpannerRetriableC } return commands; } + + private List CreateSpannerDmlCommands(SpannerConnection connection, DbTransaction transaction) + { + var commands = new List(); + var commandPosition = 0; + foreach (var modificationCommand in _modificationCommands) + { + var command = CreateSpannerDmlCommand(_thenReturnSqlGenerator, connection, modificationCommand, commandPosition); + command.Item1.Transaction = transaction; + commands.Add(command.Item1); + if (command.Item2 != null) + { + throw new ArgumentException(); + } + if (modificationCommand is SpannerPendingCommitTimestampModificationCommand commitTimestampModificationCommand) + { + // TODO: Support pending commit timestamps + // transaction.AddSpannerPendingCommitTimestampModificationCommand(commitTimestampModificationCommand); + } + commandPosition++; + } + return commands; + } /// /// Generates a Batch DML command for the modifications in this batch and SELECT statements for any @@ -413,6 +600,34 @@ private Tuple> Creat } return Tuple.Create(cmd, selectCommands); } + + private Tuple, List> CreateSpannerBatchDmlCommand(SpannerConnection connection, DbTransaction transaction) + { + var dmlCommands = new List(); + var selectCommands = new List(); + var commandPosition = 0; + foreach (var modificationCommand in _modificationCommands) + { + var commands = CreateSpannerDmlCommand(Dependencies.UpdateSqlGenerator, connection, modificationCommand, commandPosition); + commands.Item1.Transaction = transaction; + dmlCommands.Add(commands.Item1); + if (commands.Item2 != null) + { + if (_hasExplicitTransaction) + { + commands.Item2.Transaction = transaction; + } + selectCommands.Add(commands.Item2); + } + if (modificationCommand is SpannerPendingCommitTimestampModificationCommand commitTimestampModificationCommand) + { + // TODO: support pending commit timestamps + // transaction.AddSpannerPendingCommitTimestampModificationCommand(commitTimestampModificationCommand); + } + commandPosition++; + } + return Tuple.Create(dmlCommands, selectCommands); + } private Tuple CreateSpannerDmlCommand( IUpdateSqlGenerator updateSqlGenerator, @@ -460,6 +675,53 @@ private Tuple CreateSpannerDmlCommand( return Tuple.Create(cmd, selectCommand); } + private Tuple CreateSpannerDmlCommand( + IUpdateSqlGenerator updateSqlGenerator, + SpannerConnection connection, + IReadOnlyModificationCommand modificationCommand, + int commandPosition) + { + var builder = new StringBuilder(); + ResultSetMapping res; + switch (modificationCommand.EntityState) + { + case EntityState.Deleted: + res = updateSqlGenerator.AppendDeleteOperation(builder, modificationCommand, commandPosition); + break; + case EntityState.Modified: + res = updateSqlGenerator.AppendUpdateOperation(builder, modificationCommand, commandPosition); + break; + case EntityState.Added: + res = updateSqlGenerator.AppendInsertOperation(builder, modificationCommand, commandPosition); + break; + default: + throw new NotSupportedException( + $"Modification type {modificationCommand.EntityState} is not supported."); + } + string dml; + DbCommand selectCommand = null; + if (res != ResultSetMapping.NoResults) + { + var commandTexts = builder.ToString().Split(_statementTerminator); + dml = commandTexts[0]; + if (commandTexts.Length > 1) + { + selectCommand = CreateSelectedAffectedCommand(connection, modificationCommand, commandTexts[1]); + } + } + else + { + dml = builder.ToString(); + dml = dml.TrimEnd('\r', '\n', _statementTerminator); + } + // This intentionally uses a SpannerCommand instead of the internal SpannerRetriableCommand, because the command + // could eventually be added to a BatchCommand. + var cmd = connection.CreateCommand(); + cmd.CommandText = dml; + AppendWriteParameters(modificationCommand, cmd, false, true); + return Tuple.Create(cmd, selectCommand); + } + private SpannerRetriableCommand CreateSpannerMutationCommand( SpannerRetriableConnection spannerConnection, SpannerRetriableTransaction transaction, @@ -477,6 +739,23 @@ private SpannerRetriableCommand CreateSpannerMutationCommand( return cmd; } + private DbCommand CreateSpannerMutationCommand( + SpannerConnection spannerConnection, + DbTransaction transaction, + IReadOnlyModificationCommand modificationCommand) + { + var cmd = modificationCommand.EntityState switch + { + EntityState.Deleted => spannerConnection.CreateDeleteCommand(modificationCommand.TableName), + EntityState.Modified => spannerConnection.CreateUpdateCommand(modificationCommand.TableName), + EntityState.Added => spannerConnection.CreateInsertCommand(modificationCommand.TableName), + _ => throw new NotSupportedException($"Modification type {modificationCommand.EntityState} is not supported."), + }; + cmd.Transaction = transaction; + AppendWriteParameters(modificationCommand, cmd, true, false); + return cmd; + } + /// /// Adds the parameters that need to be written for an update command. This can be both a DML and a mutation command. /// @@ -530,6 +809,20 @@ private SpannerRetriableCommand CreateSelectedAffectedCommand(SpannerRetriableCo } return selectCommand; } + + private DbCommand CreateSelectedAffectedCommand(SpannerConnection connection, IReadOnlyModificationCommand modificationCommand, string sql) + { + var selectCommand = connection.CreateCommand(); + selectCommand.CommandText = sql; + foreach (var columnModification in modificationCommand.ColumnModifications) + { + if (columnModification.IsKey && (columnModification.UseOriginalValueParameter || columnModification.UseCurrentValueParameter)) + { + selectCommand.Parameters.Add(CreateParameter(columnModification, selectCommand, columnModification.UseOriginalValueParameter ? UseValue.Original : UseValue.Current, false)); + } + } + return selectCommand; + } /// /// Creates a SpannerParameter for a command and sets the correct type. diff --git a/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerPendingCommitTimestampModification.cs b/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerPendingCommitTimestampModification.cs index cb447f98..4fd7c6be 100644 --- a/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerPendingCommitTimestampModification.cs +++ b/Google.Cloud.EntityFrameworkCore.Spanner/Update/Internal/SpannerPendingCommitTimestampModification.cs @@ -60,7 +60,7 @@ internal SpannerPendingCommitTimestampColumnModification(IUpdateEntry entry, IPr public override object Value { - get => IsMutationColumnModification ? SpannerParameter.CommitTimestamp : PendingCommitTimestampValue; + get => IsMutationColumnModification ? "spanner.commit_timestamp()" : PendingCommitTimestampValue; set => base.Value = value; } }