diff --git a/_example/group.gen.go b/_example/group.gen.go index fb09442..0234792 100644 --- a/_example/group.gen.go +++ b/_example/group.gen.go @@ -690,17 +690,7 @@ func (s Group) Update() groupUpdateSQL { } func (q groupUpdateSQL) Exec(db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.groupSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q groupUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { @@ -787,6 +777,10 @@ func (q groupInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q groupInsertSQL) rowsNum() int { + return 1 +} + func (q groupInsertSQL) groupInsertSQLToSql() (string, []any, error) { var err error var s interface{} = Group{} @@ -806,19 +800,7 @@ func (q groupInsertSQL) groupInsertSQLToSql() (string, []any, error) { } func (q groupInsertSQL) Exec(db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() - if err != nil { - return Group{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return Group{}, err - } - id, err := result.LastInsertId() - if err != nil { - return Group{}, err - } - return NewGroupSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { @@ -851,6 +833,7 @@ type groupDefaultInsertHooker interface { } type groupInsertSQLToSqler interface { + rowsNum() int groupInsertSQLToSql() (string, []any, error) } @@ -868,6 +851,10 @@ func (q *groupBulkInsertSQL) Append(iqs ...groupInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *groupBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *groupBulkInsertSQL) groupInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This groupBulkInsertSQL's InsertSQL was empty") diff --git a/_example/postgresql/account_test.go b/_example/postgresql/account_test.go index 2487b56..2681a83 100644 --- a/_example/postgresql/account_test.go +++ b/_example/postgresql/account_test.go @@ -205,6 +205,7 @@ type updateQueryTestCase[T any] struct { query updateQuery[T] expected string vs []any + setup func(t *testing.T, db sqlla.DB) expectedResult []T } @@ -225,6 +226,9 @@ func (u updateQueryTestCase[T]) assert(t *testing.T, opts ...cmp.Option) { t.Helper() ctx := context.Background() + if u.setup != nil { + u.setup(t, db) + } got, err := u.query.ExecContext(ctx, db) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -413,7 +417,7 @@ func testCases() []testCaseWithToQueryTestCase { query: postgresql.NewAccountSQL().Insert(). ValueName("foo"). ValueEmbedding(sampleVector), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4);`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -432,7 +436,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueChildGroupIDIsNull(). ValueCreatedAt(sampleDate). ValueUpdatedAt(sampleDate), - expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6) RETURNING "id";`, + expected: `INSERT INTO "groups" ("child_group_id","created_at","leader_account_id","name","sub_leader_account_id","updated_at") VALUES ($1,$2,$3,$4,$5,$6);`, vs: []any{sql.Null[int64]{}, sampleDate, int64(42), "foo", int64(28), sampleDate}, expectedResult: postgresql.Group{ ID: 1, @@ -450,7 +454,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueName("foo"). ValueEmbedding(sampleVector). OnConflictDoNothing(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT DO NOTHING;`, vs: []any{sampleDate, sampleVector, "foo", sampleDate}, expectedResult: postgresql.Account{ ID: 1, @@ -468,7 +472,7 @@ func testCases() []testCaseWithToQueryTestCase { OnConflictDoUpdate("id"). ValueOnUpdateName("powawa"). SameOnUpdateEmbedding(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5 RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $5;`, vs: []any{sampleDate, sampleVector, "foo", sampleDate, "powawa"}, expectedResult: postgresql.Account{ ID: 1, @@ -490,7 +494,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12);`, vs: []any{ sampleDate, sampleVector, @@ -523,7 +527,7 @@ func testCases() []testCaseWithToQueryTestCase { } return bi.OnConflictDoNothing() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","name","updated_at") VALUES ($1,$2,$3,$4),($5,$6,$7,$8),($9,$10,$11,$12) ON CONFLICT DO NOTHING;`, vs: []any{ sampleDate, sampleVector, @@ -559,7 +563,7 @@ func testCases() []testCaseWithToQueryTestCase { ValueOnUpdateName("powawa"). SameOnUpdateEmbedding() }(), - expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16 RETURNING "id";`, + expected: `INSERT INTO "accounts" ("created_at","embedding","id","name","updated_at") VALUES ($1,$2,$3,$4,$5),($6,$7,$8,$9,$10),($11,$12,$13,$14,$15) ON CONFLICT (id) DO UPDATE SET "embedding" = "excluded"."embedding", "name" = $16;`, vs: []any{ sampleDate, sampleVector, @@ -586,15 +590,42 @@ func testCases() []testCaseWithToQueryTestCase { }, updateQueryTestCase[postgresql.Account]{ name: "update", + setup: func(t *testing.T, db sqlla.DB) { + if _, err := postgresql.NewAccountSQL().Insert(). + ValueID(42). + ValueName("foo"). + ValueEmbedding(sampleVector). + ValueCreatedAt(sampleDate). + ValueUpdatedAt(sampleDate). + ExecContextWithoutSelect(t.Context(), db); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }, query: postgresql.NewAccountSQL().Update(). SetName("bar"). SetEmbedding(sampleVector). WhereID(42), expected: `UPDATE "accounts" SET "embedding" = $1, "name" = $2, "updated_at" = $3 WHERE "id" = $4;`, vs: []any{sampleVector, "bar", sampleDate, int64(42)}, + expectedResult: []postgresql.Account{ + {ID: 42, Name: "bar", Embedding: sampleVector, CreatedAt: sampleDate, UpdatedAt: sampleDate}, + }, }, updateQueryTestCase[postgresql.Group]{ name: "update with set null", + setup: func(t *testing.T, db sqlla.DB) { + if _, err := postgresql.NewGroupSQL().Insert(). + ValueID(111). + ValueName("foo"). + ValueLeaderAccountID(42). + ValueSubLeaderAccountID(28). + ValueChildGroupID(43). + ValueCreatedAt(sampleDate). + ValueUpdatedAt(sampleDate). + ExecContextWithoutSelect(t.Context(), db); err != nil { + t.Fatalf("unexpected error: %v", err) + } + }, query: postgresql.NewGroupSQL().Update(). SetLeaderAccountID(42). SetSubLeaderAccountID(28). @@ -602,6 +633,9 @@ func testCases() []testCaseWithToQueryTestCase { WhereID(111), expected: `UPDATE "groups" SET "child_group_id" = $1, "leader_account_id" = $2, "sub_leader_account_id" = $3 WHERE "id" = $4;`, vs: []any{sql.Null[int64]{}, int64(42), int64(28), int64(111)}, + expectedResult: []postgresql.Group{ + {ID: 111, Name: "foo", LeaderAccountID: 42, SubLeaderAccountID: sql.Null[postgresql.AccountID]{Valid: true, V: 28}, ChildGroupID: sql.Null[postgresql.GroupID]{}, CreatedAt: sampleDate, UpdatedAt: sampleDate}, + }, }, deleteQueryTestCase{ name: "delete with where", diff --git a/_example/postgresql/accounts.gen.go b/_example/postgresql/accounts.gen.go index e88a18b..422ca2e 100644 --- a/_example/postgresql/accounts.gen.go +++ b/_example/postgresql/accounts.gen.go @@ -528,36 +528,45 @@ func (q accountUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q accountUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil +} + func (s Account) Update() accountUpdateSQL { return NewAccountSQL().Update().WhereID(s.ID) } func (q accountUpdateSQL) Exec(db sqlla.DB) ([]Account, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.accountSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q accountUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.accountSQL + results := make([]Account, 0, 1) + defer rows.Close() + sel := NewAccountSQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type accountDefaultUpdateHooker interface { @@ -607,7 +616,21 @@ func (q accountInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil +} + +func (q accountInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query, args, nil +} + +func (q accountInsertSQL) rowsNum() int { + return 1 } func (q accountInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -629,29 +652,20 @@ func (q accountInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []an } func (q accountInsertSQL) Exec(db sqlla.DB) (Account, error) { - query, args, err := q.ToSql() - if err != nil { - return Account{}, err - } - row := db.QueryRow(query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { - return Account{}, err - } - return NewAccountSQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q accountInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q accountInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -668,6 +682,7 @@ type accountDefaultInsertHooker interface { } type accountInsertSQLToSqler interface { + rowsNum() int accountInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -685,6 +700,10 @@ func (q *accountBulkInsertSQL) Append(iqs ...accountInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *accountBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *accountBulkInsertSQL) accountInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This accountBulkInsertSQL's InsertSQL was empty") @@ -721,10 +740,20 @@ func (q *accountBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil } -func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q *accountBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil +} + +func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -733,16 +762,18 @@ func (q *accountBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([] return nil, err } defer rows.Close() - pks := make([]AccountID, 0, len(q.insertSQLs)) + results := make([]Account, 0, len(q.insertSQLs)) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *accountBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -768,22 +799,32 @@ func (q accountInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q accountInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { +func (q accountInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -908,22 +949,32 @@ func (q accountInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q accountInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { +func (q accountInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Account{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk AccountID - if err := row.Scan(&pk); err != nil { + result, err := NewAccountSQL().Select().Scan(row) + if err != nil { return Account{}, err } - return NewAccountSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -958,13 +1009,23 @@ func (q accountBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q accountBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -973,16 +1034,17 @@ func (q accountBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context return nil, err } defer rows.Close() - pks := make([]AccountID, 0) + results := make([]Account, 0, q.insertSQL.rowsNum()) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1114,13 +1176,23 @@ func (q accountBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { +func (q accountBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(accountAllColumns, ", ") + return query + ";", args, nil + +} + +func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Account, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1129,16 +1201,17 @@ func (q accountBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, return nil, err } defer rows.Close() - pks := make([]AccountID, 0) + results := make([]Account, 0, q.insertSQL.rowsNum()) + sel := NewAccountSQL().Select() for rows.Next() { - var pk AccountID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewAccountSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/postgresql/groups.gen.go b/_example/postgresql/groups.gen.go index 29d194e..4c6cea4 100644 --- a/_example/postgresql/groups.gen.go +++ b/_example/postgresql/groups.gen.go @@ -678,36 +678,45 @@ func (q groupUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q groupUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil +} + func (s Group) Update() groupUpdateSQL { return NewGroupSQL().Update().WhereID(s.ID) } func (q groupUpdateSQL) Exec(db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.groupSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q groupUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.groupSQL + results := make([]Group, 0, 1) + defer rows.Close() + sel := NewGroupSQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type groupDefaultUpdateHooker interface { @@ -777,7 +786,21 @@ func (q groupInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil +} + +func (q groupInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query, args, nil +} + +func (q groupInsertSQL) rowsNum() int { + return 1 } func (q groupInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -799,29 +822,20 @@ func (q groupInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, e } func (q groupInsertSQL) Exec(db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() - if err != nil { - return Group{}, err - } - row := db.QueryRow(query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { - return Group{}, err - } - return NewGroupSQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q groupInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q groupInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -838,6 +852,7 @@ type groupDefaultInsertHooker interface { } type groupInsertSQLToSqler interface { + rowsNum() int groupInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -855,6 +870,10 @@ func (q *groupBulkInsertSQL) Append(iqs ...groupInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *groupBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *groupBulkInsertSQL) groupInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This groupBulkInsertSQL's InsertSQL was empty") @@ -891,10 +910,20 @@ func (q *groupBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil } -func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q *groupBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil +} + +func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -903,16 +932,18 @@ func (q *groupBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Gr return nil, err } defer rows.Close() - pks := make([]GroupID, 0, len(q.insertSQLs)) + results := make([]Group, 0, len(q.insertSQLs)) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *groupBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -938,22 +969,32 @@ func (q groupInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q groupInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { +func (q groupInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -1118,22 +1159,32 @@ func (q groupInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q groupInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { +func (q groupInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Group{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk GroupID - if err := row.Scan(&pk); err != nil { + result, err := NewGroupSQL().Select().Scan(row) + if err != nil { return Group{}, err } - return NewGroupSQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -1168,13 +1219,23 @@ func (q groupBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q groupBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1183,16 +1244,17 @@ func (q groupBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, return nil, err } defer rows.Close() - pks := make([]GroupID, 0) + results := make([]Group, 0, q.insertSQL.rowsNum()) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1364,13 +1426,23 @@ func (q groupBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { +func (q groupBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(groupAllColumns, ", ") + return query + ";", args, nil + +} + +func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Group, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1379,16 +1451,17 @@ func (q groupBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, d return nil, err } defer rows.Close() - pks := make([]GroupID, 0) + results := make([]Group, 0, q.insertSQL.rowsNum()) + sel := NewGroupSQL().Select() for rows.Next() { - var pk GroupID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewGroupSQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/postgresql/identities.gen.go b/_example/postgresql/identities.gen.go index b0e8bdb..e7b4fdb 100644 --- a/_example/postgresql/identities.gen.go +++ b/_example/postgresql/identities.gen.go @@ -534,36 +534,45 @@ func (q identityUpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } +func (q identityUpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil +} + func (s Identity) Update() identityUpdateSQL { return NewIdentitySQL().Update().WhereID(s.ID) } func (q identityUpdateSQL) Exec(db sqlla.DB) ([]Identity, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.identitySQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q identityUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.ExecContext(ctx, query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.identitySQL + results := make([]Identity, 0, 1) + defer rows.Close() + sel := NewIdentitySQL().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } - return qq.Select().AllContext(ctx, db) + return results, nil } type identityDefaultUpdateHooker interface { @@ -613,7 +622,21 @@ func (q identityInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil +} + +func (q identityInsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query, args, nil +} + +func (q identityInsertSQL) rowsNum() int { + return 1 } func (q identityInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, []any, error) { @@ -635,29 +658,20 @@ func (q identityInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, [] } func (q identityInsertSQL) Exec(db sqlla.DB) (Identity, error) { - query, args, err := q.ToSql() - if err != nil { - return Identity{}, err - } - row := db.QueryRow(query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { - return Identity{}, err - } - return NewIdentitySQL().Select().ID(pk).Single(db) + return q.ExecContext(context.Background(), db) } func (q identityInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } func (q identityInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -674,6 +688,7 @@ type identityDefaultInsertHooker interface { } type identityInsertSQLToSqler interface { + rowsNum() int identityInsertSQLToSqlPg(offset int) (string, int, []any, error) } @@ -691,6 +706,10 @@ func (q *identityBulkInsertSQL) Append(iqs ...identityInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *identityBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *identityBulkInsertSQL) identityInsertSQLToSqlPg(offset int) (string, int, []any, error) { if len(q.insertSQLs) == 0 { return "", 0, []any{}, fmt.Errorf("sqlla: This identityBulkInsertSQL's InsertSQL was empty") @@ -727,10 +746,20 @@ func (q *identityBulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - return query + " RETURNING " + "\"id\"" + ";", vs, nil + return query + ";", vs, nil } -func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q *identityBulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil +} + +func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -739,16 +768,18 @@ func (q *identityBulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([ return nil, err } defer rows.Close() - pks := make([]IdentityID, 0, len(q.insertSQLs)) + results := make([]Identity, 0, len(q.insertSQLs)) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } + func (q *identityBulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { @@ -774,22 +805,32 @@ func (q identityInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) { return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q identityInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { +func (q identityInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -914,22 +955,32 @@ func (q identityInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) { } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q identityInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { +func (q identityInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return Identity{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk IdentityID - if err := row.Scan(&pk); err != nil { + result, err := NewIdentitySQL().Select().Scan(row) + if err != nil { return Identity{}, err } - return NewIdentitySQL().Select().ID(pk).SingleContext(ctx, db) + return result, nil } @@ -964,13 +1015,23 @@ func (q identityBulkInsertOnConflictDoNothingSQL) ToSql() (string, []any, error) return "", nil, err } query += " ON CONFLICT DO NOTHING" - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q identityBulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -979,16 +1040,17 @@ func (q identityBulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Contex return nil, err } defer rows.Close() - pks := make([]IdentityID, 0) + results := make([]Identity, 0, q.insertSQL.rowsNum()) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } @@ -1120,13 +1182,23 @@ func (q identityBulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, error) } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - query += " RETURNING " + "\"id\"" return query + ";", vs, nil } -func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { +func (q identityBulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join(identityAllColumns, ", ") + return query + ";", args, nil + +} + +func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]Identity, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -1135,16 +1207,17 @@ func (q identityBulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context return nil, err } defer rows.Close() - pks := make([]IdentityID, 0) + results := make([]Identity, 0, q.insertSQL.rowsNum()) + sel := NewIdentitySQL().Select() for rows.Next() { - var pk IdentityID - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return NewIdentitySQL().Select().IDIn(pks...).AllContext(ctx, db) + return results, nil } diff --git a/_example/user.gen.go b/_example/user.gen.go index b57430f..0e33f86 100644 --- a/_example/user.gen.go +++ b/_example/user.gen.go @@ -610,17 +610,7 @@ func (s User) Update() userUpdateSQL { } func (q userUpdateSQL) Exec(db sqlla.DB) ([]User, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]User, error) { @@ -697,6 +687,10 @@ func (q userInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userInsertSQL) rowsNum() int { + return 1 +} + func (q userInsertSQL) userInsertSQLToSql() (string, []any, error) { var err error var s interface{} = User{} @@ -716,19 +710,7 @@ func (q userInsertSQL) userInsertSQLToSql() (string, []any, error) { } func (q userInsertSQL) Exec(db sqlla.DB) (User, error) { - query, args, err := q.ToSql() - if err != nil { - return User{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return User{}, err - } - id, err := result.LastInsertId() - if err != nil { - return User{}, err - } - return NewUserSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (User, error) { @@ -761,6 +743,7 @@ type userDefaultInsertHooker interface { } type userInsertSQLToSqler interface { + rowsNum() int userInsertSQLToSql() (string, []any, error) } @@ -778,6 +761,10 @@ func (q *userBulkInsertSQL) Append(iqs ...userInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userBulkInsertSQL) userInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_external.gen.go b/_example/user_external.gen.go index 48009eb..0c3b440 100644 --- a/_example/user_external.gen.go +++ b/_example/user_external.gen.go @@ -530,17 +530,7 @@ func (s UserExternal) Update() userExternalUpdateSQL { } func (q userExternalUpdateSQL) Exec(db sqlla.DB) ([]UserExternal, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userExternalSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userExternalUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserExternal, error) { @@ -607,6 +597,10 @@ func (q userExternalInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userExternalInsertSQL) rowsNum() int { + return 1 +} + func (q userExternalInsertSQL) userExternalInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserExternal{} @@ -626,19 +620,7 @@ func (q userExternalInsertSQL) userExternalInsertSQLToSql() (string, []any, erro } func (q userExternalInsertSQL) Exec(db sqlla.DB) (UserExternal, error) { - query, args, err := q.ToSql() - if err != nil { - return UserExternal{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserExternal{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserExternal{}, err - } - return NewUserExternalSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userExternalInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserExternal, error) { @@ -671,6 +653,7 @@ type userExternalDefaultInsertHooker interface { } type userExternalInsertSQLToSqler interface { + rowsNum() int userExternalInsertSQLToSql() (string, []any, error) } @@ -688,6 +671,10 @@ func (q *userExternalBulkInsertSQL) Append(iqs ...userExternalInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userExternalBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userExternalBulkInsertSQL) userExternalInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userExternalBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_item.gen.go b/_example/user_item.gen.go index 8cc81d1..1c7e374 100644 --- a/_example/user_item.gen.go +++ b/_example/user_item.gen.go @@ -565,17 +565,7 @@ func (s UserItem) Update() userItemUpdateSQL { } func (q userItemUpdateSQL) Exec(db sqlla.DB) ([]UserItem, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userItemSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userItemUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserItem, error) { @@ -647,6 +637,10 @@ func (q userItemInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userItemInsertSQL) rowsNum() int { + return 1 +} + func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserItem{} @@ -666,19 +660,7 @@ func (q userItemInsertSQL) userItemInsertSQLToSql() (string, []any, error) { } func (q userItemInsertSQL) Exec(db sqlla.DB) (UserItem, error) { - query, args, err := q.ToSql() - if err != nil { - return UserItem{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserItem{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserItem{}, err - } - return NewUserItemSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userItemInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserItem, error) { @@ -711,6 +693,7 @@ type userItemDefaultInsertHooker interface { } type userItemInsertSQLToSqler interface { + rowsNum() int userItemInsertSQLToSql() (string, []any, error) } @@ -728,6 +711,10 @@ func (q *userItemBulkInsertSQL) Append(iqs ...userItemInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userItemBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userItemBulkInsertSQL) userItemInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userItemBulkInsertSQL's InsertSQL was empty") diff --git a/_example/user_sns.gen.go b/_example/user_sns.gen.go index f58959f..2ce5a32 100644 --- a/_example/user_sns.gen.go +++ b/_example/user_sns.gen.go @@ -495,17 +495,7 @@ func (s UserSNS) Update() userSNSUpdateSQL { } func (q userSNSUpdateSQL) Exec(db sqlla.DB) ([]UserSNS, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - _, err = db.Exec(query, args...) - if err != nil { - return nil, err - } - qq := q.userSNSSQL - - return qq.Select().All(db) + return q.ExecContext(context.Background(), db) } func (q userSNSUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]UserSNS, error) { @@ -567,6 +557,10 @@ func (q userSNSInsertSQL) ToSql() (string, []any, error) { return query + ";", vs, nil } +func (q userSNSInsertSQL) rowsNum() int { + return 1 +} + func (q userSNSInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { var err error var s interface{} = UserSNS{} @@ -586,19 +580,7 @@ func (q userSNSInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { } func (q userSNSInsertSQL) Exec(db sqlla.DB) (UserSNS, error) { - query, args, err := q.ToSql() - if err != nil { - return UserSNS{}, err - } - result, err := db.Exec(query, args...) - if err != nil { - return UserSNS{}, err - } - id, err := result.LastInsertId() - if err != nil { - return UserSNS{}, err - } - return NewUserSNSSQL().Select().PkColumn(id).Single(db) + return q.ExecContext(context.Background(), db) } func (q userSNSInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (UserSNS, error) { @@ -631,6 +613,7 @@ type userSNSDefaultInsertHooker interface { } type userSNSInsertSQLToSqler interface { + rowsNum() int userSNSInsertSQLToSql() (string, []any, error) } @@ -648,6 +631,10 @@ func (q *userSNSBulkInsertSQL) Append(iqs ...userSNSInsertSQL) { q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *userSNSBulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + func (q *userSNSBulkInsertSQL) userSNSInsertSQLToSql() (string, []any, error) { if len(q.insertSQLs) == 0 { return "", []any{}, fmt.Errorf("sqlla: This userSNSBulkInsertSQL's InsertSQL was empty") diff --git a/template/insert.tmpl b/template/insert.tmpl index 4c53cb3..84b2c6b 100644 --- a/template/insert.tmpl +++ b/template/insert.tmpl @@ -25,11 +25,23 @@ func (q {{ $camelName }}InsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} - return query + " RETURNING " + {{ cquoteby .PkColumn.Name }} + ";", vs, nil - {{- else }} return query + ";", vs, nil - {{- end }} +} + +{{ if eq (dialect) "postgresql" }} +func (q {{ $camelName }}InsertSQL) ToSqlWithReturning() (string, []any, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query, args, nil +} +{{ end }} + +func (q {{ $camelName }}InsertSQL) rowsNum() int { + return 1 } {{ if eq (dialect) "mysql" }} @@ -81,40 +93,32 @@ func (q {{ $camelName }}InsertSQL) Exec(db sqlla.DB) ({{ .StructName }}, error) {{- else -}} func (q {{ $camelName }}InsertSQL) Exec(db sqlla.DB) (sql.Result, error) { {{- end }} - query, args, err := q.ToSql() + return q.ExecContext(context.Background(), db) +} + +{{ if eq (dialect) "postgresql" }} +func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { - {{ if .HasPk -}} return {{ .StructName }}{}, err - {{- else }} - return nil, err - {{- end }} } - {{- if not .HasPk }} - result, err := db.Exec(query, args...) - return result, err - {{- else }} - {{- if eq (dialect) "mysql" }} - result, err := db.Exec(query, args...) + row := db.QueryRowContext(ctx, query, args...) + result, err := {{ $constructor }}().Select().Scan(row) if err != nil { return {{ .StructName }}{}, err } - id, err := result.LastInsertId() + return result, nil +} + +func (q {{ $camelName }}InsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() if err != nil { - return {{ .StructName }}{}, err - } - return {{ $constructor }}().Select().PkColumn(id).Single(db) - {{- end }} - {{- if eq (dialect) "postgresql" }} - row := db.QueryRow(query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { - return {{ .StructName }}{}, err + return nil, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).Single(db) - {{- end }} - {{- end }} + result, err := db.ExecContext(ctx, query, args...) + return result, err } - +{{ else }} {{ if .HasPk -}} func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { {{- else -}} @@ -132,7 +136,6 @@ func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) result, err := db.ExecContext(ctx, query, args...) return result, err {{- else }} - {{- if eq (dialect) "mysql" }} result, err := db.ExecContext(ctx, query, args...) if err != nil { return {{ .StructName }}{}, err @@ -143,15 +146,6 @@ func (q {{ $camelName }}InsertSQL) ExecContext(ctx context.Context, db sqlla.DB) } return {{ $constructor }}().Select().PkColumn(id).SingleContext(ctx, db) {{- end }} - {{- if eq (dialect) "postgresql" }} - row := db.QueryRowContext(ctx, query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { - return {{ .StructName }}{}, err - } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).SingleContext(ctx, db) - {{- end }} - {{- end }} } {{ if .HasPk -}} @@ -164,12 +158,14 @@ func (q {{ $camelName }}InsertSQL) ExecContextWithoutSelect(ctx context.Context, return result, err } {{- end }} +{{- end }} type {{ $camelName }}DefaultInsertHooker interface { DefaultInsertHook({{ $camelName }}InsertSQL) ({{ $camelName }}InsertSQL, error) } type {{ $camelName }}InsertSQLToSqler interface { + rowsNum() int {{- if eq (dialect) "mysql" }} {{ $camelName }}InsertSQLToSql() (string, []any, error) {{- end }} @@ -192,6 +188,10 @@ func (q *{{ $camelName }}BulkInsertSQL) Append(iqs ...{{ $camelName }}InsertSQL) q.insertSQLs = append(q.insertSQLs, iqs...) } +func (q *{{ $camelName }}BulkInsertSQL) rowsNum() int { + return len(q.insertSQLs) +} + {{ if eq (dialect) "mysql" }} func (q *{{ $camelName }}BulkInsertSQL) {{ $camelName }}InsertSQLToSql() (string, []any, error) { {{- end }} @@ -264,45 +264,54 @@ func (q *{{ $camelName }}BulkInsertSQL) ToSql() (string, []any, error) { if err != nil { return "", []any{}, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} - return query + " RETURNING " + {{ cquoteby .PkColumn.Name }} + ";", vs, nil - {{- else }} return query + ";", vs, nil - {{- end }} } -{{- if and .HasPk (eq (dialect) "postgresql") }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{- else }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{- end }} +{{- if eq (dialect) "postgresql" }} +func (q *{{ $camelName }}BulkInsertSQL) ToSqlWithReturning() (string, []any, error) { query, args, err := q.ToSql() + if err != nil { + return "", []any{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - {{- if and .HasPk (eq (dialect) "postgresql") }} rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - pks := make([]{{ .PkColumn.TypeName }}, 0, len(q.insertSQLs)) + results := make([]{{ .StructName }}, 0, len(q.insertSQLs)) + sel := {{ $constructor }}().Select() for rows.Next() { - var pk {{ .PkColumn.TypeName }} - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) + } + return results, nil +} + +func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { + query, args, err := q.ToSql() + if err != nil { + return nil, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}In(pks...).AllContext(ctx, db) - {{- else }} result, err := db.ExecContext(ctx, query, args...) return result, err - {{- end }} } +{{- end }} -{{- if and .HasPk (eq (dialect) "postgresql") }} -func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { +{{- if eq (dialect) "mysql" }} +func (q *{{ $camelName }}BulkInsertSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { query, args, err := q.ToSql() if err != nil { return nil, err @@ -312,6 +321,7 @@ func (q *{{ $camelName }}BulkInsertSQL) ExecContextWithoutSelect(ctx context.Con } {{- end }} + {{- if eq (dialect) "mysql" }} {{ template "InsertMySQL" . }} {{- end }} diff --git a/template/insert_postgresql.tmpl b/template/insert_postgresql.tmpl index 7ea59c3..7337bdc 100644 --- a/template/insert_postgresql.tmpl +++ b/template/insert_postgresql.tmpl @@ -1,19 +1,19 @@ -{{ define "InsertPostgreSQL.ExecContextHasPkSingle" }} +{{ define "InsertPostgreSQL.ExecContextSingle" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return {{ .StructName }}{}, err } row := db.QueryRowContext(ctx, query, args...) - var pk {{ .PkColumn.TypeName }} - if err := row.Scan(&pk); err != nil { + result, err := {{ $constructor }}().Select().Scan(row) + if err != nil { return {{ .StructName }}{}, err } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}(pk).SingleContext(ctx, db) + return result, nil {{ end }} -{{ define "InsertPostgreSQL.ExecContextHasPkAll" }} +{{ define "InsertPostgreSQL.ExecContextAll" }} {{- $constructor := printf "New%sSQL" (.Name | toCamel) -}} - query, args, err := q.ToSql() + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } @@ -22,16 +22,17 @@ return nil, err } defer rows.Close() - pks := make([]{{ .PkColumn.TypeName }}, 0) + results := make([]{{ .StructName }}, 0, q.insertSQL.rowsNum()) + sel := {{ $constructor }}().Select() for rows.Next() { - var pk {{ .PkColumn.TypeName }} - if err := rows.Scan(&pk); err != nil { + result, err := sel.Scan(rows) + if err != nil { return nil, err } - pks = append(pks, pk) + results = append(results, result) } - return {{ $constructor }}().Select().{{ .PkColumn.MethodName }}In(pks...).AllContext(ctx, db) + return results, nil {{ end }} {{ define "InsertPostgreSQL.ExecContextWithoutSelect" }} query, args, err := q.ToSql() @@ -48,9 +49,6 @@ return "", nil, err } query += " ON CONFLICT DO NOTHING" - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} return query + ";", vs, nil {{ end }} {{ define "InsertPostgreSQL.DoUpdateToSql" }} @@ -75,12 +73,19 @@ } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} return query + ";", vs, nil {{ end }} +{{ define "InsertPostgreSQL.ToSqlWithReturning" }} +{{- $camelName := .Name | toCamel | untitle -}} + query, args, err := q.ToSql() + if err != nil { + return "", nil, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +{{ end }} {{ define "InsertPostgreSQL" }} {{- $camelName := .Name | toCamel | untitle -}} @@ -100,19 +105,17 @@ func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ToSql() (string, []any, er {{ template "InsertPostgreSQL.DoNothingToSql" . }} } -{{ if .HasPk -}} +func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkSingle" . }} +{{ template "InsertPostgreSQL.ExecContextSingle" . }} } func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}InsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}InsertOnConflictDoUpdateSQL struct { insertSQL {{ $camelName }}InsertSQLToSqler @@ -151,26 +154,21 @@ func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ToSql() (string, []any, err } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} return query + ";", vs, nil } -{{ if .HasPk -}} +func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ({{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkSingle" . }} +{{ template "InsertPostgreSQL.ExecContextSingle" . }} } func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}InsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}DefaultInsertOnConflictDoUpdateHooker interface { DefaultInsertOnConflictDoUpdateHook({{ $camelName }}InsertOnConflictDoUpdateSQL) ({{ $camelName }}InsertOnConflictDoUpdateSQL, error) @@ -190,19 +188,17 @@ func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ToSql() (string, []any {{ template "InsertPostgreSQL.DoNothingToSql" . }} } -{{ if .HasPk -}} +func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkAll" . }} +{{ template "InsertPostgreSQL.ExecContextAll" . }} } func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}BulkInsertOnConflictDoNothingSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} type {{ $camelName }}BulkInsertOnConflictDoUpdateSQL struct { insertSQL {{ $camelName }}InsertSQLToSqler @@ -248,25 +244,20 @@ func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ToSql() (string, []any, } query += " ON CONFLICT (" + q.target + ") DO UPDATE SET" + os vs = append(vs, ovs...) - {{- if .HasPk }} - query += " RETURNING " + {{ cquoteby .PkColumn.Name }} - {{- end }} return query + ";", vs, nil } -{{ if .HasPk -}} +func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ToSqlWithReturning() (string, []any, error) { +{{ template "InsertPostgreSQL.ToSqlWithReturning" . }} +} + func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { -{{ template "InsertPostgreSQL.ExecContextHasPkAll" . }} +{{ template "InsertPostgreSQL.ExecContextAll" . }} } func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContextWithoutSelect(ctx context.Context, db sqlla.DB) (sql.Result, error) { {{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} } -{{- else -}} -func (q {{ $camelName }}BulkInsertOnConflictDoUpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { -{{ template "InsertPostgreSQL.ExecContextWithoutSelect" . }} -} -{{- end }} {{ end }} diff --git a/template/update.tmpl b/template/update.tmpl index 7ddb2b1..df3332f 100644 --- a/template/update.tmpl +++ b/template/update.tmpl @@ -50,23 +50,56 @@ func (q {{ $camelName }}UpdateSQL) ToSql() (string, []interface{}, error) { return query + ";", append(svs, wvs...), nil } -{{- if .HasPk }} +{{- if eq (dialect) "postgresql" }} +func (q {{ $camelName }}UpdateSQL) ToSqlWithReturning() (string, []interface{}, error) { + query, args, err := q.ToSql() + if err != nil { + return "", []interface{}{}, err + } + query = strings.TrimSuffix(query, ";") + query += " RETURNING " + strings.Join({{ $camelName }}AllColumns, ", ") + return query + ";", args, nil +} + func (s {{ .StructName }}) Update() {{ $camelName }}UpdateSQL { return {{ $constructor }}().Update().Where{{ .PkColumn.Name | toCamel | title }}(s.{{ .PkColumn.FieldName }}) } func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) ([]{{ .StructName }}, error) { - query, args, err := q.ToSql() + return q.ExecContext(context.Background(), db) +} + +func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { + query, args, err := q.ToSqlWithReturning() if err != nil { return nil, err } - _, err = db.Exec(query, args...) + rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - qq := q.{{ $camelName }}SQL + results := make([]{{ .StructName }}, 0, 1) + defer rows.Close() + sel := {{ $constructor }}().Select() + for rows.Next() { + result, err := sel.Scan(rows) + if err != nil { + return nil, err + } + results = append(results, result) + } + + return results, nil +} +{{- end }} +{{- if eq (dialect) "mysql" }} +{{- if .HasPk }} +func (s {{ .StructName }}) Update() {{ $camelName }}UpdateSQL { + return {{ $constructor }}().Update().Where{{ .PkColumn.Name | toCamel | title }}(s.{{ .PkColumn.FieldName }}) +} - return qq.Select().All(db) +func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) ([]{{ .StructName }}, error) { + return q.ExecContext(context.Background(), db) } func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) ([]{{ .StructName }}, error) { @@ -84,11 +117,7 @@ func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) } {{- else }} func (q {{ $camelName }}UpdateSQL) Exec(db sqlla.DB) (sql.Result, error) { - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - return db.Exec(query, args...) + return q.ExecContext(context.Background(), db) } func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) (sql.Result, error) { @@ -99,6 +128,7 @@ func (q {{ $camelName }}UpdateSQL) ExecContext(ctx context.Context, db sqlla.DB) return db.ExecContext(ctx, query, args...) } {{- end }} +{{- end }} type {{ $camelName }}DefaultUpdateHooker interface { DefaultUpdateHook({{ $camelName }}UpdateSQL) ({{ $camelName }}UpdateSQL, error)