Skip to content

Commit 72fa342

Browse files
committed
inintal package
1 parent cd8fbda commit 72fa342

File tree

16 files changed

+1235
-0
lines changed

16 files changed

+1235
-0
lines changed

client.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package redshiftdatasqldriver
2+
3+
import (
4+
"context"
5+
6+
"github.com/aws/aws-sdk-go-v2/config"
7+
"github.com/aws/aws-sdk-go-v2/service/redshiftdata"
8+
)
9+
10+
type RedshiftDataClient interface {
11+
ExecuteStatement(ctx context.Context, params *redshiftdata.ExecuteStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.ExecuteStatementOutput, error)
12+
DescribeStatement(ctx context.Context, params *redshiftdata.DescribeStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.DescribeStatementOutput, error)
13+
CancelStatement(ctx context.Context, params *redshiftdata.CancelStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.CancelStatementOutput, error)
14+
redshiftdata.GetStatementResultAPIClient
15+
}
16+
17+
var RedshiftDataClientConstructor func(ctx context.Context, cfg *RedshiftDataConfig) (RedshiftDataClient, error)
18+
19+
func newRedshiftDataClient(ctx context.Context, cfg *RedshiftDataConfig) (RedshiftDataClient, error) {
20+
if RedshiftDataClientConstructor != nil {
21+
return RedshiftDataClientConstructor(ctx, cfg)
22+
}
23+
return DefaultRedshiftDataClientConstructor(ctx, cfg)
24+
}
25+
26+
func DefaultRedshiftDataClientConstructor(ctx context.Context, cfg *RedshiftDataConfig) (RedshiftDataClient, error) {
27+
awsCfg, err := config.LoadDefaultConfig(ctx)
28+
if err != nil {
29+
return nil, err
30+
}
31+
client := redshiftdata.NewFromConfig(awsCfg, cfg.RedshiftDataOptFns...)
32+
return client, nil
33+
}

