From 4db411edca0827cc80236f7ee5ef444fd9a8753b Mon Sep 17 00:00:00 2001 From: Tejaswini Duggaraju Date: Fri, 4 Apr 2025 13:42:09 -0700 Subject: [PATCH 1/2] Added support of taking catalog for the database --- conn.go | 13 +++++++++---- db_test.go | 28 +++++++++++++++++++++++++++- driver.go | 3 +++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index b131e4c..176ff6a 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( type conn struct { athena athenaAPI + catalog string db string OutputLocation string @@ -57,11 +58,15 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) // startQuery starts an Athena query and returns its ID. func (c *conn) startQuery(ctx context.Context, query string) (string, error) { + queryCtx := &types.QueryExecutionContext{ + Database: aws.String(c.db), + } + if c.catalog != "" { + queryCtx.Catalog = aws.String(c.catalog) + } resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ - QueryString: aws.String(query), - QueryExecutionContext: &types.QueryExecutionContext{ - Database: aws.String(c.db), - }, + QueryString: aws.String(query), + QueryExecutionContext: queryCtx, ResultConfiguration: &types.ResultConfiguration{ OutputLocation: aws.String(c.OutputLocation), }, diff --git a/db_test.go b/db_test.go index a784853..fa98afa 100644 --- a/db_test.go +++ b/db_test.go @@ -119,7 +119,7 @@ func TestOpen(t *testing.T) { awsConfig, err := config.LoadDefaultConfig(context.Background()) require.NoError(t, err, "LoadDefaultConfig") db, err := Open(DriverConfig{ - Config: &awsConfig, + Config: &awsConfig, Database: AthenaDatabase, OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket), }) @@ -129,6 +129,28 @@ func TestOpen(t *testing.T) { require.NoError(t, err, "Query") } +func TestDriverWithDBCatalog(t *testing.T) { + ctx := context.Background() + catalogName := os.Getenv("ATHENA_CATALOG") + if catalogName == "" { + t.Skip("ATHENA_CATALOG not set") + } + + tableName := os.Getenv("ATHENA_TABLE") + if tableName == "" { + tableName = "catalog_test_table" + } + connStr := fmt.Sprintf("catalog=%s&db=%s&output_location=s3://%s/output", getDBCatalog(), AthenaDatabase, S3Bucket) + db, err := sql.Open("athena", connStr) + require.NoError(t, err, "Open") + defer db.Close() + + harness := &athenaHarness{t: t, db: db, table: tableName} + defer harness.teardown(ctx) + harness.mustExec(ctx, `CREATE TABLE %s ( value string )`, tableName) + harness.mustExec(ctx, `INSERT INTO %s VALUES ('foo')`, tableName) +} + type dummyRow struct { NullValue *struct{} `json:"nullValue"` SmallintType int `json:"smallintType"` @@ -248,3 +270,7 @@ func (t athenaDate) String() string { func (t athenaDate) Equal(t2 athenaDate) bool { return time.Time(t).Equal(time.Time(t2)) } + +func getDBCatalog() string { + return os.Getenv("ATHENA_CATALOG") +} diff --git a/driver.go b/driver.go index 03dceca..c311117 100644 --- a/driver.go +++ b/driver.go @@ -80,6 +80,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { return &conn{ athena: athena.NewFromConfig(*cfg.Config), db: cfg.Database, + catalog: cfg.Catalog, OutputLocation: cfg.OutputLocation, pollFrequency: cfg.PollFrequency, }, nil @@ -116,6 +117,7 @@ func Open(cfg DriverConfig) (*sql.DB, error) { type DriverConfig struct { Config *aws.Config Database string + Catalog string OutputLocation string PollFrequency time.Duration @@ -139,6 +141,7 @@ func configFromConnectionString(ctx context.Context, connStr string) (*DriverCon cfg.Config = &awsConfig cfg.Database = args.Get("db") + cfg.Catalog = args.Get("catalog") cfg.OutputLocation = args.Get("output_location") frequencyStr := args.Get("poll_frequency") From 153d1fb6b245f284c5ad29abdbb94da0cc7da95f Mon Sep 17 00:00:00 2001 From: Tejaswini Duggaraju Date: Fri, 4 Apr 2025 13:52:06 -0700 Subject: [PATCH 2/2] Removed getDBCatalog() method in the test file --- db_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/db_test.go b/db_test.go index fa98afa..48ade7d 100644 --- a/db_test.go +++ b/db_test.go @@ -140,7 +140,7 @@ func TestDriverWithDBCatalog(t *testing.T) { if tableName == "" { tableName = "catalog_test_table" } - connStr := fmt.Sprintf("catalog=%s&db=%s&output_location=s3://%s/output", getDBCatalog(), AthenaDatabase, S3Bucket) + connStr := fmt.Sprintf("catalog=%s&db=%s&output_location=s3://%s/output", catalogName, AthenaDatabase, S3Bucket) db, err := sql.Open("athena", connStr) require.NoError(t, err, "Open") defer db.Close() @@ -270,7 +270,3 @@ func (t athenaDate) String() string { func (t athenaDate) Equal(t2 athenaDate) bool { return time.Time(t).Equal(time.Time(t2)) } - -func getDBCatalog() string { - return os.Getenv("ATHENA_CATALOG") -}