Skip to content

Commit eb83772

Browse files
committed
support ordinal parameters
1 parent 73d2259 commit eb83772

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

conn.go

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,8 @@ func (conn *redshiftDataConn) Begin() (driver.Tx, error) {
5353

5454
func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
5555
params := &redshiftdata.ExecuteStatementInput{
56-
Sql: nullif(query),
57-
}
58-
if len(args) > 0 {
59-
params.Parameters = make([]types.SqlParameter, 0, len(args))
60-
for _, arg := range args {
61-
params.Parameters = append(params.Parameters, types.SqlParameter{
62-
Name: aws.String(arg.Name),
63-
Value: aws.String(fmt.Sprintf("%v", arg.Value)),
64-
})
65-
}
56+
Sql: nullif(query),
57+
Parameters: convertArgsToParameters(args),
6658
}
6759
p, output, err := conn.executeStatement(ctx, params)
6860
if err != nil {
@@ -74,16 +66,8 @@ func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, ar
7466

7567
func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
7668
params := &redshiftdata.ExecuteStatementInput{
77-
Sql: nullif(query),
78-
}
79-
if len(args) > 0 {
80-
params.Parameters = make([]types.SqlParameter, 0, len(args))
81-
for _, arg := range args {
82-
params.Parameters = append(params.Parameters, types.SqlParameter{
83-
Name: aws.String(arg.Name),
84-
Value: aws.String(fmt.Sprintf("%v", arg.Value)),
85-
})
86-
}
69+
Sql: nullif(query),
70+
Parameters: convertArgsToParameters(args),
8771
}
8872
_, output, err := conn.executeStatement(ctx, params)
8973
if err != nil {
@@ -92,6 +76,20 @@ func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, arg
9276
return newResult(output), nil
9377
}
9478

79+
func convertArgsToParameters(args []driver.NamedValue) []types.SqlParameter {
80+
if len(args) == 0 {
81+
return nil
82+
}
83+
params := make([]types.SqlParameter, 0, len(args))
84+
for _, arg := range args {
85+
params = append(params, types.SqlParameter{
86+
Name: aws.String(coalesce(nullif(arg.Name), aws.String(fmt.Sprintf("%d", arg.Ordinal)))),
87+
Value: aws.String(fmt.Sprintf("%v", arg.Value)),
88+
})
89+
}
90+
return params
91+
}
92+
9593
func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *redshiftdata.ExecuteStatementInput) (*redshiftdata.GetStatementResultPaginator, *redshiftdata.DescribeStatementOutput, error) {
9694
debugLogger.Printf("query: %s", coalesce(params.Sql))
9795
params.ClusterIdentifier = conn.cfg.ClusterIdentifier

driver_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,25 @@ func TestTimestampQuery(t *testing.T) {
9898
})
9999
}
100100

101+
func TestOrdinalParameterQuery(t *testing.T) {
102+
runTestsWithDB(t, dsn, func(t *testing.T, db *sql.DB) {
103+
restore := requireNoErrorLog(t)
104+
defer restore()
105+
query := `SELECT usesysid, usename FROM pg_user WHERE usename = :1`
106+
rows, err := db.QueryContext(context.Background(), query, "rdsdb")
107+
require.NoError(t, err)
108+
defer func() {
109+
require.NoError(t, rows.Close())
110+
}()
111+
require.True(t, rows.Next())
112+
var userID int64
113+
var userName string
114+
require.NoError(t, rows.Scan(&userID, &userName))
115+
require.Equal(t, int64(1), userID)
116+
require.Equal(t, "rdsdb", userName)
117+
})
118+
}
119+
101120
func TestSimpleExec(t *testing.T) {
102121
runTestsWithDB(t, dsn, func(t *testing.T, db *sql.DB) {
103122
restore := requireNoErrorLog(t)

utils.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ func nullif(str string) *string {
77
return &str
88
}
99

10-
func coalesce(str *string) string {
11-
if str == nil {
12-
return ""
10+
func coalesce(strs ...*string) string {
11+
for _, str := range strs {
12+
if str != nil {
13+
return *str
14+
}
1315
}
14-
return *str
16+
return ""
1517
}

0 commit comments

Comments
 (0)