diff --git a/.gitignore b/.gitignore
index 3ec2e7b8..4d1ab2f1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -235,3 +235,4 @@ Paket.Restore.targets
.paket
/docs/output
docs/output/**/*.*
+*.orig
diff --git a/src/SqlClient.Tests/Lib/Lib.fsproj b/src/SqlClient.Tests/Lib/Lib.fsproj
index 2c7d6195..ad6a20c6 100644
--- a/src/SqlClient.Tests/Lib/Lib.fsproj
+++ b/src/SqlClient.Tests/Lib/Lib.fsproj
@@ -53,6 +53,7 @@
+
diff --git a/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj b/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj
index aff864cf..3df52d8a 100644
--- a/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj
+++ b/src/SqlClient.Tests/SqlClient.Tests.NET40/SqlClient.Tests.NET40.fsproj
@@ -58,6 +58,7 @@
+
diff --git a/src/SqlClient.Tests/SqlClient.Tests.fsproj b/src/SqlClient.Tests/SqlClient.Tests.fsproj
index 3a995590..531c7154 100644
--- a/src/SqlClient.Tests/SqlClient.Tests.fsproj
+++ b/src/SqlClient.Tests/SqlClient.Tests.fsproj
@@ -87,6 +87,7 @@
+
diff --git a/src/SqlClient.Tests/TVPTests.fs b/src/SqlClient.Tests/TVPTests.fs
index 7885530d..52751f0d 100644
--- a/src/SqlClient.Tests/TVPTests.fs
+++ b/src/SqlClient.Tests/TVPTests.fs
@@ -145,3 +145,23 @@ let UsingTVPInQuery() =
|> Seq.toList
Assert.Equal<_ list>(expected, actual)
+
+type MappedTVP =
+ SqlCommandProvider<"
+ SELECT myId, myName from @input
+ ", ConnectionStrings.AdventureWorksLiteral, TableVarMapping = "@input=dbo.MyTableType">
+[]
+let UsingMappedTVPInQuery() =
+ printfn "%s" ConnectionStrings.AdventureWorksLiteral
+ use cmd = new MappedTVP(ConnectionStrings.AdventureWorksLiteral)
+ let expected = [
+ 1, Some "monkey"
+ 2, Some "donkey"
+ ]
+
+ let actual =
+ cmd.Execute(input = [ for id, name in expected -> MappedTVP.MyTableType(id, name) ])
+ |> Seq.map(fun x -> x.myId, x.myName)
+ |> Seq.toList
+
+ Assert.Equal<_ list>(expected, actual)
diff --git a/src/SqlClient.Tests/TempTableTests.fs b/src/SqlClient.Tests/TempTableTests.fs
new file mode 100644
index 00000000..afe39753
--- /dev/null
+++ b/src/SqlClient.Tests/TempTableTests.fs
@@ -0,0 +1,106 @@
+module FSharp.Data.TempTableTests
+
+open FSharp.Data
+open Xunit
+open System.Data.SqlClient
+
+type TempTable =
+ SqlCommandProvider<
+ TempTableDefinitions = "
+ CREATE TABLE #Temp (
+ Id INT NOT NULL,
+ Name NVARCHAR(100) NULL)",
+ CommandText = "
+ SELECT Id, Name FROM #Temp",
+ ConnectionStringOrName =
+ ConnectionStrings.AdventureWorksLiteral>
+
+[]
+let usingTempTable() =
+ use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
+ conn.Open()
+
+ use cmd = new TempTable(conn)
+
+ cmd.LoadTempTables(
+ Temp =
+ [ TempTable.Temp(Id = 1, Name = Some "monkey")
+ TempTable.Temp(Id = 2, Name = Some "donkey") ])
+
+ let actual =
+ cmd.Execute()
+ |> Seq.map(fun x -> x.Id, x.Name)
+ |> Seq.toList
+
+ let expected = [
+ 1, Some "monkey"
+ 2, Some "donkey"
+ ]
+
+ Assert.Equal<_ list>(expected, actual)
+
+[]
+let queryWithHash() =
+ // We shouldn't mangle the statement when it's run
+ use cmd =
+ new SqlCommandProvider<
+ CommandText = "
+ SELECT Id, Name
+ FROM
+ (
+ SELECT 1 AS Id, '#name' AS Name UNION
+ SELECT 2, 'some other value'
+ ) AS a
+ WHERE Name = '#name'",
+ ConnectionStringOrName =
+ ConnectionStrings.AdventureWorksLiteral>(ConnectionStrings.AdventureWorksLiteral)
+
+ let actual =
+ cmd.Execute()
+ |> Seq.map(fun x -> x.Id, x.Name)
+ |> Seq.toList
+
+ let expected = [
+ 1, "#name"
+ ]
+
+ Assert.Equal<_ list>(expected, actual)
+
+type TempTableHash =
+ SqlCommandProvider<
+ TempTableDefinitions = "
+ CREATE TABLE #Temp (
+ Id INT NOT NULL)",
+ CommandText = "
+ SELECT a.Id, a.Name
+ FROM
+ (
+ SELECT 1 AS Id, '#Temp' AS Name UNION
+ SELECT 2, 'some other value'
+ ) AS a
+ INNER JOIN #Temp t ON t.Id = a.Id",
+ ConnectionStringOrName =
+ ConnectionStrings.AdventureWorksLiteral>
+
+[]
+let queryWithHashAndTempTable() =
+ // We shouldn't mangle the statement when it's run
+ use conn = new SqlConnection(ConnectionStrings.AdventureWorksLiteral)
+ conn.Open()
+
+ use cmd = new TempTableHash(conn)
+
+ cmd.LoadTempTables(
+ Temp =
+ [ TempTableHash.Temp(Id = 1) ])
+
+ let actual =
+ cmd.Execute()
+ |> Seq.map(fun x -> x.Id, x.Name)
+ |> Seq.toList
+
+ let expected = [
+ 1, "#Temp"
+ ]
+
+ Assert.Equal<_ list>(expected, actual)
\ No newline at end of file
diff --git a/src/SqlClient/AssemblyInfo.fs b/src/SqlClient/AssemblyInfo.fs
index a61e0367..9cc5187c 100644
--- a/src/SqlClient/AssemblyInfo.fs
+++ b/src/SqlClient/AssemblyInfo.fs
@@ -1,4 +1,5 @@
-namespace System
+// Auto-Generated by FAKE; do not edit
+namespace System
open System.Reflection
open System.Runtime.CompilerServices
@@ -11,4 +12,9 @@ open System.Runtime.CompilerServices
do ()
module internal AssemblyVersionInformation =
- let [] Version = "1.8.4"
+ let [] AssemblyTitle = "SqlClient"
+ let [] AssemblyProduct = "FSharp.Data.SqlClient"
+ let [] AssemblyDescription = "SqlClient F# type providers"
+ let [] AssemblyVersion = "1.8.4"
+ let [] AssemblyFileVersion = "1.8.4"
+ let [] InternalsVisibleTo = "SqlClient.Tests"
diff --git a/src/SqlClient/DesignTime.fs b/src/SqlClient/DesignTime.fs
index 85e20d51..faaa5449 100644
--- a/src/SqlClient/DesignTime.fs
+++ b/src/SqlClient/DesignTime.fs
@@ -10,6 +10,7 @@ open System.Diagnostics
open Microsoft.FSharp.Quotations
open ProviderImplementation.ProvidedTypes
open FSharp.Data
+open System.Text.RegularExpressions
type internal RowType = {
Provided: Type
@@ -40,7 +41,52 @@ module internal SharedLogic =
// add .Table
returnType.Single |> cmdProvidedType.AddMember
-type DesignTime private() =
+module Prefixes =
+ let tempTable = "##SQLCOMMANDPROVIDER_"
+ let tableVar = "@SQLCOMMANDPROVIDER_"
+
+type TempTableLoader(fieldCount, items: obj seq) =
+ let enumerator = items.GetEnumerator()
+
+ interface IDataReader with
+ member this.FieldCount: int = fieldCount
+ member this.Read(): bool = enumerator.MoveNext()
+ member this.GetValue(i: int): obj =
+ let row : obj[] = unbox enumerator.Current
+ row.[i]
+ member this.Dispose(): unit = ()
+
+ member __.Close(): unit = invalidOp "NotImplementedException"
+ member __.Depth: int = invalidOp "NotImplementedException"
+ member __.GetBoolean(_: int): bool = invalidOp "NotImplementedException"
+ member __.GetByte(_ : int): byte = invalidOp "NotImplementedException"
+ member __.GetBytes(_ : int, _ : int64, _ : byte [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
+ member __.GetChar(_ : int): char = invalidOp "NotImplementedException"
+ member __.GetChars(_ : int, _ : int64, _ : char [], _ : int, _ : int): int64 = invalidOp "NotImplementedException"
+ member __.GetData(_ : int): IDataReader = invalidOp "NotImplementedException"
+ member __.GetDataTypeName(_ : int): string = invalidOp "NotImplementedException"
+ member __.GetDateTime(_ : int): System.DateTime = invalidOp "NotImplementedException"
+ member __.GetDecimal(_ : int): decimal = invalidOp "NotImplementedException"
+ member __.GetDouble(_ : int): float = invalidOp "NotImplementedException"
+ member __.GetFieldType(_ : int): System.Type = invalidOp "NotImplementedException"
+ member __.GetFloat(_ : int): float32 = invalidOp "NotImplementedException"
+ member __.GetGuid(_ : int): System.Guid = invalidOp "NotImplementedException"
+ member __.GetInt16(_ : int): int16 = invalidOp "NotImplementedException"
+ member __.GetInt32(_ : int): int = invalidOp "NotImplementedException"
+ member __.GetInt64(_ : int): int64 = invalidOp "NotImplementedException"
+ member __.GetName(_ : int): string = invalidOp "NotImplementedException"
+ member __.GetOrdinal(_ : string): int = invalidOp "NotImplementedException"
+ member __.GetSchemaTable(): DataTable = invalidOp "NotImplementedException"
+ member __.GetString(_ : int): string = invalidOp "NotImplementedException"
+ member __.GetValues(_ : obj []): int = invalidOp "NotImplementedException"
+ member __.IsClosed: bool = invalidOp "NotImplementedException"
+ member __.IsDBNull(_ : int): bool = invalidOp "NotImplementedException"
+ member __.Item with get (_ : int): obj = invalidOp "NotImplementedException"
+ member __.Item with get (_ : string): obj = invalidOp "NotImplementedException"
+ member __.NextResult(): bool = invalidOp "NotImplementedException"
+ member __.RecordsAffected: int = invalidOp "NotImplementedException"
+
+type DesignTime private() =
static member internal AddGeneratedMethod
(sqlParameters: Parameter list, hasOutputParameters, executeArgs: ProvidedParameter list, erasedType, providedOutputType, name) =
@@ -632,3 +678,133 @@ type DesignTime private() =
then
yield upcast ProvidedMethod(factoryMethodName.Value, parameters2, returnType = cmdProvidedType, IsStaticMethod = true, InvokeCode = body2)
]
+
+ static member private CreateTempTableRecord(name, cols) =
+ let rowType = ProvidedTypeDefinition(name, Some typeof, HideObjectMethods = true)
+
+ let parameters =
+ [
+ for (p : Column) in cols do
+ let name = p.Name
+ let param = ProvidedParameter( name, p.GetProvidedType(), ?optionalValue = if p.Nullable then Some null else None)
+ yield param
+ ]
+
+ let ctor = ProvidedConstructor( parameters)
+ ctor.InvokeCode <- fun args ->
+ let optionsToNulls = QuotationsFactory.MapArrayNullableItems(cols, "MapArrayOptionItemToObj")
+
+ <@@ let values: obj[] = %%Expr.NewArray(typeof, [ for a in args -> Expr.Coerce(a, typeof) ])
+ (%%optionsToNulls) values
+ values @@>
+
+ rowType.AddMember ctor
+ rowType.AddXmlDoc "Type Table Type"
+
+ rowType
+
+ // Changes any temp tables in to a global temp table (##name) then creates them on the open connection.
+ static member internal SubstituteTempTables(connection, commandText: string, tempTableDefinitions : string, connectionId) =
+ // Extract and temp tables
+ let tempTableRegex = Regex("#([a-z0-9\-_]+)", RegexOptions.IgnoreCase)
+ let tempTableNames =
+ tempTableRegex.Matches(tempTableDefinitions)
+ |> Seq.cast
+ |> Seq.map (fun m -> m.Groups.[1].Value)
+ |> Seq.toList
+
+ match tempTableNames with
+ | [] -> commandText, None
+ | _ ->
+ // Create temp table(s), extracts the columns then drop it.
+ let tableTypes =
+ use create = new SqlCommand(tempTableDefinitions, connection)
+ create.ExecuteScalar() |> ignore
+
+ tempTableNames
+ |> List.map(fun name ->
+ let cols = DesignTime.GetOutputColumns(connection, "SELECT * FROM #"+name, [], isStoredProcedure = false)
+ use drop = new SqlCommand("DROP TABLE #"+name, connection)
+ drop.ExecuteScalar() |> ignore
+ DesignTime.CreateTempTableRecord(name, cols), cols)
+
+ let parameters =
+ tableTypes
+ |> List.map (fun (typ, _) ->
+ ProvidedParameter(typ.Name, parameterType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ seq>, [ typ ])))
+
+ // Build the values load method.
+ let loadValues (exprArgs: Expr list) (connection) =
+ (exprArgs.Tail, tableTypes)
+ ||> List.map2 (fun expr (typ, cols) ->
+ let destinationTableName = typ.Name
+ let colsLength = cols.Length
+
+ <@@
+ let items = (%%expr : obj seq)
+ use reader = new TempTableLoader(colsLength, items)
+
+ use bulkCopy = new SqlBulkCopy((%%connection : SqlConnection))
+ bulkCopy.BulkCopyTimeout <- 0
+ bulkCopy.BatchSize <- 5000
+ bulkCopy.DestinationTableName <- "#" + destinationTableName
+ bulkCopy.WriteToServer(reader)
+
+ @@>
+ )
+ |> List.fold (fun acc x -> Expr.Sequential(acc, x)) <@@ () @@>
+
+ let loadTempTablesMethod = ProvidedMethod("LoadTempTables", parameters, typeof)
+
+ loadTempTablesMethod.InvokeCode <- fun exprArgs ->
+
+ let command = Expr.Coerce(exprArgs.[0], typedefof)
+
+ let connection =
+ <@@ let cmd = (%%command : ISqlCommand)
+ cmd.Raw.Connection @@>
+
+ <@@ do
+ use create = new SqlCommand(tempTableDefinitions, (%%connection : SqlConnection))
+ create.ExecuteNonQuery() |> ignore
+
+ (%%loadValues exprArgs connection)
+ ignore() @@>
+
+ // Create the temp table(s) but as a global temp table with a unique name. This can be used later down stream on the open connection.
+ use cmd = new SqlCommand(tempTableRegex.Replace(tempTableDefinitions, Prefixes.tempTable+connectionId+"$1"), connection)
+ cmd.ExecuteScalar() |> ignore
+
+ // Only replace temp tables we find in our list.
+ tempTableRegex.Replace(commandText, MatchEvaluator(fun m ->
+ match tempTableNames |> List.tryFind((=) m.Groups.[1].Value) with
+ | Some name -> Prefixes.tempTable + connectionId + name
+ | None -> m.Groups.[0].Value)),
+
+ Some(loadTempTablesMethod, tableTypes |> List.unzip |> fst)
+
+ static member internal RemoveSubstitutedTempTables(connection, tempTables : ProvidedTypeDefinition list, connectionId) =
+ if not tempTables.IsEmpty then
+ use cmd = new SqlCommand(tempTables |> List.map(fun tempTable -> sprintf "DROP TABLE [%s%s%s]" Prefixes.tempTable connectionId tempTable.Name) |> String.concat ";", connection)
+ cmd.ExecuteScalar() |> ignore
+
+ // tableVarMapping(s) is converted into DECLARE statements then prepended to the command text.
+ static member internal SubstituteTableVar(commandText: string, tableVarMapping : string) =
+ let varRegex = Regex("@([a-z0-9_]+)", RegexOptions.IgnoreCase)
+
+ let vars =
+ tableVarMapping.Split([|';'|], System.StringSplitOptions.RemoveEmptyEntries)
+ |> Array.choose(fun (x : string) ->
+ match x.Split([|'='|]) with
+ | [|name;typ|] -> Some(name.TrimStart('@'), typ)
+ | _ -> None)
+
+ // Only replace table vars we find in our list.
+ let commandText =
+ varRegex.Replace(commandText, MatchEvaluator(fun m ->
+ match vars |> Array.tryFind(fun (n,_) -> n = m.Groups.[1].Value) with
+ | Some (name, _) -> Prefixes.tableVar + name
+ | None -> m.Groups.[0].Value))
+
+ (vars |> Array.map(fun (name,typ) -> sprintf "DECLARE %s%s %s = @%s" Prefixes.tableVar name typ name) |> String.concat "; ") + "; " + commandText
+
diff --git a/src/SqlClient/SqlCommandProvider.fs b/src/SqlClient/SqlCommandProvider.fs
index 41320d25..22935f51 100644
--- a/src/SqlClient/SqlCommandProvider.fs
+++ b/src/SqlClient/SqlCommandProvider.fs
@@ -53,9 +53,11 @@ type SqlCommandProvider(config : TypeProviderConfig) as this =
ProvidedStaticParameter("ConfigFile", typeof, "")
ProvidedStaticParameter("AllParametersOptional", typeof, false)
ProvidedStaticParameter("DataDirectory", typeof, "")
- ],
+ ProvidedStaticParameter("TempTableDefinitions", typeof, "")
+ ProvidedStaticParameter("TableVarMapping", typeof, "")
+ ],
instantiationFunction = (fun typeName args ->
- let value = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3], unbox args.[4], unbox args.[5], unbox args.[6])
+ let value = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3], unbox args.[4], unbox args.[5], unbox args.[6], unbox args.[7], unbox args.[8])
cache.GetOrAdd(typeName, value)
)
)
@@ -70,6 +72,8 @@ type SqlCommandProvider(config : TypeProviderConfig) as this =
If set all parameters become optional. NULL input values must be handled inside T-SQL.
A folder to be used to resolve relative file paths to *.sql script files at compile time. The default value is the folder that contains the project or script.
The name of the data directory that replaces |DataDirectory| in connection strings. The default value is the project or script directory.
+Temp tables create command.
+List table-valued parameters in the format of "@tvp1=[dbo].[TVP_IDs]; @tvp2=[dbo].[TVP_IDs]"
"""
this.AddNamespace(nameSpace, [ providerType ])
@@ -81,7 +85,7 @@ type SqlCommandProvider(config : TypeProviderConfig) as this =
|> defaultArg
<| base.ResolveAssembly args
- member internal this.CreateRootType(typeName, sqlStatement, connectionStringOrName: string, resultType, singleRow, configFile, allParametersOptional, dataDirectory) =
+ member internal this.CreateRootType(typeName, sqlStatement, connectionStringOrName: string, resultType, singleRow, configFile, allParametersOptional, dataDirectory, tempTableDefinitions, tableVarMapping) =
if singleRow && not (resultType = ResultType.Records || resultType = ResultType.Tuples)
then
@@ -104,11 +108,25 @@ type SqlCommandProvider(config : TypeProviderConfig) as this =
conn.CheckVersion()
conn.LoadDataTypesMap()
- let parameters = DesignTime.ExtractParameters(conn, sqlStatement, allParametersOptional)
+ let connectionId = Guid.NewGuid().ToString().Substring(0, 8)
+
+ let designTimeSqlStatement, tempTableTypes =
+ if String.IsNullOrWhiteSpace(tempTableDefinitions) then
+ sqlStatement, None
+ else
+ DesignTime.SubstituteTempTables(conn, sqlStatement, tempTableDefinitions, connectionId)
+
+ let designTimeSqlStatement =
+ if String.IsNullOrWhiteSpace(tableVarMapping) then
+ designTimeSqlStatement
+ else
+ DesignTime.SubstituteTableVar(designTimeSqlStatement, tableVarMapping)
+
+ let parameters = DesignTime.ExtractParameters(conn, designTimeSqlStatement, allParametersOptional)
let outputColumns =
if resultType <> ResultType.DataReader
- then DesignTime.GetOutputColumns(conn, sqlStatement, parameters, isStoredProcedure = false)
+ then DesignTime.GetOutputColumns(conn, designTimeSqlStatement, parameters, isStoredProcedure = false)
else []
let rank = if singleRow then ResultRank.SingleRow else ResultRank.Sequence
@@ -117,6 +135,14 @@ type SqlCommandProvider(config : TypeProviderConfig) as this =
let cmdProvidedType = ProvidedTypeDefinition(assembly, nameSpace, typeName, Some typeof<``ISqlCommand Implementation``>, HideObjectMethods = true)
do
+ match tempTableTypes with
+ | Some (loadTempTables, types) ->
+ DesignTime.RemoveSubstitutedTempTables(conn, types, connectionId)
+ cmdProvidedType.AddMember(loadTempTables)
+ types |> List.iter(fun t -> cmdProvidedType.AddMember(t))
+ | _ -> ()
+
+ do
cmdProvidedType.AddMember(ProvidedProperty("ConnectionStringOrName", typeof, [], IsStatic = true, GetterCode = fun _ -> <@@ connectionStringOrName @@>))
do