From e1773254db6c72f217b7000d5724fb99c141f8c9 Mon Sep 17 00:00:00 2001 From: mackee Date: Thu, 19 Jun 2025 14:47:41 +0900 Subject: [PATCH 1/3] feat(postgresql): use RETURNING to fetch entire row on INSERT instead of PK re-select BREAKING CHANGE: PostgreSQL INSERT statements now use `RETURNING` with all columns and return full row values directly, rather than performing a second SELECT by primary key. This affects result structures for single and bulk inserts, including on conflict operations. Code generation template updated accordingly. --- _example/group.gen.go | 23 +++--- _example/postgresql/account_test.go | 14 ++-- _example/postgresql/accounts.gen.go | 92 ++++++++++++----------- _example/postgresql/groups.gen.go | 92 ++++++++++++----------- _example/postgresql/identities.gen.go | 92 ++++++++++++----------- _example/user.gen.go | 23 +++--- _example/user_external.gen.go | 23 +++--- _example/user_item.gen.go | 23 +++--- _example/user_sns.gen.go | 23 +++--- template/insert.tmpl | 102 +++++++++++++------------- template/insert_postgresql.tmpl | 74 ++++++------------- 11 files changed, 280 insertions(+), 301 deletions(-) diff --git a/_example/group.gen.go b/_example/group.gen.go index fb09442..2eb16ec 100644 --- a/_example/group.gen.go +++ b/_example/group.gen.go @@ -787,6 +787,10 @@ func (q groupInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q groupInsertSQL) rowsNum() int { + return 1 +} + func (q groupInsertSQL) groupInsertSQLToSql() (string, []any, error) { var err error var s interface{} = Group{} @@ -806,19 +810,7 @@ func (q groupInsertSQL) groupInsertSQLToSql() (string, []any, error) { } func (q groupInsertSQL) Exec(db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() - if err != nil { - return Group{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return Group{}, err - } - id, err := result.LastInsertId() - if err != nil { - return Group{}, err - } - return NewGroupSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { @@ -851,6 +843,7 @@ type groupDefaultInsertHooker interface { } type groupInsertSQLToSqler interface { + rowsNum() int groupInsertSQLToSql() (string, []any, error) } @@ -868,6 +861,10 @@ func (q *groupBulkInsertSQL) Append(iqs ...groupInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *groupBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *groupBulkInsertSQL) groupInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This groupBulkInsertSQL's InsertSQL was empty") diff --git a/_example/postgresql/account_test.go b/_example/postgresql/account_test.go index 2487b56..59c9703 100644 --- a/_example/postgresql/account_test.go +++ b/_example/postgresql/account_test.go @@ -413,7 +413,7 @@ func testCases() []testCaseWithToQueryTestCase { query: postgresql.NewAccountSQL().Insert(). ValueName("foo"). ValueEmbedding(sampleVector), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -432,7 +432,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueChildGroupIDIsNull(). ValueCreatedAt(sampleDate). ValueUpdatedAt(sampleDate), - expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6) RETURNING "id";`, + expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6) RETURNING "id", "name", "leader_account_id", "sub_leader_account_id", "child_group_id", "created_at", "updated_at";`, vs: []any{sql.Null[int64]{}, sampleDate, int64(42), "foo", int64(28), sampleDate}, expectedResult: postgresql.Group{ ID: 1, @@ -450,7 +450,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueName("foo"). ValueEmbedding(sampleVector). OnConflictDoNothing(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -468,7 +468,7 @@ func testCases() []testCaseWithToQueryTestCase { OnConflictDoUpdate("id"). ValueOnUpdateName("powawa"). SameOnUpdateEmbedding(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5 RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5 RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{sampleDate, sampleVector, "foo", sampleDate, "powawa"}, expectedResult: postgresql.Account{ ID: 1, @@ -490,7 +490,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{ sampleDate, sampleVector, @@ -523,7 +523,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi.OnConflictDoNothing() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{ sampleDate, sampleVector, @@ -559,7 +559,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueOnUpdateName("powawa"). SameOnUpdateEmbedding() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16 RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16 RETURNING "id", "name", "embedding", "created_at", "updated_at";`, vs: []any{ sampleDate, sampleVector, diff --git a/_example/postgresql/accounts.gen.go b/_example/postgresql/accounts.gen.go index e88a18b..1050146 100644 --- a/_example/postgresql/accounts.gen.go +++ b/_example/postgresql/accounts.gen.go @@ -607,7 +607,12 @@ func (q accountInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(accountAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil +} + +func (q accountInsertSQL) rowsNum() int { + return 1 } func (q accountInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -629,16 +634,7 @@ func (q accountInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []an } func (q accountInsertSQL) Exec(db sqlla.DB) (Account, error) { - query, args, err := q.ToSql() - if err != nil { - return Account{}, err - } - row := db.QueryRow(query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { - return Account{}, err - } - return NewAccountSQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q accountInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { @@ -647,11 +643,11 @@ func (q accountInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q accountInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -668,6 +664,7 @@ type accountDefaultInsertHooker interface { } type accountInsertSQLToSqler interface { + rowsNum() int accountInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -685,6 +682,10 @@ func (q *accountBulkInsertSQL) Append(iqs ...accountInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *accountBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *accountBulkInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This accountBulkInsertSQL's InsertSQL was empty") @@ -721,7 +722,8 @@ func (q *accountBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(accountAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil } func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { query, args, err := q.ToSql() @@ -733,16 +735,18 @@ func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([] return nil, err } defer rows.Close() - pks := make([]AccountID, 0, len(q.insertSQLs)) + results := make([]Account, 0, len(q.insertSQLs)) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *accountBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -767,8 +771,8 @@ func (q accountInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(accountAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -779,11 +783,11 @@ func (q accountInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -908,7 +912,8 @@ func (q accountInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(accountAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -919,11 +924,11 @@ func (q accountInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -957,8 +962,8 @@ func (q accountBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(accountAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -973,16 +978,17 @@ func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context return nil, err } defer rows.Close() - pks := make([]AccountID, 0) + results := make([]Account, 0, q.insertSQL.rowsNum()) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1114,7 +1120,8 @@ func (q accountBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(accountAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -1129,16 +1136,17 @@ func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, return nil, err } defer rows.Close() - pks := make([]AccountID, 0) + results := make([]Account, 0, q.insertSQL.rowsNum()) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/postgresql/groups.gen.go b/_example/postgresql/groups.gen.go index 29d194e..184cc07 100644 --- a/_example/postgresql/groups.gen.go +++ b/_example/postgresql/groups.gen.go @@ -777,7 +777,12 @@ func (q groupInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(groupAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil +} + +func (q groupInsertSQL) rowsNum() int { + return 1 } func (q groupInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -799,16 +804,7 @@ func (q groupInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, e } func (q groupInsertSQL) Exec(db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() - if err != nil { - return Group{}, err - } - row := db.QueryRow(query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { - return Group{}, err - } - return NewGroupSQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { @@ -817,11 +813,11 @@ func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, er return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q groupInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -838,6 +834,7 @@ type groupDefaultInsertHooker interface { } type groupInsertSQLToSqler interface { + rowsNum() int groupInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -855,6 +852,10 @@ func (q *groupBulkInsertSQL) Append(iqs ...groupInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *groupBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *groupBulkInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This groupBulkInsertSQL's InsertSQL was empty") @@ -891,7 +892,8 @@ func (q *groupBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(groupAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil } func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { query, args, err := q.ToSql() @@ -903,16 +905,18 @@ func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Gr return nil, err } defer rows.Close() - pks := make([]GroupID, 0, len(q.insertSQLs)) + results := make([]Group, 0, len(q.insertSQLs)) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *groupBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -937,8 +941,8 @@ func (q groupInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(groupAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -949,11 +953,11 @@ func (q groupInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db s return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -1118,7 +1122,8 @@ func (q groupInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(groupAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -1129,11 +1134,11 @@ func (q groupInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sq return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -1167,8 +1172,8 @@ func (q groupBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(groupAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -1183,16 +1188,17 @@ func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, return nil, err } defer rows.Close() - pks := make([]GroupID, 0) + results := make([]Group, 0, q.insertSQL.rowsNum()) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1364,7 +1370,8 @@ func (q groupBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(groupAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -1379,16 +1386,17 @@ func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, d return nil, err } defer rows.Close() - pks := make([]GroupID, 0) + results := make([]Group, 0, q.insertSQL.rowsNum()) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/postgresql/identities.gen.go b/_example/postgresql/identities.gen.go index b0e8bdb..6386748 100644 --- a/_example/postgresql/identities.gen.go +++ b/_example/postgresql/identities.gen.go @@ -613,7 +613,12 @@ func (q identityInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(identityAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil +} + +func (q identityInsertSQL) rowsNum() int { + return 1 } func (q identityInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -635,16 +640,7 @@ func (q identityInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, [] } func (q identityInsertSQL) Exec(db sqlla.DB) (Identity, error) { - query, args, err := q.ToSql() - if err != nil { - return Identity{}, err - } - row := db.QueryRow(query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { - return Identity{}, err - } - return NewIdentitySQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q identityInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { @@ -653,11 +649,11 @@ func (q identityInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identi return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q identityInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -674,6 +670,7 @@ type identityDefaultInsertHooker interface { } type identityInsertSQLToSqler interface { + rowsNum() int identityInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -691,6 +688,10 @@ func (q *identityBulkInsertSQL) Append(iqs ...identityInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *identityBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *identityBulkInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This identityBulkInsertSQL's InsertSQL was empty") @@ -727,7 +728,8 @@ func (q *identityBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + columns := strings.Join(identityAllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil } func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { query, args, err := q.ToSql() @@ -739,16 +741,18 @@ func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([ return nil, err } defer rows.Close() - pks := make([]IdentityID, 0, len(q.insertSQLs)) + results := make([]Identity, 0, len(q.insertSQLs)) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *identityBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -773,8 +777,8 @@ func (q identityInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(identityAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -785,11 +789,11 @@ func (q identityInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, d return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -914,7 +918,8 @@ func (q identityInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(identityAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -925,11 +930,11 @@ func (q identityInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -963,8 +968,8 @@ func (q identityBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" + columns := strings.Join(identityAllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil } @@ -979,16 +984,17 @@ func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Contex return nil, err } defer rows.Close() - pks := make([]IdentityID, 0) + results := make([]Identity, 0, q.insertSQL.rowsNum()) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1120,7 +1126,8 @@ func (q identityBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" + columns := strings.Join(identityAllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } @@ -1135,16 +1142,17 @@ func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context return nil, err } defer rows.Close() - pks := make([]IdentityID, 0) + results := make([]Identity, 0, q.insertSQL.rowsNum()) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/user.gen.go b/_example/user.gen.go index b57430f..6d2c3da 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -697,6 +697,10 @@ func (q userInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userInsertSQL) rowsNum() int { + return 1 +} + func (q userInsertSQL) userInsertSQLToSql() (string, []any, error) { var err error var s interface{} = User{} @@ -716,19 +720,7 @@ func (q userInsertSQL) userInsertSQLToSql() (string, []any, error) { } func (q userInsertSQL) Exec(db sqlla.DB) (User, error) { - query, args, err := q.ToSql() - if err != nil { - return User{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return User{}, err - } - id, err := result.LastInsertId() - if err != nil { - return User{}, err - } - return NewUserSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (User, error) { @@ -761,6 +753,7 @@ type userDefaultInsertHooker interface { } type userInsertSQLToSqler interface { + rowsNum() int userInsertSQLToSql() (string, []any, error) } @@ -778,6 +771,10 @@ func (q *userBulkInsertSQL) Append(iqs ...userInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userBulkInsertSQL) userInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_external.gen.go b/_example/user_external.gen.go index 48009eb..0ede046 100644 --- a/_example/user_external.gen.go +++ b/_example/user_external.gen.go @@ -607,6 +607,10 @@ func (q userExternalInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userExternalInsertSQL) rowsNum() int { + return 1 +} + func (q userExternalInsertSQL) userExternalInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserExternal{} @@ -626,19 +630,7 @@ func (q userExternalInsertSQL) userExternalInsertSQLToSql() (string, []any, erro } func (q userExternalInsertSQL) Exec(db sqlla.DB) (UserExternal, error) { - query, args, err := q.ToSql() - if err != nil { - return UserExternal{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserExternal{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserExternal{}, err - } - return NewUserExternalSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userExternalInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserExternal, error) { @@ -671,6 +663,7 @@ type userExternalDefaultInsertHooker interface { } type userExternalInsertSQLToSqler interface { + rowsNum() int userExternalInsertSQLToSql() (string, []any, error) } @@ -688,6 +681,10 @@ func (q *userExternalBulkInsertSQL) Append(iqs ...userExternalInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userExternalBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userExternalBulkInsertSQL) userExternalInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userExternalBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 8cc81d1..89dcde8 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -647,6 +647,10 @@ func (q userItemInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userItemInsertSQL) rowsNum() int { + return 1 +} + func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserItem{} @@ -666,19 +670,7 @@ func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []any, error) { } func (q userItemInsertSQL) Exec(db sqlla.DB) (UserItem, error) { - query, args, err := q.ToSql() - if err != nil { - return UserItem{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserItem{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserItem{}, err - } - return NewUserItemSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userItemInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserItem, error) { @@ -711,6 +703,7 @@ type userItemDefaultInsertHooker interface { } type userItemInsertSQLToSqler interface { + rowsNum() int userItemInsertSQLToSql() (string, []any, error) } @@ -728,6 +721,10 @@ func (q *userItemBulkInsertSQL) Append(iqs ...userItemInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userItemBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_sns.gen.go b/_example/user_sns.gen.go index f58959f..8b4427c 100644 --- a/_example/user_sns.gen.go +++ b/_example/user_sns.gen.go @@ -567,6 +567,10 @@ func (q userSNSInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userSNSInsertSQL) rowsNum() int { + return 1 +} + func (q userSNSInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserSNS{} @@ -586,19 +590,7 @@ func (q userSNSInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { } func (q userSNSInsertSQL) Exec(db sqlla.DB) (UserSNS, error) { - query, args, err := q.ToSql() - if err != nil { - return UserSNS{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserSNS{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserSNS{}, err - } - return NewUserSNSSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userSNSInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserSNS, error) { @@ -631,6 +623,7 @@ type userSNSDefaultInsertHooker interface { } type userSNSInsertSQLToSqler interface { + rowsNum() int userSNSInsertSQLToSql() (string, []any, error) } @@ -648,6 +641,10 @@ func (q *userSNSBulkInsertSQL) Append(iqs ...userSNSInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userSNSBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userSNSBulkInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userSNSBulkInsertSQL's InsertSQL was empty") diff --git a/template/insert.tmpl b/template/insert.tmpl index 4c53cb3..daedf0f 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -25,13 +25,18 @@ func (q {{ $camelName }}InsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} - return query + " RETURNING " + {{ cquoteby .PkColumn.Name }} + ";", vs, nil + {{- if eq (dialect) "postgresql" }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil {{- else }} return query + ";", vs, nil {{- end }} } +func (q {{ $camelName }}InsertSQL) rowsNum() int { + return 1 +} + {{ if eq (dialect) "mysql" }} func (q {{ $camelName }}InsertSQL) {{ $camelName }}InsertSQLToSql() (string, []any, error) { {{- end }} @@ -81,40 +86,32 @@ func (q {{ $camelName }}InsertSQL) Exec(db sqlla.DB) ({{ .StructName }}, error) {{- else -}} func (q {{ $camelName }}InsertSQL) Exec(db sqlla.DB) (sql.Result, error) { {{- end }} + return q.ExecContext(context.Background(), db) +} + +{{ if eq (dialect) "postgresql" }} +func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { query, args, err := q.ToSql() if err != nil { - {{ if .HasPk -}} return {{ .StructName }}{}, err - {{- else }} - return nil, err - {{- end }} } - {{- if not .HasPk }} - result, err := db.Exec(query, args...) - return result, err - {{- else }} - {{- if eq (dialect) "mysql" }} - result, err := db.Exec(query, args...) + row := db.QueryRowContext(ctx, query, args...) + result, err := {{ $constructor }}().Select().Scan(row) if err != nil { return {{ .StructName }}{}, err } - id, err := result.LastInsertId() + return result, nil +} + +func (q {{ $camelName }}InsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() if err != nil { - return {{ .StructName }}{}, err - } - return {{ $constructor }}().Select().PkColumn(id).Single(db) - {{- end }} - {{- if eq (dialect) "postgresql" }} - row := db.QueryRow(query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { - return {{ .StructName }}{}, err + return nil, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).Single(db) - {{- end }} - {{- end }} + result, err := db.ExecContext(ctx, query, args...) + return result, err } - +{{ else }} {{ if .HasPk -}} func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { {{- else -}} @@ -132,7 +129,6 @@ func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) result, err := db.ExecContext(ctx, query, args...) return result, err {{- else }} - {{- if eq (dialect) "mysql" }} result, err := db.ExecContext(ctx, query, args...) if err != nil { return {{ .StructName }}{}, err @@ -143,15 +139,6 @@ func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) } return {{ $constructor }}().Select().PkColumn(id).SingleContext(ctx, db) {{- end }} - {{- if eq (dialect) "postgresql" }} - row := db.QueryRowContext(ctx, query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { - return {{ .StructName }}{}, err - } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).SingleContext(ctx, db) - {{- end }} - {{- end }} } {{ if .HasPk -}} @@ -164,12 +151,14 @@ func (q {{ $camelName }}InsertSQL) ExecContextWithoutSelect(ctx context.Context, return result, err } {{- end }} +{{- end }} type {{ $camelName }}DefaultInsertHooker interface { DefaultInsertHook({{ $camelName }}InsertSQL) ({{ $camelName }}InsertSQL, error) } type {{ $camelName }}InsertSQLToSqler interface { + rowsNum() int {{- if eq (dialect) "mysql" }} {{ $camelName }}InsertSQLToSql() (string, []any, error) {{- end }} @@ -192,6 +181,10 @@ func (q *{{ $camelName }}BulkInsertSQL) Append(iqs ...{{ $camelName }}InsertSQL) q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *{{ $camelName }}BulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + {{ if eq (dialect) "mysql" }} func (q *{{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []any, error) { {{- end }} @@ -264,45 +257,49 @@ func (q *{{ $camelName }}BulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} - return query + " RETURNING " + {{ cquoteby .PkColumn.Name }} + ";", vs, nil + {{- if eq (dialect) "postgresql" }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + return query + " RETURNING " + columns + ";", vs, nil {{- else }} return query + ";", vs, nil {{- end }} } -{{- if and .HasPk (eq (dialect) "postgresql") }} +{{- if eq (dialect) "postgresql" }} func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{- else }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{- end }} query, args, err := q.ToSql() if err != nil { return nil, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - pks := make([]{{ .PkColumn.TypeName }}, 0, len(q.insertSQLs)) + results := make([]{{ .StructName }}, 0, len(q.insertSQLs)) + sel := {{ $constructor }}().Select() for rows.Next() { - var pk {{ .PkColumn.TypeName }} - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) + } + return results, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() + if err != nil { + return nil, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}In(pks...).AllContext(ctx, db) - {{- else }} result, err := db.ExecContext(ctx, query, args...) return result, err - {{- end }} } +{{- end }} -{{- if and .HasPk (eq (dialect) "postgresql") }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { +{{- if eq (dialect) "mysql" }} +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { return nil, err @@ -312,6 +309,7 @@ func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Con } {{- end }} + {{- if eq (dialect) "mysql" }} {{ template "InsertMySQL" . }} {{- end }} diff --git a/template/insert_postgresql.tmpl b/template/insert_postgresql.tmpl index 7ea59c3..7aea215 100644 --- a/template/insert_postgresql.tmpl +++ b/template/insert_postgresql.tmpl @@ -1,17 +1,17 @@ -{{ define "InsertPostgreSQL.ExecContextHasPkSingle" }} +{{ define "InsertPostgreSQL.ExecContextSingle" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} query, args, err := q.ToSql() if err != nil { return {{ .StructName }}{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { + result, err := {{ $constructor }}().Select().Scan(row) + if err != nil { return {{ .StructName }}{}, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).SingleContext(ctx, db) + return result, nil {{ end }} -{{ define "InsertPostgreSQL.ExecContextHasPkAll" }} +{{ define "InsertPostgreSQL.ExecContextAll" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} query, args, err := q.ToSql() if err != nil { @@ -22,16 +22,17 @@ return nil, err } defer rows.Close() - pks := make([]{{ .PkColumn.TypeName }}, 0) + results := make([]{{ .StructName }}, 0, q.insertSQL.rowsNum()) + sel := {{ $constructor }}().Select() for rows.Next() { - var pk {{ .PkColumn.TypeName }} - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}In(pks...).AllContext(ctx, db) + return results, nil {{ end }} {{ define "InsertPostgreSQL.ExecContextWithoutSelect" }} query, args, err := q.ToSql() @@ -47,10 +48,8 @@ if err != nil { return "", nil, err } - query += " ON CONFLICT DO NOTHING" - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns return query + ";", vs, nil {{ end }} {{ define "InsertPostgreSQL.DoUpdateToSql" }} @@ -75,9 +74,8 @@ } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil {{ end }} @@ -100,19 +98,13 @@ func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ToSql() (string, []any, er {{ template "InsertPostgreSQL.DoNothingToSql" . }} } -{{ if .HasPk -}} func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkSingle" . }} +{{ template "InsertPostgreSQL.ExecContextSingle" . }} } func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}InsertOnConflictDoUpdateSQL struct { insertSQL {{ $camelName }}InsertSQLToSqler @@ -151,26 +143,19 @@ func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ToSql() (string, []any, err } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } -{{ if .HasPk -}} func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkSingle" . }} +{{ template "InsertPostgreSQL.ExecContextSingle" . }} } func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}DefaultInsertOnConflictDoUpdateHooker interface { DefaultInsertOnConflictDoUpdateHook({{ $camelName }}InsertOnConflictDoUpdateSQL) ({{ $camelName }}InsertOnConflictDoUpdateSQL, error) @@ -190,19 +175,13 @@ func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ToSql() (string, []any {{ template "InsertPostgreSQL.DoNothingToSql" . }} } -{{ if .HasPk -}} func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkAll" . }} +{{ template "InsertPostgreSQL.ExecContextAll" . }} } func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}BulkInsertOnConflictDoUpdateSQL struct { insertSQL {{ $camelName }}InsertSQLToSqler @@ -248,25 +227,18 @@ func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} + columns := strings.Join({{ $camelName }}AllColumns, ", ") + query += " RETURNING " + columns return query + ";", vs, nil } -{{ if .HasPk -}} func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkAll" . }} +{{ template "InsertPostgreSQL.ExecContextAll" . }} } func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} {{ end }} From 3d7cfdb1b639720dbb1e53079eed62f519238d0e Mon Sep 17 00:00:00 2001 From: mackee Date: Thu, 19 Jun 2025 15:23:10 +0900 Subject: [PATCH 2/3] feat(postgresql): use RETURNING clause in UPDATE to return updated rows Previously, UPDATE methods for PostgreSQL selected updated rows using a separate SELECT by PK after executing the update statement. This change updates the templates and generated code to use the RETURNING clause in PostgreSQL UPDATE statements, allowing the updated rows to be fetched directly within the update query. This approach improves performance and accuracy by returning the actual modified records without requiring an additional SELECT. Test cases have been adjusted accordingly. --- _example/group.gen.go | 12 +------ _example/postgresql/account_test.go | 34 ++++++++++++++++++ _example/postgresql/accounts.gen.go | 39 +++++++++++++-------- _example/postgresql/groups.gen.go | 39 +++++++++++++-------- _example/postgresql/identities.gen.go | 39 +++++++++++++-------- _example/user.gen.go | 12 +------ _example/user_external.gen.go | 12 +------ _example/user_item.gen.go | 12 +------ _example/user_sns.gen.go | 12 +------ template/update.tmpl | 50 +++++++++++++++++++++------ 10 files changed, 151 insertions(+), 110 deletions(-) diff --git a/_example/group.gen.go b/_example/group.gen.go index 2eb16ec..0234792 100644 --- a/_example/group.gen.go +++ b/_example/group.gen.go @@ -690,17 +690,7 @@ func (s Group) Update() groupUpdateSQL { } func (q groupUpdateSQL) Exec(db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.groupSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q groupUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { diff --git a/_example/postgresql/account_test.go b/_example/postgresql/account_test.go index 59c9703..c15b68c 100644 --- a/_example/postgresql/account_test.go +++ b/_example/postgresql/account_test.go @@ -205,6 +205,7 @@ type updateQueryTestCase[T any] struct { query updateQuery[T] expected string vs []any + setup func(t *testing.T, db sqlla.DB) expectedResult []T } @@ -225,6 +226,9 @@ func (u updateQueryTestCase[T]) assert(t *testing.T, opts ...cmp.Option) { t.Helper() ctx := context.Background() + if u.setup != nil { + u.setup(t, db) + } got, err := u.query.ExecContext(ctx, db) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -586,15 +590,42 @@ func testCases() []testCaseWithToQueryTestCase { }, updateQueryTestCase[postgresql.Account]{ name: "update", + setup: func(t *testing.T, db sqlla.DB) { + if _, err := postgresql.NewAccountSQL().Insert(). + ValueID(42). + ValueName("foo"). + ValueEmbedding(sampleVector). + ValueCreatedAt(sampleDate). + ValueUpdatedAt(sampleDate). + ExecContextWithoutSelect(t.Context(), db); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }, query: postgresql.NewAccountSQL().Update(). SetName("bar"). SetEmbedding(sampleVector). WhereID(42), expected: `UPDATE "accounts" SET "embedding" = $1, "name" = $2, "updated_at" = $3 WHERE "id" = $4;`, vs: []any{sampleVector, "bar", sampleDate, int64(42)}, + expectedResult: []postgresql.Account{ + {ID: 42, Name: "bar", Embedding: sampleVector, CreatedAt: sampleDate, UpdatedAt: sampleDate}, + }, }, updateQueryTestCase[postgresql.Group]{ name: "update with set null", + setup: func(t *testing.T, db sqlla.DB) { + if _, err := postgresql.NewGroupSQL().Insert(). + ValueID(111). + ValueName("foo"). + ValueLeaderAccountID(42). + ValueSubLeaderAccountID(28). + ValueChildGroupID(43). + ValueCreatedAt(sampleDate). + ValueUpdatedAt(sampleDate). + ExecContextWithoutSelect(t.Context(), db); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }, query: postgresql.NewGroupSQL().Update(). SetLeaderAccountID(42). SetSubLeaderAccountID(28). @@ -602,6 +633,9 @@ func testCases() []testCaseWithToQueryTestCase { WhereID(111), expected: `UPDATE "groups" SET "child_group_id" = $1, "leader_account_id" = $2, "sub_leader_account_id" = $3 WHERE "id" = $4;`, vs: []any{sql.Null[int64]{}, int64(42), int64(28), int64(111)}, + expectedResult: []postgresql.Group{ + {ID: 111, Name: "foo", LeaderAccountID: 42, SubLeaderAccountID: sql.Null[postgresql.AccountID]{Valid: true, V: 28}, ChildGroupID: sql.Null[postgresql.GroupID]{}, CreatedAt: sampleDate, UpdatedAt: sampleDate}, + }, }, deleteQueryTestCase{ name: "delete with where", diff --git a/_example/postgresql/accounts.gen.go b/_example/postgresql/accounts.gen.go index 1050146..0d1a4cf 100644 --- a/_example/postgresql/accounts.gen.go +++ b/_example/postgresql/accounts.gen.go @@ -528,36 +528,45 @@ func (q accountUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q accountUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil +} + func (s Account) Update() accountUpdateSQL { return NewAccountSQL().Update().WhereID(s.ID) } func (q accountUpdateSQL) Exec(db sqlla.DB) ([]Account, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.accountSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q accountUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.accountSQL + results := make([]Account, 0, 1) + defer rows.Close() + sel := NewAccountSQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type accountDefaultUpdateHooker interface { diff --git a/_example/postgresql/groups.gen.go b/_example/postgresql/groups.gen.go index 184cc07..4c23b19 100644 --- a/_example/postgresql/groups.gen.go +++ b/_example/postgresql/groups.gen.go @@ -678,36 +678,45 @@ func (q groupUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q groupUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil +} + func (s Group) Update() groupUpdateSQL { return NewGroupSQL().Update().WhereID(s.ID) } func (q groupUpdateSQL) Exec(db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.groupSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q groupUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.groupSQL + results := make([]Group, 0, 1) + defer rows.Close() + sel := NewGroupSQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type groupDefaultUpdateHooker interface { diff --git a/_example/postgresql/identities.gen.go b/_example/postgresql/identities.gen.go index 6386748..772b24d 100644 --- a/_example/postgresql/identities.gen.go +++ b/_example/postgresql/identities.gen.go @@ -534,36 +534,45 @@ func (q identityUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q identityUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil +} + func (s Identity) Update() identityUpdateSQL { return NewIdentitySQL().Update().WhereID(s.ID) } func (q identityUpdateSQL) Exec(db sqlla.DB) ([]Identity, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.identitySQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q identityUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.identitySQL + results := make([]Identity, 0, 1) + defer rows.Close() + sel := NewIdentitySQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type identityDefaultUpdateHooker interface { diff --git a/_example/user.gen.go b/_example/user.gen.go index 6d2c3da..0e33f86 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -610,17 +610,7 @@ func (s User) Update() userUpdateSQL { } func (q userUpdateSQL) Exec(db sqlla.DB) ([]User, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]User, error) { diff --git a/_example/user_external.gen.go b/_example/user_external.gen.go index 0ede046..0c3b440 100644 --- a/_example/user_external.gen.go +++ b/_example/user_external.gen.go @@ -530,17 +530,7 @@ func (s UserExternal) Update() userExternalUpdateSQL { } func (q userExternalUpdateSQL) Exec(db sqlla.DB) ([]UserExternal, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userExternalSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userExternalUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserExternal, error) { diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 89dcde8..1c7e374 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -565,17 +565,7 @@ func (s UserItem) Update() userItemUpdateSQL { } func (q userItemUpdateSQL) Exec(db sqlla.DB) ([]UserItem, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userItemSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userItemUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserItem, error) { diff --git a/_example/user_sns.gen.go b/_example/user_sns.gen.go index 8b4427c..2ce5a32 100644 --- a/_example/user_sns.gen.go +++ b/_example/user_sns.gen.go @@ -495,17 +495,7 @@ func (s UserSNS) Update() userSNSUpdateSQL { } func (q userSNSUpdateSQL) Exec(db sqlla.DB) ([]UserSNS, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userSNSSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userSNSUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserSNS, error) { diff --git a/template/update.tmpl b/template/update.tmpl index 7ddb2b1..df3332f 100644 --- a/template/update.tmpl +++ b/template/update.tmpl @@ -50,23 +50,56 @@ func (q {{ $camelName }}UpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } -{{- if .HasPk }} +{{- if eq (dialect) "postgresql" }} +func (q {{ $camelName }}UpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +} + func (s {{ .StructName }}) Update() {{ $camelName }}UpdateSQL { return {{ $constructor }}().Update().Where{{ .PkColumn.Name | toCamel | title }}(s.{{ .PkColumn.FieldName }}) } func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) ([]{{ .StructName }}, error) { - query, args, err := q.ToSql() + return q.ExecContext(context.Background(), db) +} + +func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.Exec(query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.{{ $camelName }}SQL + results := make([]{{ .StructName }}, 0, 1) + defer rows.Close() + sel := {{ $constructor }}().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } + + return results, nil +} +{{- end }} +{{- if eq (dialect) "mysql" }} +{{- if .HasPk }} +func (s {{ .StructName }}) Update() {{ $camelName }}UpdateSQL { + return {{ $constructor }}().Update().Where{{ .PkColumn.Name | toCamel | title }}(s.{{ .PkColumn.FieldName }}) +} - return qq.Select().All(db) +func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) ([]{{ .StructName }}, error) { + return q.ExecContext(context.Background(), db) } func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { @@ -84,11 +117,7 @@ func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) } {{- else }} func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) (sql.Result, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - return db.Exec(query, args...) + return q.ExecContext(context.Background(), db) } func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -99,6 +128,7 @@ func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) return db.ExecContext(ctx, query, args...) } {{- end }} +{{- end }} type {{ $camelName }}DefaultUpdateHooker interface { DefaultUpdateHook({{ $camelName }}UpdateSQL) ({{ $camelName }}UpdateSQL, error) From dbc30fd435766c50d6b1e1484fb6a3c531707d38 Mon Sep 17 00:00:00 2001 From: mackee Date: Thu, 19 Jun 2025 16:51:02 +0900 Subject: [PATCH 3/3] feat(postgresql): separate RETURNING clause from ToSql to ToSqlWithReturning for INSERT BREAKING CHANGE: The ToSql method for INSERT queries in the PostgreSQL dialect no longer automatically appends the RETURNING clause. A new ToSqlWithReturning method has been introduced which appends the RETURNING clause as before. Users who require the RETURNING clause should use ToSqlWithReturning, or explicitly add it themselves when using ToSql. This provides more control over the generated SQL and aligns with user responsibility for query output. --- _example/postgresql/account_test.go | 14 ++-- _example/postgresql/accounts.gen.go | 92 +++++++++++++++++++++------ _example/postgresql/groups.gen.go | 92 +++++++++++++++++++++------ _example/postgresql/identities.gen.go | 92 +++++++++++++++++++++------ template/insert.tmpl | 36 +++++++---- template/insert_postgresql.tmpl | 39 +++++++++--- 6 files changed, 282 insertions(+), 83 deletions(-) diff --git a/_example/postgresql/account_test.go b/_example/postgresql/account_test.go index c15b68c..2681a83 100644 --- a/_example/postgresql/account_test.go +++ b/_example/postgresql/account_test.go @@ -417,7 +417,7 @@ func testCases() []testCaseWithToQueryTestCase { query: postgresql.NewAccountSQL().Insert(). ValueName("foo"). ValueEmbedding(sampleVector), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4);`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -436,7 +436,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueChildGroupIDIsNull(). ValueCreatedAt(sampleDate). ValueUpdatedAt(sampleDate), - expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6) RETURNING "id", "name", "leader_account_id", "sub_leader_account_id", "child_group_id", "created_at", "updated_at";`, + expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6);`, vs: []any{sql.Null[int64]{}, sampleDate, int64(42), "foo", int64(28), sampleDate}, expectedResult: postgresql.Group{ ID: 1, @@ -454,7 +454,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueName("foo"). ValueEmbedding(sampleVector). OnConflictDoNothing(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING;`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -472,7 +472,7 @@ func testCases() []testCaseWithToQueryTestCase { OnConflictDoUpdate("id"). ValueOnUpdateName("powawa"). SameOnUpdateEmbedding(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5 RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5;`, vs: []any{sampleDate, sampleVector, "foo", sampleDate, "powawa"}, expectedResult: postgresql.Account{ ID: 1, @@ -494,7 +494,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12);`, vs: []any{ sampleDate, sampleVector, @@ -527,7 +527,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi.OnConflictDoNothing() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING;`, vs: []any{ sampleDate, sampleVector, @@ -563,7 +563,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueOnUpdateName("powawa"). SameOnUpdateEmbedding() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16 RETURNING "id", "name", "embedding", "created_at", "updated_at";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16;`, vs: []any{ sampleDate, sampleVector, diff --git a/_example/postgresql/accounts.gen.go b/_example/postgresql/accounts.gen.go index 0d1a4cf..422ca2e 100644 --- a/_example/postgresql/accounts.gen.go +++ b/_example/postgresql/accounts.gen.go @@ -616,8 +616,17 @@ func (q accountInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(accountAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil +} + +func (q accountInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query, args, nil } func (q accountInsertSQL) rowsNum() int { @@ -647,7 +656,7 @@ func (q accountInsertSQL) Exec(db sqlla.DB) (Account, error) { } func (q accountInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } @@ -731,11 +740,20 @@ func (q *accountBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(accountAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil } -func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q *accountBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil +} + +func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -780,14 +798,24 @@ func (q accountInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - columns := strings.Join(accountAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q accountInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { +func (q accountInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } @@ -921,14 +949,23 @@ func (q accountInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(accountAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q accountInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { +func (q accountInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } @@ -971,14 +1008,24 @@ func (q accountBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) if err != nil { return "", nil, err } - columns := strings.Join(accountAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q accountBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1129,14 +1176,23 @@ func (q accountBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(accountAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q accountBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } diff --git a/_example/postgresql/groups.gen.go b/_example/postgresql/groups.gen.go index 4c23b19..4c6cea4 100644 --- a/_example/postgresql/groups.gen.go +++ b/_example/postgresql/groups.gen.go @@ -786,8 +786,17 @@ func (q groupInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(groupAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil +} + +func (q groupInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query, args, nil } func (q groupInsertSQL) rowsNum() int { @@ -817,7 +826,7 @@ func (q groupInsertSQL) Exec(db sqlla.DB) (Group, error) { } func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } @@ -901,11 +910,20 @@ func (q *groupBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(groupAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil } -func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q *groupBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil +} + +func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -950,14 +968,24 @@ func (q groupInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - columns := strings.Join(groupAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q groupInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { +func (q groupInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } @@ -1131,14 +1159,23 @@ func (q groupInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(groupAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q groupInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { +func (q groupInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } @@ -1181,14 +1218,24 @@ func (q groupBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - columns := strings.Join(groupAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q groupBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1379,14 +1426,23 @@ func (q groupBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(groupAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q groupBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } diff --git a/_example/postgresql/identities.gen.go b/_example/postgresql/identities.gen.go index 772b24d..e7b4fdb 100644 --- a/_example/postgresql/identities.gen.go +++ b/_example/postgresql/identities.gen.go @@ -622,8 +622,17 @@ func (q identityInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(identityAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil +} + +func (q identityInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query, args, nil } func (q identityInsertSQL) rowsNum() int { @@ -653,7 +662,7 @@ func (q identityInsertSQL) Exec(db sqlla.DB) (Identity, error) { } func (q identityInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } @@ -737,11 +746,20 @@ func (q *identityBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - columns := strings.Join(identityAllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil + return query + ";", vs, nil } -func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q *identityBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil +} + +func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -786,14 +804,24 @@ func (q identityInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { if err != nil { return "", nil, err } - columns := strings.Join(identityAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q identityInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { +func (q identityInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } @@ -927,14 +955,23 @@ func (q identityInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(identityAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q identityInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { +func (q identityInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } @@ -977,14 +1014,24 @@ func (q identityBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) if err != nil { return "", nil, err } - columns := strings.Join(identityAllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil } -func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q identityBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1135,14 +1182,23 @@ func (q identityBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join(identityAllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } -func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q identityBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } diff --git a/template/insert.tmpl b/template/insert.tmpl index daedf0f..84b2c6b 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -25,14 +25,21 @@ func (q {{ $camelName }}InsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if eq (dialect) "postgresql" }} - columns := strings.Join({{ $camelName }}AllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil - {{- else }} return query + ";", vs, nil - {{- end }} } +{{ if eq (dialect) "postgresql" }} +func (q {{ $camelName }}InsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query, args, nil +} +{{ end }} + func (q {{ $camelName }}InsertSQL) rowsNum() int { return 1 } @@ -91,7 +98,7 @@ func (q {{ $camelName }}InsertSQL) Exec(db sqlla.DB) (sql.Result, error) { {{ if eq (dialect) "postgresql" }} func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return {{ .StructName }}{}, err } @@ -257,17 +264,22 @@ func (q *{{ $camelName }}BulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if eq (dialect) "postgresql" }} - columns := strings.Join({{ $camelName }}AllColumns, ", ") - return query + " RETURNING " + columns + ";", vs, nil - {{- else }} return query + ";", vs, nil - {{- end }} } {{- if eq (dialect) "postgresql" }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { +func (q *{{ $camelName }}BulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } diff --git a/template/insert_postgresql.tmpl b/template/insert_postgresql.tmpl index 7aea215..7337bdc 100644 --- a/template/insert_postgresql.tmpl +++ b/template/insert_postgresql.tmpl @@ -1,6 +1,6 @@ {{ define "InsertPostgreSQL.ExecContextSingle" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return {{ .StructName }}{}, err } @@ -13,7 +13,7 @@ {{ end }} {{ define "InsertPostgreSQL.ExecContextAll" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -48,8 +48,7 @@ if err != nil { return "", nil, err } - columns := strings.Join({{ $camelName }}AllColumns, ", ") - query += " ON CONFLICT DO NOTHING" + " RETURNING " + columns + query += " ON CONFLICT DO NOTHING" return query + ";", vs, nil {{ end }} {{ define "InsertPostgreSQL.DoUpdateToSql" }} @@ -74,11 +73,19 @@ } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join({{ $camelName }}AllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil {{ end }} +{{ define "InsertPostgreSQL.ToSqlWithReturning" }} +{{- $camelName := .Name | toCamel | untitle -}} + query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +{{ end }} {{ define "InsertPostgreSQL" }} {{- $camelName := .Name | toCamel | untitle -}} @@ -98,6 +105,10 @@ func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ToSql() (string, []any, er {{ template "InsertPostgreSQL.DoNothingToSql" . }} } +func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { {{ template "InsertPostgreSQL.ExecContextSingle" . }} } @@ -143,12 +154,14 @@ func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ToSql() (string, []any, err } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join({{ $camelName }}AllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } +func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { {{ template "InsertPostgreSQL.ExecContextSingle" . }} } @@ -175,6 +188,10 @@ func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ToSql() (string, []any {{ template "InsertPostgreSQL.DoNothingToSql" . }} } +func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { {{ template "InsertPostgreSQL.ExecContextAll" . }} } @@ -227,12 +244,14 @@ func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - columns := strings.Join({{ $camelName }}AllColumns, ", ") - query += " RETURNING " + columns return query + ";", vs, nil } +func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { {{ template "InsertPostgreSQL.ExecContextAll" . }} }