conn.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
package redshiftdatasqldriver
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
"fmt"
7+
"time"
8+
9+
"github.com/aws/aws-sdk-go-v2/aws"
10+
"github.com/aws/aws-sdk-go-v2/service/redshiftdata"
11+
"github.com/aws/aws-sdk-go-v2/service/redshiftdata/types"
12+
)
13+
14+
type redshiftDataConn struct {
15+
client RedshiftDataClient
16+
cfg *RedshiftDataConfig
17+
aliveCh chan struct{}
18+
isClosed bool
19+
}
20+
21+
func newConn(client RedshiftDataClient, cfg *RedshiftDataConfig) *redshiftDataConn {
22+
return &redshiftDataConn{
23+
client: client,
24+
cfg: cfg,
25+
aliveCh: make(chan struct{}),
26+
}
27+
}
28+
29+
func (conn *redshiftDataConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
30+
return nil, fmt.Errorf("prepared statment %w", ErrNotSupported)
31+
}
32+
33+
func (conn *redshiftDataConn) Prepare(query string) (driver.Stmt, error) {
34+
return conn.PrepareContext(context.Background(), query)
35+
}
36+
37+
func (conn *redshiftDataConn) Close() error {
38+
if conn.isClosed {
39+
return nil
40+
}
41+
conn.isClosed = true
42+
close(conn.aliveCh)
43+
return nil
44+
}
45+
46+
func (conn *redshiftDataConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
47+
return nil, fmt.Errorf("transaction %w", ErrNotSupported)
48+
}
49+
50+
func (conn *redshiftDataConn) Begin() (driver.Tx, error) {
51+
return conn.BeginTx(context.Background(), driver.TxOptions{})
52+
}
53+
54+
func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
55+
params := &redshiftdata.ExecuteStatementInput{
56+
Sql: nullif(query),
57+
}
58+
if len(args) > 0 {
59+
params.Parameters = make([]types.SqlParameter, 0, len(args))
60+
for _, arg := range args {
61+
params.Parameters = append(params.Parameters, types.SqlParameter{
62+
Name: aws.String(arg.Name),
63+
Value: aws.String(fmt.Sprintf("%v", arg.Value)),
64+
})
65+
}
66+
}
67+
p, output, err := conn.executeStatement(ctx, params)
68+
if err != nil {
69+
return nil, err
70+
}
71+
rows := newRows(coalesce(output.Id), p)
72+
return rows, nil
73+
}
74+
75+
func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
76+
params := &redshiftdata.ExecuteStatementInput{
77+
Sql: nullif(query),
78+
}
79+
if len(args) > 0 {
80+
params.Parameters = make([]types.SqlParameter, 0, len(args))
81+
for _, arg := range args {
82+
params.Parameters = append(params.Parameters, types.SqlParameter{
83+
Name: aws.String(arg.Name),
84+
Value: aws.String(fmt.Sprintf("%v", arg.Value)),
85+
})
86+
}
87+
}
88+
_, output, err := conn.executeStatement(ctx, params)
89+
if err != nil {
90+
return nil, err
91+
}
92+
return newResult(output), nil
93+
}
94+
95+
func (conn *redshiftDataConn) executeStatement(ctx context.Context, params *redshiftdata.ExecuteStatementInput) (*redshiftdata.GetStatementResultPaginator, *redshiftdata.DescribeStatementOutput, error) {
96+
debugLogger.Printf("query: %s", coalesce(params.Sql))
97+
params.ClusterIdentifier = conn.cfg.ClusterIdentifier
98+
params.Database = conn.cfg.Database
99+
params.DbUser = conn.cfg.DbUser
100+
params.WorkgroupName = conn.cfg.WorkgroupName
101+
params.SecretArn = conn.cfg.SecretsARN
102+
103+
if conn.cfg.Timeout == 0 {
104+
conn.cfg.Timeout = 15 * time.Minute
105+
}
106+
if conn.cfg.Polling == 0 {
107+
conn.cfg.Polling = 10 * time.Millisecond
108+
}
109+
ectx, cancel := context.WithTimeout(ctx, conn.cfg.Timeout)
110+
defer cancel()
111+
executeOutput, err := conn.client.ExecuteStatement(ectx, params)
112+
if err != nil {
113+
return nil, nil, fmt.Errorf("execute statement:%w", err)
114+
}
115+
queryStart := time.Now()
116+
debugLogger.Printf("[%s] sucess execute statement: %s", *executeOutput.Id, coalesce(params.Sql))
117+
describeOutput, err := conn.client.DescribeStatement(ectx, &redshiftdata.DescribeStatementInput{
118+
Id: executeOutput.Id,
119+
})
120+
if err != nil {
121+
return nil, nil, fmt.Errorf("describe statement:%w", err)
122+
}
123+
debugLogger.Printf("[%s] describe statement: status=%s pid=%d query_id=%d", *executeOutput.Id, describeOutput.Status, describeOutput.RedshiftPid, describeOutput.RedshiftQueryId)
124+
125+
var isFinished bool
126+
defer func() {
127+
if !isFinished {
128+
describeOutput, err := conn.client.DescribeStatement(ctx, &redshiftdata.DescribeStatementInput{
129+
Id: executeOutput.Id,
130+
})
131+
if err != nil {
132+
errLogger.Printf("[%s] failed describe statement: %v", *executeOutput.Id, err)
133+
return
134+
}
135+
if describeOutput.Status == types.StatusStringFinished ||
136+
describeOutput.Status == types.StatusStringFailed ||
137+
describeOutput.Status == types.StatusStringAborted {
138+
return
139+
}
140+
debugLogger.Printf("[%s] try cancel statement", *executeOutput.Id)
141+
output, err := conn.client.CancelStatement(ctx, &redshiftdata.CancelStatementInput{
142+
Id: executeOutput.Id,
143+
})
144+
if err != nil {
145+
146+
errLogger.Printf("[%s] failed cancel statement: %v", *executeOutput.Id, err)
147+
return
148+
}
149+
if !*output.Status {
150+
debugLogger.Printf("[%s] cancel statement status is false", *executeOutput.Id)
151+
}
152+
}
153+
}()
154+
delay := time.NewTimer(conn.cfg.Polling)
155+
for {
156+
if describeOutput.Status == types.StatusStringAborted {
157+
return nil, nil, fmt.Errorf("query aborted: %s", *describeOutput.Error)
158+
}
159+
if describeOutput.Status == types.StatusStringFailed {
160+
return nil, nil, fmt.Errorf("query failed: %s", *describeOutput.Error)
161+
}
162+
if describeOutput.Status == types.StatusStringFinished {
163+
break
164+
}
165+
debugLogger.Printf("[%s] wating finsih query: elapsed_time=%s", *executeOutput.Id, time.Since(queryStart))
166+
delay.Reset(conn.cfg.Polling)
167+
select {
168+
case <-ectx.Done():
169+
if !delay.Stop() {
170+
<-delay.C
171+
}
172+
return nil, nil, ectx.Err()
173+
case <-delay.C:
174+
case <-conn.aliveCh:
175+
if !delay.Stop() {
176+
<-delay.C
177+
}
178+
return nil, nil, ErrConnClosed
179+
}
180+
describeOutput, err = conn.client.DescribeStatement(ctx, &redshiftdata.DescribeStatementInput{
181+
Id: executeOutput.Id,
182+
})
183+
if err != nil {
184+
return nil, nil, fmt.Errorf("describe statement:%w", err)
185+
}
186+
}
187+
isFinished = true
188+
debugLogger.Printf("[%s] success query: elapsed_time=%s", *executeOutput.Id, time.Since(queryStart))
189+
if !*describeOutput.HasResultSet {
190+
return nil, describeOutput, nil
191+
}
192+
debugLogger.Printf("[%s] query has result set: result_rows=%d", *executeOutput.Id, describeOutput.ResultRows)
193+
p := redshiftdata.NewGetStatementResultPaginator(conn.client, &redshiftdata.GetStatementResultInput{
194+
Id: executeOutput.Id,
195+
})
196+
return p, describeOutput, nil
197+
}

connector.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package redshiftdatasqldriver
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
)
7+
8+
type redshiftDataConnector struct {
9+
d *redshiftDataDriver
10+
cfg *RedshiftDataConfig
11+
}
12+
13+
func (c *redshiftDataConnector) Connect(ctx context.Context) (driver.Conn, error) {
14+
client, err := newRedshiftDataClient(ctx, c.cfg)
15+
if err != nil {
16+
return nil, err
17+
}
18+
return newConn(client, c.cfg), nil
19+
}
20+
21+
func (c *redshiftDataConnector) Driver() driver.Driver {
22+
return c.d
23+
}

driver.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package redshiftdatasqldriver
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
)
8+
9+
func init() {
10+
sql.Register("redshift-data", &redshiftDataDriver{})
11+
}
12+
13+
type redshiftDataDriver struct{}
14+
15+
func (d *redshiftDataDriver) Open(dsn string) (driver.Conn, error) {
16+
connector, err := d.OpenConnector(dsn)
17+
if err != nil {
18+
return nil, err
19+
}
20+
return connector.Connect(context.Background())
21+
}
22+
23+
func (d *redshiftDataDriver) OpenConnector(dsn string) (driver.Connector, error) {
24+
cfg, err := ParseDSN(dsn)
25+
if err != nil {
26+
return nil, err
27+
}
28+
return &redshiftDataConnector{
29+
d: d,
30+
cfg: cfg,
31+
}, nil
32+
}

0 commit comments

Comments
 (0)