diff --git a/.gitignore b/.gitignore index 3ec2e7b8..4d1ab2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,4 @@ Paket.Restore.targets .paket /docs/output docs/output/**/*.* +*.orig diff --git a/src/SqlClient.Tests/Lib/Lib.fsproj b/src/SqlClient.Tests/Lib/Lib.fsproj index 2c7d6195..ad6a20c6 100644 --- a/src/SqlClient.Tests/Lib/Lib.fsproj +++ b/src/SqlClient.Tests/Lib/Lib.fsproj @@ -53,6 +53,7 @@ + diff --git a/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj b/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj index aff864cf..3df52d8a 100644 --- a/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj +++ b/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj @@ -58,6 +58,7 @@ + diff --git a/src/SqlClient.Tests/SqlClient.Tests.fsproj b/src/SqlClient.Tests/SqlClient.Tests.fsproj index 3a995590..531c7154 100644 --- a/src/SqlClient.Tests/SqlClient.Tests.fsproj +++ b/src/SqlClient.Tests/SqlClient.Tests.fsproj @@ -87,6 +87,7 @@ + diff --git a/src/SqlClient.Tests/TVPTests.fs b/src/SqlClient.Tests/TVPTests.fs index 7885530d..52751f0d 100644 --- a/src/SqlClient.Tests/TVPTests.fs +++ b/src/SqlClient.Tests/TVPTests.fs @@ -145,3 +145,23 @@ let UsingTVPInQuery() = |> Seq.toList Assert.Equal<_ list>(expected, actual) + +type MappedTVP = + SqlCommandProvider<" + SELECT myId, myName from @input + ", ConnectionStrings.AdventureWorksLiteral, TableVarMapping = "@input=dbo.MyTableType"> +[] +let UsingMappedTVPInQuery() = + printfn "%s" ConnectionStrings.AdventureWorksLiteral + use cmd = new MappedTVP(ConnectionStrings.AdventureWorksLiteral) + let expected = [ + 1, Some "monkey" + 2, Some "donkey" + ] + + let actual = + cmd.Execute(input = [ for id, name in expected -> MappedTVP.MyTableType(id, name) ]) + |> Seq.map(fun x -> x.myId, x.myName) + |> Seq.toList + + Assert.Equal<_ list>(expected, actual) diff --git a/src/SqlClient.Tests/TempTableTests.fs b/src/SqlClient.Tests/TempTableTests.fs new file mode 100644 index 00000000..afe39753 --- /dev/null +++ b/src/SqlClient.Tests/TempTableTests.fs @@ -0,0 +1,106 @@ +module FSharp.Data.TempTableTests + +open FSharp.Data +open Xunit +open System.Data.SqlClient + +type TempTable = + SqlCommandProvider< + TempTableDefinitions = " + CREATE TABLE #Temp ( + Id INT NOT NULL, + Name NVARCHAR(100) NULL)", + CommandText = " + SELECT Id, Name FROM #Temp", + ConnectionStringOrName = + ConnectionStrings.AdventureWorksLiteral> + +[] +let usingTempTable() = + use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral) + conn.Open() + + use cmd = new TempTable(conn) + + cmd.LoadTempTables( + Temp = + [ TempTable.Temp(Id = 1, Name = Some "monkey") + TempTable.Temp(Id = 2, Name = Some "donkey") ]) + + let actual = + cmd.Execute() + |> Seq.map(fun x -> x.Id, x.Name) + |> Seq.toList + + let expected = [ + 1, Some "monkey" + 2, Some "donkey" + ] + + Assert.Equal<_ list>(expected, actual) + +[] +let queryWithHash() = + // We shouldn't mangle the statement when it's run + use cmd = + new SqlCommandProvider< + CommandText = " + SELECT Id, Name + FROM + ( + SELECT 1 AS Id, '#name' AS Name UNION + SELECT 2, 'some other value' + ) AS a + WHERE Name = '#name'", + ConnectionStringOrName = + ConnectionStrings.AdventureWorksLiteral>(ConnectionStrings.AdventureWorksLiteral) + + let actual = + cmd.Execute() + |> Seq.map(fun x -> x.Id, x.Name) + |> Seq.toList + + let expected = [ + 1, "#name" + ] + + Assert.Equal<_ list>(expected, actual) + +type TempTableHash = + SqlCommandProvider< + TempTableDefinitions = " + CREATE TABLE #Temp ( + Id INT NOT NULL)", + CommandText = " + SELECT a.Id, a.Name + FROM + ( + SELECT 1 AS Id, '#Temp' AS Name UNION + SELECT 2, 'some other value' + ) AS a + INNER JOIN #Temp t ON t.Id = a.Id", + ConnectionStringOrName = + ConnectionStrings.AdventureWorksLiteral> + +[] +let queryWithHashAndTempTable() = + // We shouldn't mangle the statement when it's run + use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral) + conn.Open() + + use cmd = new TempTableHash(conn) + + cmd.LoadTempTables( + Temp = + [ TempTableHash.Temp(Id = 1) ]) + + let actual = + cmd.Execute() + |> Seq.map(fun x -> x.Id, x.Name) + |> Seq.toList + + let expected = [ + 1, "#Temp" + ] + + Assert.Equal<_ list>(expected, actual) \ No newline at end of file diff --git a/src/SqlClient/AssemblyInfo.fs b/src/SqlClient/AssemblyInfo.fs index a61e0367..9cc5187c 100644 --- a/src/SqlClient/AssemblyInfo.fs +++ b/src/SqlClient/AssemblyInfo.fs @@ -1,4 +1,5 @@ -namespace System +// Auto-Generated by FAKE; do not edit +namespace System open System.Reflection open System.Runtime.CompilerServices @@ -11,4 +12,9 @@ open System.Runtime.CompilerServices do () module internal AssemblyVersionInformation = - let [] Version = "1.8.4" + let [] AssemblyTitle = "SqlClient" + let [] AssemblyProduct = "FSharp.Data.SqlClient" + let [] AssemblyDescription = "SqlClient F# type providers" + let [] AssemblyVersion = "1.8.4" + let [] AssemblyFileVersion = "1.8.4" + let [] InternalsVisibleTo = "SqlClient.Tests" diff --git a/src/SqlClient/DesignTime.fs b/src/SqlClient/DesignTime.fs index 85e20d51..faaa5449 100644 --- a/src/SqlClient/DesignTime.fs +++ b/src/SqlClient/DesignTime.fs @@ -10,6 +10,7 @@ open System.Diagnostics open Microsoft.FSharp.Quotations open ProviderImplementation.ProvidedTypes open FSharp.Data +open System.Text.RegularExpressions type internal RowType = { Provided: Type @@ -40,7 +41,52 @@ module internal SharedLogic = // add .Table returnType.Single |> cmdProvidedType.AddMember -type DesignTime private() = +module Prefixes = + let tempTable = "##SQLCOMMANDPROVIDER_" + let tableVar = "@SQLCOMMANDPROVIDER_" + +type TempTableLoader(fieldCount, items: obj seq) = + let enumerator = items.GetEnumerator() + + interface IDataReader with + member this.FieldCount: int = fieldCount + member this.Read(): bool = enumerator.MoveNext() + member this.GetValue(i: int): obj = + let row : obj[] = unbox enumerator.Current + row.[i] + member this.Dispose(): unit = () + + member __.Close(): unit = invalidOp "NotImplementedException" + member __.Depth: int = invalidOp "NotImplementedException" + member __.GetBoolean(_: int): bool = invalidOp "NotImplementedException" + member __.GetByte(_ : int): byte = invalidOp "NotImplementedException" + member __.GetBytes(_ : int, _ : int64, _ : byte [], _ : int, _ : int): int64 = invalidOp "NotImplementedException" + member __.GetChar(_ : int): char = invalidOp "NotImplementedException" + member __.GetChars(_ : int, _ : int64, _ : char [], _ : int, _ : int): int64 = invalidOp "NotImplementedException" + member __.GetData(_ : int): IDataReader = invalidOp "NotImplementedException" + member __.GetDataTypeName(_ : int): string = invalidOp "NotImplementedException" + member __.GetDateTime(_ : int): System.DateTime = invalidOp "NotImplementedException" + member __.GetDecimal(_ : int): decimal = invalidOp "NotImplementedException" + member __.GetDouble(_ : int): float = invalidOp "NotImplementedException" + member __.GetFieldType(_ : int): System.Type = invalidOp "NotImplementedException" + member __.GetFloat(_ : int): float32 = invalidOp "NotImplementedException" + member __.GetGuid(_ : int): System.Guid = invalidOp "NotImplementedException" + member __.GetInt16(_ : int): int16 = invalidOp "NotImplementedException" + member __.GetInt32(_ : int): int = invalidOp "NotImplementedException" + member __.GetInt64(_ : int): int64 = invalidOp "NotImplementedException" + member __.GetName(_ : int): string = invalidOp "NotImplementedException" + member __.GetOrdinal(_ : string): int = invalidOp "NotImplementedException" + member __.GetSchemaTable(): DataTable = invalidOp "NotImplementedException" + member __.GetString(_ : int): string = invalidOp "NotImplementedException" + member __.GetValues(_ : obj []): int = invalidOp "NotImplementedException" + member __.IsClosed: bool = invalidOp "NotImplementedException" + member __.IsDBNull(_ : int): bool = invalidOp "NotImplementedException" + member __.Item with get (_ : int): obj = invalidOp "NotImplementedException" + member __.Item with get (_ : string): obj = invalidOp "NotImplementedException" + member __.NextResult(): bool = invalidOp "NotImplementedException" + member __.RecordsAffected: int = invalidOp "NotImplementedException" + +type DesignTime private() = static member internal AddGeneratedMethod (sqlParameters: Parameter list, hasOutputParameters, executeArgs: ProvidedParameter list, erasedType, providedOutputType, name) = @@ -632,3 +678,133 @@ type DesignTime private() = then yield upcast ProvidedMethod(factoryMethodName.Value, parameters2, returnType = cmdProvidedType, IsStaticMethod = true, InvokeCode = body2) ] + + static member private CreateTempTableRecord(name, cols) = + let rowType = ProvidedTypeDefinition(name, Some typeof, HideObjectMethods = true) + + let parameters = + [ + for (p : Column) in cols do + let name = p.Name + let param = ProvidedParameter( name, p.GetProvidedType(), ?optionalValue = if p.Nullable then Some null else None) + yield param + ] + + let ctor = ProvidedConstructor( parameters) + ctor.InvokeCode <- fun args -> + let optionsToNulls = QuotationsFactory.MapArrayNullableItems(cols, "MapArrayOptionItemToObj") + + <@@ let values: obj[] = %%Expr.NewArray(typeof, [ for a in args -> Expr.Coerce(a, typeof) ]) + (%%optionsToNulls) values + values @@> + + rowType.AddMember ctor + rowType.AddXmlDoc "Type Table Type" + + rowType + + // Changes any temp tables in to a global temp table (##name) then creates them on the open connection. + static member internal SubstituteTempTables(connection, commandText: string, tempTableDefinitions : string, connectionId) = + // Extract and temp tables + let tempTableRegex = Regex("#([a-z0-9\-_]+)", RegexOptions.IgnoreCase) + let tempTableNames = + tempTableRegex.Matches(tempTableDefinitions) + |> Seq.cast + |> Seq.map (fun m -> m.Groups.[1].Value) + |> Seq.toList + + match tempTableNames with + | [] -> commandText, None + | _ -> + // Create temp table(s), extracts the columns then drop it. + let tableTypes = + use create = new SqlCommand(tempTableDefinitions, connection) + create.ExecuteScalar() |> ignore + + tempTableNames + |> List.map(fun name -> + let cols = DesignTime.GetOutputColumns(connection, "SELECT * FROM #"+name, [], isStoredProcedure = false) + use drop = new SqlCommand("DROP TABLE #"+name, connection) + drop.ExecuteScalar() |> ignore + DesignTime.CreateTempTableRecord(name, cols), cols) + + let parameters = + tableTypes + |> List.map (fun (typ, _) -> + ProvidedParameter(typ.Name, parameterType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ seq>, [ typ ]))) + + // Build the values load method. + let loadValues (exprArgs: Expr list) (connection) = + (exprArgs.Tail, tableTypes) + ||> List.map2 (fun expr (typ, cols) -> + let destinationTableName = typ.Name + let colsLength = cols.Length + + <@@ + let items = (%%expr : obj seq) + use reader = new TempTableLoader(colsLength, items) + + use bulkCopy = new SqlBulkCopy((%%connection : SqlConnection)) + bulkCopy.BulkCopyTimeout <- 0 + bulkCopy.BatchSize <- 5000 + bulkCopy.DestinationTableName <- "#" + destinationTableName + bulkCopy.WriteToServer(reader) + + @@> + ) + |> List.fold (fun acc x -> Expr.Sequential(acc, x)) <@@ () @@> + + let loadTempTablesMethod = ProvidedMethod("LoadTempTables", parameters, typeof) + + loadTempTablesMethod.InvokeCode <- fun exprArgs -> + + let command = Expr.Coerce(exprArgs.[0], typedefof) + + let connection = + <@@ let cmd = (%%command : ISqlCommand) + cmd.Raw.Connection @@> + + <@@ do + use create = new SqlCommand(tempTableDefinitions, (%%connection : SqlConnection)) + create.ExecuteNonQuery() |> ignore + + (%%loadValues exprArgs connection) + ignore() @@> + + // Create the temp table(s) but as a global temp table with a unique name. This can be used later down stream on the open connection. + use cmd = new SqlCommand(tempTableRegex.Replace(tempTableDefinitions, Prefixes.tempTable+connectionId+"$1"), connection) + cmd.ExecuteScalar() |> ignore + + // Only replace temp tables we find in our list. + tempTableRegex.Replace(commandText, MatchEvaluator(fun m -> + match tempTableNames |> List.tryFind((=) m.Groups.[1].Value) with + | Some name -> Prefixes.tempTable + connectionId + name + | None -> m.Groups.[0].Value)), + + Some(loadTempTablesMethod, tableTypes |> List.unzip |> fst) + + static member internal RemoveSubstitutedTempTables(connection, tempTables : ProvidedTypeDefinition list, connectionId) = + if not tempTables.IsEmpty then + use cmd = new SqlCommand(tempTables |> List.map(fun tempTable -> sprintf "DROP TABLE [%s%s%s]" Prefixes.tempTable connectionId tempTable.Name) |> String.concat ";", connection) + cmd.ExecuteScalar() |> ignore + + // tableVarMapping(s) is converted into DECLARE statements then prepended to the command text. + static member internal SubstituteTableVar(commandText: string, tableVarMapping : string) = + let varRegex = Regex("@([a-z0-9_]+)", RegexOptions.IgnoreCase) + + let vars = + tableVarMapping.Split([|';'|], System.StringSplitOptions.RemoveEmptyEntries) + |> Array.choose(fun (x : string) -> + match x.Split([|'='|]) with + | [|name;typ|] -> Some(name.TrimStart('@'), typ) + | _ -> None) + + // Only replace table vars we find in our list. + let commandText = + varRegex.Replace(commandText, MatchEvaluator(fun m -> + match vars |> Array.tryFind(fun (n,_) -> n = m.Groups.[1].Value) with + | Some (name, _) -> Prefixes.tableVar + name + | None -> m.Groups.[0].Value)) + + (vars |> Array.map(fun (name,typ) -> sprintf "DECLARE %s%s %s = @%s" Prefixes.tableVar name typ name) |> String.concat "; ") + "; " + commandText + diff --git a/src/SqlClient/SqlCommandProvider.fs b/src/SqlClient/SqlCommandProvider.fs index 41320d25..22935f51 100644 --- a/src/SqlClient/SqlCommandProvider.fs +++ b/src/SqlClient/SqlCommandProvider.fs @@ -53,9 +53,11 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = ProvidedStaticParameter("ConfigFile", typeof, "") ProvidedStaticParameter("AllParametersOptional", typeof, false) ProvidedStaticParameter("DataDirectory", typeof, "") - ], + ProvidedStaticParameter("TempTableDefinitions", typeof, "") + ProvidedStaticParameter("TableVarMapping", typeof, "") + ], instantiationFunction = (fun typeName args -> - let value = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3], unbox args.[4], unbox args.[5], unbox args.[6]) + let value = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3], unbox args.[4], unbox args.[5], unbox args.[6], unbox args.[7], unbox args.[8]) cache.GetOrAdd(typeName, value) ) ) @@ -70,6 +72,8 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = If set all parameters become optional. NULL input values must be handled inside T-SQL. A folder to be used to resolve relative file paths to *.sql script files at compile time. The default value is the folder that contains the project or script. The name of the data directory that replaces |DataDirectory| in connection strings. The default value is the project or script directory. +Temp tables create command. +List table-valued parameters in the format of "@tvp1=[dbo].[TVP_IDs]; @tvp2=[dbo].[TVP_IDs]" """ this.AddNamespace(nameSpace, [ providerType ]) @@ -81,7 +85,7 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = |> defaultArg <| base.ResolveAssembly args - member internal this.CreateRootType(typeName, sqlStatement, connectionStringOrName: string, resultType, singleRow, configFile, allParametersOptional, dataDirectory) = + member internal this.CreateRootType(typeName, sqlStatement, connectionStringOrName: string, resultType, singleRow, configFile, allParametersOptional, dataDirectory, tempTableDefinitions, tableVarMapping) = if singleRow && not (resultType = ResultType.Records || resultType = ResultType.Tuples) then @@ -104,11 +108,25 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = conn.CheckVersion() conn.LoadDataTypesMap() - let parameters = DesignTime.ExtractParameters(conn, sqlStatement, allParametersOptional) + let connectionId = Guid.NewGuid().ToString().Substring(0, 8) + + let designTimeSqlStatement, tempTableTypes = + if String.IsNullOrWhiteSpace(tempTableDefinitions) then + sqlStatement, None + else + DesignTime.SubstituteTempTables(conn, sqlStatement, tempTableDefinitions, connectionId) + + let designTimeSqlStatement = + if String.IsNullOrWhiteSpace(tableVarMapping) then + designTimeSqlStatement + else + DesignTime.SubstituteTableVar(designTimeSqlStatement, tableVarMapping) + + let parameters = DesignTime.ExtractParameters(conn, designTimeSqlStatement, allParametersOptional) let outputColumns = if resultType <> ResultType.DataReader - then DesignTime.GetOutputColumns(conn, sqlStatement, parameters, isStoredProcedure = false) + then DesignTime.GetOutputColumns(conn, designTimeSqlStatement, parameters, isStoredProcedure = false) else [] let rank = if singleRow then ResultRank.SingleRow else ResultRank.Sequence @@ -117,6 +135,14 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = let cmdProvidedType = ProvidedTypeDefinition(assembly, nameSpace, typeName, Some typeof<``ISqlCommand Implementation``>, HideObjectMethods = true) do + match tempTableTypes with + | Some (loadTempTables, types) -> + DesignTime.RemoveSubstitutedTempTables(conn, types, connectionId) + cmdProvidedType.AddMember(loadTempTables) + types |> List.iter(fun t -> cmdProvidedType.AddMember(t)) + | _ -> () + + do cmdProvidedType.AddMember(ProvidedProperty("ConnectionStringOrName", typeof, [], IsStatic = true, GetterCode = fun _ -> <@@ connectionStringOrName @@>)) do