From f00bc7595bd3b59b4888e25e7d7b0fc492da6ed8 Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Tue, 24 Mar 2026 11:55:39 +0000 Subject: [PATCH 1/2] Add PostgreSQL replication provider Implement the DatabaseProvider interface for PostgreSQL streaming replication topologies. The provider uses lib/pq to connect to PostgreSQL instances and maps WAL-based replication state to the common ReplicationStatus struct. - GetReplicationStatus: queries pg_stat_wal_receiver (standby) or pg_current_wal_lsn (primary) - IsReplicaRunning: checks WAL receiver status - SetReadOnly/IsReadOnly: manages default_transaction_read_only - StartReplication: resumes WAL replay (streaming is automatic) - StopReplication: pauses WAL replay via pg_wal_replay_pause() Also adds PostgreSQLTopologyUser and PostgreSQLTopologyPassword config fields, unit tests, and updated documentation. Closes #53 --- docs/database-providers.md | 57 +- go.mod | 1 + go.sum | 2 + go/config/config.go | 2 + go/inst/provider_postgresql.go | 241 +++ go/inst/provider_postgresql_test.go | 37 + vendor/github.com/lib/pq/.gitattributes | 1 + vendor/github.com/lib/pq/.gitignore | 6 + vendor/github.com/lib/pq/CHANGELOG.md | 242 +++ vendor/github.com/lib/pq/LICENSE | 21 + vendor/github.com/lib/pq/README.md | 312 +++ vendor/github.com/lib/pq/array.go | 903 ++++++++ vendor/github.com/lib/pq/as.go | 26 + vendor/github.com/lib/pq/as_go126.go | 23 + vendor/github.com/lib/pq/buf.go | 100 + vendor/github.com/lib/pq/compose.yaml | 81 + vendor/github.com/lib/pq/conn.go | 1817 +++++++++++++++++ vendor/github.com/lib/pq/conn_go18.go | 226 ++ vendor/github.com/lib/pq/connector.go | 1157 +++++++++++ vendor/github.com/lib/pq/copy.go | 337 +++ vendor/github.com/lib/pq/deprecated.go | 133 ++ vendor/github.com/lib/pq/doc.go | 137 ++ vendor/github.com/lib/pq/encode.go | 400 ++++ vendor/github.com/lib/pq/error.go | 324 +++ .../lib/pq/internal/pgpass/pgpass.go | 71 + .../lib/pq/internal/pgservice/pgservice.go | 70 + .../github.com/lib/pq/internal/pqsql/copy.go | 37 + .../github.com/lib/pq/internal/pqtime/loc.go | 37 + .../lib/pq/internal/pqtime/pqtime.go | 190 ++ .../github.com/lib/pq/internal/pqutil/path.go | 86 + .../github.com/lib/pq/internal/pqutil/perm.go | 64 + .../pq/internal/pqutil/perm_unsupported.go | 12 + .../lib/pq/internal/pqutil/pqutil.go | 32 + .../lib/pq/internal/pqutil/user_other.go | 9 + .../lib/pq/internal/pqutil/user_posix.go | 25 + .../lib/pq/internal/pqutil/user_windows.go | 28 + .../github.com/lib/pq/internal/proto/proto.go | 186 ++ .../github.com/lib/pq/internal/proto/sz_32.go | 7 + .../github.com/lib/pq/internal/proto/sz_64.go | 7 + vendor/github.com/lib/pq/krb.go | 27 + vendor/github.com/lib/pq/notice.go | 69 + vendor/github.com/lib/pq/notify.go | 834 ++++++++ vendor/github.com/lib/pq/oid/doc.go | 7 + vendor/github.com/lib/pq/oid/types.go | 343 ++++ vendor/github.com/lib/pq/pqerror/codes.go | 581 ++++++ vendor/github.com/lib/pq/pqerror/pqerror.go | 35 + vendor/github.com/lib/pq/quote.go | 71 + vendor/github.com/lib/pq/rows.go | 245 +++ vendor/github.com/lib/pq/scram/scram.go | 261 +++ vendor/github.com/lib/pq/ssl.go | 312 +++ vendor/github.com/lib/pq/staticcheck.conf | 5 + vendor/github.com/lib/pq/stmt.go | 150 ++ vendor/modules.txt | 12 + 53 files changed, 10397 insertions(+), 2 deletions(-) create mode 100644 go/inst/provider_postgresql.go create mode 100644 go/inst/provider_postgresql_test.go create mode 100644 vendor/github.com/lib/pq/.gitattributes create mode 100644 vendor/github.com/lib/pq/.gitignore create mode 100644 vendor/github.com/lib/pq/CHANGELOG.md create mode 100644 vendor/github.com/lib/pq/LICENSE create mode 100644 vendor/github.com/lib/pq/README.md create mode 100644 vendor/github.com/lib/pq/array.go create mode 100644 vendor/github.com/lib/pq/as.go create mode 100644 vendor/github.com/lib/pq/as_go126.go create mode 100644 vendor/github.com/lib/pq/buf.go create mode 100644 vendor/github.com/lib/pq/compose.yaml create mode 100644 vendor/github.com/lib/pq/conn.go create mode 100644 vendor/github.com/lib/pq/conn_go18.go create mode 100644 vendor/github.com/lib/pq/connector.go create mode 100644 vendor/github.com/lib/pq/copy.go create mode 100644 vendor/github.com/lib/pq/deprecated.go create mode 100644 vendor/github.com/lib/pq/doc.go create mode 100644 vendor/github.com/lib/pq/encode.go create mode 100644 vendor/github.com/lib/pq/error.go create mode 100644 vendor/github.com/lib/pq/internal/pgpass/pgpass.go create mode 100644 vendor/github.com/lib/pq/internal/pgservice/pgservice.go create mode 100644 vendor/github.com/lib/pq/internal/pqsql/copy.go create mode 100644 vendor/github.com/lib/pq/internal/pqtime/loc.go create mode 100644 vendor/github.com/lib/pq/internal/pqtime/pqtime.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/path.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/perm.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/pqutil.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/user_other.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/user_posix.go create mode 100644 vendor/github.com/lib/pq/internal/pqutil/user_windows.go create mode 100644 vendor/github.com/lib/pq/internal/proto/proto.go create mode 100644 vendor/github.com/lib/pq/internal/proto/sz_32.go create mode 100644 vendor/github.com/lib/pq/internal/proto/sz_64.go create mode 100644 vendor/github.com/lib/pq/krb.go create mode 100644 vendor/github.com/lib/pq/notice.go create mode 100644 vendor/github.com/lib/pq/notify.go create mode 100644 vendor/github.com/lib/pq/oid/doc.go create mode 100644 vendor/github.com/lib/pq/oid/types.go create mode 100644 vendor/github.com/lib/pq/pqerror/codes.go create mode 100644 vendor/github.com/lib/pq/pqerror/pqerror.go create mode 100644 vendor/github.com/lib/pq/quote.go create mode 100644 vendor/github.com/lib/pq/rows.go create mode 100644 vendor/github.com/lib/pq/scram/scram.go create mode 100644 vendor/github.com/lib/pq/ssl.go create mode 100644 vendor/github.com/lib/pq/staticcheck.conf create mode 100644 vendor/github.com/lib/pq/stmt.go diff --git a/docs/database-providers.md b/docs/database-providers.md index c6ea6b0f..55cc7b02 100644 --- a/docs/database-providers.md +++ b/docs/database-providers.md @@ -6,8 +6,9 @@ Orchestrator supports a database provider abstraction layer that decouples core orchestration logic from database-specific operations. This allows orchestrator to manage different database engines through a common interface. -MySQL is the default (and currently only) provider. The abstraction layer is -designed to support future providers such as PostgreSQL. +MySQL is the default provider. PostgreSQL is also supported for streaming +replication topologies. The abstraction layer is designed to support additional +providers in the future. ## Architecture @@ -82,6 +83,58 @@ preserved. The MySQL provider is automatically registered at init time. No configuration is needed to use it. +## PostgreSQL Provider + +The PostgreSQL provider (`PostgreSQLProvider`) supports PostgreSQL streaming +replication topologies. It uses the `lib/pq` driver to connect to PostgreSQL +instances. + +### Configuration + +Add the following fields to your orchestrator configuration JSON: + +```json +{ + "PostgreSQLTopologyUser": "orchestrator", + "PostgreSQLTopologyPassword": "secret" +} +``` + +These credentials are used to connect to PostgreSQL topology instances for +discovery and replication management operations. + +### Activating the Provider + +To use PostgreSQL instead of MySQL, register the provider during startup: + +```go +import "github.com/proxysql/orchestrator/go/inst" + +inst.SetProvider(inst.NewPostgreSQLProvider()) +``` + +### How It Works + +| Operation | PostgreSQL Implementation | +|---------------------|---------------------------------------------------------------| +| GetReplicationStatus | Queries `pg_stat_wal_receiver` (standby) or `pg_current_wal_lsn()` (primary). Reports WAL LSN as position and `replay_lag` as lag. | +| IsReplicaRunning | Checks `pg_stat_wal_receiver` for an active WAL receiver with `status = 'streaming'`. | +| SetReadOnly | Runs `ALTER SYSTEM SET default_transaction_read_only = on/off` followed by `SELECT pg_reload_conf()`. | +| IsReadOnly | Queries `SHOW default_transaction_read_only`. | +| StartReplication | Calls `SELECT pg_wal_replay_resume()`. Streaming replication itself starts automatically when the standby connects. | +| StopReplication | Calls `SELECT pg_wal_replay_pause()` to pause WAL replay. The WAL receiver remains connected. | + +### Differences from MySQL + +- **No separate IO/SQL threads.** PostgreSQL does not have the concept of + separate IO and SQL threads. The `IOThreadRunning` and `SQLThreadRunning` + fields in `ReplicationStatus` both mirror the WAL receiver state. +- **Streaming replication is automatic.** `StartReplication` resumes WAL replay + but cannot start the WAL receiver itself -- that is controlled by PostgreSQL's + `primary_conninfo` configuration. +- **StopReplication pauses replay only.** The WAL receiver continues to receive + WAL segments; only application (replay) is paused. + ## Implementing a New Provider To add support for a new database engine: diff --git a/go.mod b/go.mod index 5728ada5..20b6c69b 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/hashicorp/consul/api v1.33.4 github.com/hashicorp/raft v1.7.3 github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef + github.com/lib/pq v1.12.0 github.com/mattn/go-sqlite3 v1.14.37 github.com/montanaflynn/stats v0.8.2 github.com/outbrain/zookeepercli v1.0.12 diff --git a/go.sum b/go.sum index eb1b73dc..2b0dc91d 100644 --- a/go.sum +++ b/go.sum @@ -128,6 +128,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo= +github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/go/config/config.go b/go/config/config.go index 339ccc0d..28461d62 100644 --- a/go/config/config.go +++ b/go/config/config.go @@ -96,6 +96,8 @@ type Configuration struct { AgentsServerPort string // port orchestrator agents talk back to MySQLTopologyUser string MySQLTopologyPassword string + PostgreSQLTopologyUser string // Username for connecting to PostgreSQL topology instances + PostgreSQLTopologyPassword string // Password for connecting to PostgreSQL topology instances MySQLTopologyCredentialsConfigFile string // my.cnf style configuration file from where to pick credentials. Expecting `user`, `password` under `[client]` section MySQLTopologySSLPrivateKeyFile string // Private key file used to authenticate with a Topology mysql instance with TLS MySQLTopologySSLCertFile string // Certificate PEM file used to authenticate with a Topology mysql instance with TLS diff --git a/go/inst/provider_postgresql.go b/go/inst/provider_postgresql.go new file mode 100644 index 00000000..8f6ad253 --- /dev/null +++ b/go/inst/provider_postgresql.go @@ -0,0 +1,241 @@ +/* + Copyright 2024 Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package inst + +import ( + "database/sql" + "fmt" + + _ "github.com/lib/pq" + "github.com/proxysql/golib/log" + "github.com/proxysql/orchestrator/go/config" +) + +// PostgreSQLProvider implements DatabaseProvider for PostgreSQL streaming +// replication topologies. +type PostgreSQLProvider struct{} + +// NewPostgreSQLProvider creates a new PostgreSQL database provider. +func NewPostgreSQLProvider() *PostgreSQLProvider { + return &PostgreSQLProvider{} +} + +// ProviderName returns "postgresql". +func (p *PostgreSQLProvider) ProviderName() string { + return "postgresql" +} + +// openPostgreSQLTopology opens a connection to a PostgreSQL instance using +// credentials from the orchestrator configuration. +func openPostgreSQLTopology(hostname string, port int) (*sql.DB, error) { + cfg := config.Config + connStr := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable connect_timeout=5", + hostname, port, cfg.PostgreSQLTopologyUser, cfg.PostgreSQLTopologyPassword, + ) + db, err := sql.Open("postgres", connStr) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(3) + db.SetMaxIdleConns(1) + return db, nil +} + +// GetReplicationStatus retrieves the replication state for a PostgreSQL instance. +// On a standby it queries pg_stat_wal_receiver; on a primary it queries +// pg_stat_replication. +func (p *PostgreSQLProvider) GetReplicationStatus(key InstanceKey) (*ReplicationStatus, error) { + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return nil, log.Errore(err) + } + defer db.Close() + + // Check whether this instance is in recovery (i.e. is a standby). + var inRecovery bool + if err := db.QueryRow("SELECT pg_is_in_recovery()").Scan(&inRecovery); err != nil { + return nil, log.Errore(err) + } + + if inRecovery { + return p.getStandbyReplicationStatus(db) + } + return p.getPrimaryReplicationStatus(db) +} + +// getStandbyReplicationStatus reads replication state from a PostgreSQL standby +// via pg_stat_wal_receiver and pg_last_wal_replay_lsn(). +func (p *PostgreSQLProvider) getStandbyReplicationStatus(db *sql.DB) (*ReplicationStatus, error) { + var status, lsn sql.NullString + var lagSeconds sql.NullFloat64 + + err := db.QueryRow(` + SELECT + w.status, + pg_last_wal_replay_lsn()::text, + EXTRACT(EPOCH FROM replay_lag) + FROM pg_stat_wal_receiver w + LEFT JOIN pg_stat_replication r ON true + LIMIT 1 + `).Scan(&status, &lsn, &lagSeconds) + + if err == sql.ErrNoRows { + // No WAL receiver row means replication is not running. + return &ReplicationStatus{ + ReplicaRunning: false, + IOThreadRunning: false, + SQLThreadRunning: false, + Position: "", + Lag: -1, + }, nil + } + if err != nil { + return nil, log.Errore(err) + } + + ioRunning := status.Valid && status.String == "streaming" + lag := int64(-1) + if lagSeconds.Valid { + lag = int64(lagSeconds.Float64) + } + + position := "" + if lsn.Valid { + position = lsn.String + } + + return &ReplicationStatus{ + ReplicaRunning: ioRunning, + IOThreadRunning: ioRunning, + SQLThreadRunning: ioRunning, // PG does not separate IO/SQL threads; mirror IO state + Position: position, + Lag: lag, + }, nil +} + +// getPrimaryReplicationStatus returns a ReplicationStatus for a primary server. +// A primary is not itself a replica, so ReplicaRunning is false, and we report +// the current WAL insert position. +func (p *PostgreSQLProvider) getPrimaryReplicationStatus(db *sql.DB) (*ReplicationStatus, error) { + var lsn string + if err := db.QueryRow("SELECT pg_current_wal_lsn()::text").Scan(&lsn); err != nil { + return nil, log.Errore(err) + } + return &ReplicationStatus{ + ReplicaRunning: false, + IOThreadRunning: false, + SQLThreadRunning: false, + Position: lsn, + Lag: 0, + }, nil +} + +// IsReplicaRunning checks whether the WAL receiver is active on a PostgreSQL +// standby instance. +func (p *PostgreSQLProvider) IsReplicaRunning(key InstanceKey) (bool, error) { + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return false, log.Errore(err) + } + defer db.Close() + + var status sql.NullString + err = db.QueryRow("SELECT status FROM pg_stat_wal_receiver LIMIT 1").Scan(&status) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, log.Errore(err) + } + return status.Valid && status.String == "streaming", nil +} + +// SetReadOnly sets or clears the default_transaction_read_only parameter on +// a PostgreSQL instance and reloads the configuration. +func (p *PostgreSQLProvider) SetReadOnly(key InstanceKey, readOnly bool) error { + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return log.Errore(err) + } + defer db.Close() + + value := "off" + if readOnly { + value = "on" + } + if _, err := db.Exec(fmt.Sprintf("ALTER SYSTEM SET default_transaction_read_only = %s", value)); err != nil { + return log.Errore(err) + } + if _, err := db.Exec("SELECT pg_reload_conf()"); err != nil { + return log.Errore(err) + } + return nil +} + +// IsReadOnly checks whether default_transaction_read_only is enabled on a +// PostgreSQL instance. +func (p *PostgreSQLProvider) IsReadOnly(key InstanceKey) (bool, error) { + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return false, log.Errore(err) + } + defer db.Close() + + var value string + if err := db.QueryRow("SHOW default_transaction_read_only").Scan(&value); err != nil { + return false, log.Errore(err) + } + return value == "on", nil +} + +// StartReplication is a no-op for PostgreSQL streaming replication. Streaming +// replication starts automatically when a standby connects to its primary. +// WAL replay is resumed if it was previously paused. +func (p *PostgreSQLProvider) StartReplication(key InstanceKey) error { + log.Infof("PostgreSQL streaming replication on %s:%d starts automatically; resuming WAL replay if paused", key.Hostname, key.Port) + + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return log.Errore(err) + } + defer db.Close() + + if _, err := db.Exec("SELECT pg_wal_replay_resume()"); err != nil { + return log.Errore(err) + } + return nil +} + +// StopReplication pauses WAL replay on a PostgreSQL standby. This is the +// closest equivalent to stopping replication in MySQL. Note that the WAL +// receiver (IO thread equivalent) remains connected; only replay is paused. +func (p *PostgreSQLProvider) StopReplication(key InstanceKey) error { + db, err := openPostgreSQLTopology(key.Hostname, key.Port) + if err != nil { + return log.Errore(err) + } + defer db.Close() + + if _, err := db.Exec("SELECT pg_wal_replay_pause()"); err != nil { + return log.Errore(err) + } + return nil +} + +// Compile-time check that PostgreSQLProvider implements DatabaseProvider. +var _ DatabaseProvider = (*PostgreSQLProvider)(nil) diff --git a/go/inst/provider_postgresql_test.go b/go/inst/provider_postgresql_test.go new file mode 100644 index 00000000..c119a35f --- /dev/null +++ b/go/inst/provider_postgresql_test.go @@ -0,0 +1,37 @@ +/* + Copyright 2024 Orchestrator Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package inst + +import "testing" + +func TestPostgreSQLProviderName(t *testing.T) { + p := NewPostgreSQLProvider() + if p.ProviderName() != "postgresql" { + t.Errorf("expected 'postgresql', got %q", p.ProviderName()) + } +} + +func TestPostgreSQLProviderImplementsInterface(t *testing.T) { + var _ DatabaseProvider = (*PostgreSQLProvider)(nil) +} + +func TestPostgreSQLProviderNewReturnsNonNil(t *testing.T) { + p := NewPostgreSQLProvider() + if p == nil { + t.Error("expected non-nil provider") + } +} diff --git a/vendor/github.com/lib/pq/.gitattributes b/vendor/github.com/lib/pq/.gitattributes new file mode 100644 index 00000000..dfdb8b77 --- /dev/null +++ b/vendor/github.com/lib/pq/.gitattributes @@ -0,0 +1 @@ +*.sh text eol=lf diff --git a/vendor/github.com/lib/pq/.gitignore b/vendor/github.com/lib/pq/.gitignore new file mode 100644 index 00000000..3243952a --- /dev/null +++ b/vendor/github.com/lib/pq/.gitignore @@ -0,0 +1,6 @@ +.db +*.test +*~ +*.swp +.idea +.vscode \ No newline at end of file diff --git a/vendor/github.com/lib/pq/CHANGELOG.md b/vendor/github.com/lib/pq/CHANGELOG.md new file mode 100644 index 00000000..3f80d12b --- /dev/null +++ b/vendor/github.com/lib/pq/CHANGELOG.md @@ -0,0 +1,242 @@ +unreleased +---------- + +- This release changes the default `sslmode` from `require` to `prefer`, which + is the default used by libpq and the rest of the PostgreSQL ecosystem. See + [#1271] for some background. + +v1.12.0 (2026-03-18) +-------------------- + +- The next release may change the default sslmode from `require` to `prefer`. + See [#1271] for details. + +- `CopyIn()` and `CopyInToSchema()` have been marked as deprecated. These are + simple query builders and not needed for `COPY [..] FROM STDIN` support (which + is *not* deprecated). ([#1279]) + + // Old + tx.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) + + // Replacement + tx.Prepare(`copy temp (num, text, blob, nothing) from stdin`) + +### Features + +- Support protocol 3.2, and the `min_protocol_version` and + `max_protocol_version` DSN parameters ([#1258]). + +- Support `sslmode=prefer` and `sslmode=allow` ([#1270]). + +- Support `ssl_min_protocol_version` and `ssl_max_protocol_version` ([#1277]). + +- Support connection service file to load connection details ([#1285]). + +- Support `sslrootcert=system` and use `~/.postgresql/root.crt` as the default + value of sslrootcert ([#1280], [#1281]). + +- Add a new `pqerror` package with PostgreSQL error codes ([#1275]). + + For example, to test if an error is a UNIQUE constraint violation: + + if pqErr, ok := errors.AsType[*pq.Error](err); ok && pqErr.Code == pqerror.UniqueViolation { + log.Fatalf("email %q already exsts", email) + } + + To make this a bit more convenient, it also adds a `pq.As()` function: + + pqErr := pq.As(err, pqerror.UniqueViolation) + if pqErr != nil { + log.Fatalf("email %q already exsts", email) + } + +### Fixes + +- Fix SSL key permission check to allow modes stricter than 0600/0640#1265 ([#1265]). + +- Fix Hstore to work with binary parameters ([#1278]). + +- Clearer error when starting a new query while pq is still processing another + query ([#1272]). + +- Send intermediate CAs with client certificates, so they can be signed by an + intermediate CA ([#1267]). + +- Use `time.UTC` for UTC aliases such as `Etc/UTC` ([#1282]). + +[#1258]: https://github.com/lib/pq/pull/1258 +[#1265]: https://github.com/lib/pq/pull/1265 +[#1267]: https://github.com/lib/pq/pull/1267 +[#1270]: https://github.com/lib/pq/pull/1270 +[#1271]: https://github.com/lib/pq/pull/1271 +[#1272]: https://github.com/lib/pq/pull/1272 +[#1275]: https://github.com/lib/pq/pull/1275 +[#1277]: https://github.com/lib/pq/pull/1277 +[#1278]: https://github.com/lib/pq/pull/1278 +[#1279]: https://github.com/lib/pq/pull/1279 +[#1280]: https://github.com/lib/pq/pull/1280 +[#1281]: https://github.com/lib/pq/pull/1281 +[#1282]: https://github.com/lib/pq/pull/1282 +[#1283]: https://github.com/lib/pq/pull/1283 +[#1285]: https://github.com/lib/pq/pull/1285 + +v1.11.2 (2026-02-10) +-------------------- +This fixes two regressions: + +- Don't send startup parameters if there is no value, improving compatibility + with Supavisor ([#1260]). + +- Don't send `dbname` as a startup parameter if `database=[..]` is used in the + connection string. It's recommended to use dbname=, as database= is not a + libpq option, and only worked by accident previously. ([#1261]) + +[#1260]: https://github.com/lib/pq/pull/1260 +[#1261]: https://github.com/lib/pq/pull/1261 + +v1.11.1 (2026-01-29) +-------------------- +This fixes two regressions present in the v1.11.0 release: + +- Fix build on 32bit systems, Windows, and Plan 9 ([#1253]). + +- Named []byte types and pointers to []byte (e.g. `*[]byte`, `json.RawMessage`) + would be treated as an array instead of bytea ([#1252]). + +[#1252]: https://github.com/lib/pq/pull/1252 +[#1253]: https://github.com/lib/pq/pull/1253 + +v1.11.0 (2026-01-28) +-------------------- +This version of pq requires Go 1.21 or newer. + +pq now supports only maintained PostgreSQL releases, which is PostgreSQL 14 and +newer. Previously PostgreSQL 8.4 and newer were supported. + +### Features + +- The `pq.Error.Error()` text includes the position of the error (if reported + by PostgreSQL) and SQLSTATE code ([#1219], [#1224]): + + pq: column "columndoesntexist" does not exist at column 8 (42703) + pq: syntax error at or near ")" at position 2:71 (42601) + +- The `pq.Error.ErrorWithDetail()` method prints a more detailed multiline + message, with the Detail, Hint, and error position (if any) ([#1219]): + + ERROR: syntax error at or near ")" (42601) + CONTEXT: line 12, column 1: + + 10 | name varchar, + 11 | version varchar, + 12 | ); + ^ + +- Add `Config`, `NewConfig()`, and `NewConnectorConfig()` to supply connection + details in a more structured way ([#1240]). + +- Support `hostaddr` and `$PGHOSTADDR` ([#1243]). + +- Support multiple values in `host`, `port`, and `hostaddr`, which are each + tried in order, or randomly if `load_balance_hosts=random` is set ([#1246]). + +- Support `target_session_attrs` connection parameter ([#1246]). + +- Support [`sslnegotiation`] to use SSL without negotiation ([#1180]). + +- Allow using a custom `tls.Config`, for example for encrypted keys ([#1228]). + +- Add `PQGO_DEBUG=1` print the communication with PostgreSQL to stderr, to aid + in debugging, testing, and bug reports ([#1223]). + +- Add support for NamedValueChecker interface ([#1125], [#1238]). + + +### Fixes + +- Match HOME directory lookup logic with libpq: prefer $HOME over /etc/passwd, + ignore ENOTDIR errors, and use APPDATA on Windows ([#1214]). + +- Fix `sslmode=verify-ca` verifying the hostname anyway when connecting to a DNS + name (rather than IP) ([#1226]). + +- Correctly detect pre-protocol errors such as the server not being able to fork + or running out of memory ([#1248]). + +- Fix build with wasm ([#1184]), appengine ([#745]), and Plan 9 ([#1133]). + +- Deprecate and type alias `pq.NullTime` to `sql.NullTime` ([#1211]). + +- Enforce integer limits of the Postgres wire protocol ([#1161]). + +- Accept the `passfile` connection parameter to override `PGPASSFILE` ([#1129]). + +- Fix connecting to socket on Windows systems ([#1179]). + +- Don't perform a permission check on the .pgpass file on Windows ([#595]). + +- Warn about incorrect .pgpass permissions ([#595]). + +- Don't set extra_float_digits ([#1212]). + +- Decode bpchar into a string ([#949]). + +- Fix panic in Ping() by not requiring CommandComplete or EmptyQueryResponse in + simpleQuery() ([#1234]) + +- Recognize bit/varbit ([#743]) and float types ([#1166]) in ColumnTypeScanType(). + +- Accept `PGGSSLIB` and `PGKRBSRVNAME` environment variables ([#1143]). + +- Handle ErrorResponse in readReadyForQuery and return proper error ([#1136]). + +- Detect COPY even if the query starts with whitespace or comments ([#1198]). + +- CopyIn() and CopyInSchema() now work if the list of columns is empty, in which + case it will copy all columns ([#1239]). + +- Treat nil []byte in query parameters as nil/NULL rather than `""` ([#838]). + +- Accept multiple authentication methods before checking AuthOk, which improves + compatibility with PgPool-II ([#1188]). + +[`sslnegotiation`]: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLNEGOTIATION +[#595]: https://github.com/lib/pq/pull/595 +[#745]: https://github.com/lib/pq/pull/745 +[#743]: https://github.com/lib/pq/pull/743 +[#838]: https://github.com/lib/pq/pull/838 +[#949]: https://github.com/lib/pq/pull/949 +[#1125]: https://github.com/lib/pq/pull/1125 +[#1129]: https://github.com/lib/pq/pull/1129 +[#1133]: https://github.com/lib/pq/pull/1133 +[#1136]: https://github.com/lib/pq/pull/1136 +[#1143]: https://github.com/lib/pq/pull/1143 +[#1161]: https://github.com/lib/pq/pull/1161 +[#1166]: https://github.com/lib/pq/pull/1166 +[#1179]: https://github.com/lib/pq/pull/1179 +[#1180]: https://github.com/lib/pq/pull/1180 +[#1184]: https://github.com/lib/pq/pull/1184 +[#1188]: https://github.com/lib/pq/pull/1188 +[#1198]: https://github.com/lib/pq/pull/1198 +[#1211]: https://github.com/lib/pq/pull/1211 +[#1212]: https://github.com/lib/pq/pull/1212 +[#1214]: https://github.com/lib/pq/pull/1214 +[#1219]: https://github.com/lib/pq/pull/1219 +[#1223]: https://github.com/lib/pq/pull/1223 +[#1224]: https://github.com/lib/pq/pull/1224 +[#1226]: https://github.com/lib/pq/pull/1226 +[#1228]: https://github.com/lib/pq/pull/1228 +[#1234]: https://github.com/lib/pq/pull/1234 +[#1238]: https://github.com/lib/pq/pull/1238 +[#1239]: https://github.com/lib/pq/pull/1239 +[#1240]: https://github.com/lib/pq/pull/1240 +[#1243]: https://github.com/lib/pq/pull/1243 +[#1246]: https://github.com/lib/pq/pull/1246 +[#1248]: https://github.com/lib/pq/pull/1248 + + +v1.10.9 (2023-04-26) +-------------------- +- Fixes backwards incompat bug with 1.13. + +- Fixes pgpass issue diff --git a/vendor/github.com/lib/pq/LICENSE b/vendor/github.com/lib/pq/LICENSE new file mode 100644 index 00000000..6a77dc4f --- /dev/null +++ b/vendor/github.com/lib/pq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2011-2013, 'pq' Contributors. Portions Copyright (c) 2011 Blake Mizerany + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md new file mode 100644 index 00000000..5ca523d4 --- /dev/null +++ b/vendor/github.com/lib/pq/README.md @@ -0,0 +1,312 @@ +pq is a Go PostgreSQL driver for database/sql. + +All [maintained versions of PostgreSQL] are supported. Older versions may work, +but this is not tested. [API docs]. + +[maintained versions of PostgreSQL]: https://www.postgresql.org/support/versioning +[API docs]: https://pkg.go.dev/github.com/lib/pq + +Connecting +---------- +Use the `postgres` driver name in the `sql.Open()` call: + +```go +package main + +import ( + "database/sql" + "log" + + _ "github.com/lib/pq" // To register the driver. +) + +func main() { + // Or as URL: postgresql://localhost/pqgo + db, err := sql.Open("postgres", "host=localhost dbname=pqgo") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + // db.Open() only creates a connection pool, and doesn't actually establish + // a connection. To ensure the connection works you need to do *something* + // with a connection. + err = db.Ping() + if err != nil { + log.Fatal(err) + } +} +``` + +You can also use the `pq.Config` struct: + +```go +cfg := pq.Config{ + Host: "localhost", + Port: 5432, + User: "pqgo", +} +// Or: create a new Config from the defaults, environment, and DSN. +// cfg, err := pq.NewConfig("host=postgres dbname=pqgo") +// if err != nil { +// log.Fatal(err) +// } + +c, err := pq.NewConnectorConfig(cfg) +if err != nil { + log.Fatal(err) +} + +// Create connection pool. +db := sql.OpenDB(c) +defer db.Close() + +// Make sure it works. +err = db.Ping() +if err != nil { + log.Fatal(err) +} +``` + +The DSN is identical to PostgreSQL's libpq; most parameters are supported and +should behave identical. Both key=value and postgres:// URL-style connection +strings are supported. See the doc comments on the [Config struct] for the full +list and documentation. + +The most notable difference is that you can use any [run-time parameter] such as +`search_path` or `work_mem` in the connection string. This is different from +libpq, which uses the `options` parameter for this (which also works in pq). + +For example: + + sql.Open("postgres", "dbname=pqgo work_mem=100kB search_path=xyz") + +The libpq way (which also works in pq) is to use `options='-c k=v'` like so: + + sql.Open("postgres", "dbname=pqgo options='-c work_mem=100kB -c search_path=xyz'") + +[Config struct]: https://pkg.go.dev/github.com/lib/pq#Config +[run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html + +Errors +------ +Errors from PostgreSQL are returned as [pq.Error]; [pq.As] can be used to +convert an error to `pq.Error`: + +```go +pqErr := pq.As(err, pqerror.UniqueViolation) +if pqErr != nil { + return fmt.Errorf("email %q already exsts", email) +} +``` + +the Error() string contains the error message and code: + + pq: duplicate key value violates unique constraint "users_lower_idx" (23505) + +The ErrorWithDetail() string also contains the DETAIL and CONTEXT fields, if +present. For example for the above error this helpfully contains the duplicate +value: + + ERROR: duplicate key value violates unique constraint "users_lower_idx" (23505) + DETAIL: Key (lower(email))=(a@example.com) already exists. + +Or for an invalid syntax error like this: + + pq: invalid input syntax for type json (22P02) + +It contains the context where this error occurred: + + ERROR: invalid input syntax for type json (22P02) + DETAIL: Token "asd" is invalid. + CONTEXT: line 5, column 8: + + 3 | 'def', + 4 | 123, + 5 | 'foo', 'asd'::jsonb + ^ + +[pq.Error]: https://pkg.go.dev/github.com/lib/pq#Error +[pq.As]: https://pkg.go.dev/github.com/lib/pq#As + +PostgreSQL features +------------------- + +### Authentication +pq supports PASSWORD, MD5, and SCRAM-SHA256 authentication out of the box. If +you need GSS/Kerberos authentication you'll need to import the `auth/kerberos` +module: package: + + import "github.com/lib/pq/auth/kerberos" + + func init() { + pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) + } + +This is in a separate module so that users who don't need Kerberos (i.e. most +users) don't have to add unnecessary dependencies. + +Reading a [password file] (pgpass) is also supported. + +[password file]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + +### Bulk imports with `COPY [..] FROM STDIN` +You can perform bulk imports by preparing a `COPY [..] FROM STDIN` statement +inside a transaction. The returned `sql.Stmt` can then be repeatedly executed to +copy data. After all data has been processed you should call Exec() once with no +arguments to flush all buffered data. + +[Further documentation][copy-doc] and [example][copy-ex]. + +[copy-doc]: https://pkg.go.dev/github.com/lib/pq#hdr-Bulk_imports +[copy-ex]: https://pkg.go.dev/github.com/lib/pq#example-package-CopyFromStdin + +### NOTICE errors +PostgreSQL has "NOTICE" errors for informational messages. For example from the +psql CLI: + + pqgo=# drop table if exists doesnotexist; + NOTICE: table "doesnotexist" does not exist, skipping + DROP TABLE + +These errors are not returned because they're not really errors but, well, +notices. + +You can register a callback for these notices with [ConnectorWithNoticeHandler] + +[ConnectorWithNoticeHandler]: https://pkg.go.dev/github.com/lib/pq#ConnectorWithNoticeHandler + +### Using `LISTEN`/`NOTIFY` +With [pq.Listener] notifications are send on a channel. For example: + +```go +l := pq.NewListener("dbname=pqgo", time.Second, time.Minute, nil) +defer l.Close() + +err := l.Listen("coconut") +if err != nil { + log.Fatal(err) +} + +for { + n := <-l.Notify: + if n == nil { + fmt.Println("nil notify: closing Listener") + return + } + fmt.Printf("notification on %q with data %q\n", n.Channel, n.Extra) +} +``` + +And you'll get a notification for every `notify coconut`. + +See the API docs for a more complete example. + +[pq.Listener]: https://pkg.go.dev/github.com/lib/pq#Listener + + +Caveats +------- +### LastInsertId +sql.Result.LastInsertId() is not supported, because the PostgreSQL protocol does +not have this facility. Use `insert [..] returning [cols]` instead: + + db.QueryRow(`insert into tbl [..] returning id_col`).Scan(..) + // Or multiple rows: + db.Query(`insert into tbl (row1), (row2) returning id_col`) + +This will also work in SQLite and MariaDB with the same syntax. MS-SQL and +Oracle have a similar facility (with a different syntax). + +### timestamps +For timestamps with a timezone (`timestamptz`/`timestamp with time zone`), pq +uses the timezone configured in the server, as libpq. You can change this with +`timestamp=[..]` in the connection string. It's generally recommended to use +UTC. + +For timestamps without a timezone (`timestamp`/`timestamp without time zone`), +pq always uses `time.FixedZone("", 0)` as the timezone; the timestamp parameter +has no effect here. This is intentionally not equal to time.UTC, as it's not a +UTC time: it's a time without a timezone. Go's time package does not really +support this concept, so this is the best we can do This will print `+0000` +twice (e.g. `2026-03-15 17:45:47 +0000 +0000`; having a clearer name would have +been better, but is not compatible change). See [this comment][ts] for some +options on how to deal with this. + +Also see the examples for [timestamptz] and [timestamp] + +[ts]: https://github.com/lib/pq/issues/329#issuecomment-4025733506 +[timestamptz]: https://pkg.go.dev/github.com/lib/pq#example-package-TimestampWithTimezone +[timestamp]: https://pkg.go.dev/github.com/lib/pq#example-package-TimestampWithoutTimezone + +### bytea with copy +All `[]byte` parameters are encoded as `bytea` when using `copy [..] from +stdin`, which may result in errors for e.g. `jsonb` columns. The solution is to +use a string instead of []byte. See #1023 + +Development +----------- +### Running tests +Tests need to be run against a PostgreSQL database; you can use Docker compose +to start one: + + docker compose up -d + +This starts the latest PostgreSQL; use `docker compose up -d pg«v»` to start a +different version. + +In addition, your `/etc/hosts` needs an entry: + + 127.0.0.1 postgres postgres-invalid + +Or you can use any other PostgreSQL instance; see +`testdata/init/docker-entrypoint-initdb.d` for the required setup. You can use +the standard `PG*` environment variables to control the connection details; it +uses the following defaults: + + PGHOST=localhost + PGDATABASE=pqgo + PGUSER=pqgo + PGSSLMODE=disable + PGCONNECT_TIMEOUT=20 + +`PQTEST_BINARY_PARAMETERS` can be used to add `binary_parameters=yes` to all +connection strings: + + PQTEST_BINARY_PARAMETERS=1 go test + +Tests can be run against pgbouncer with: + + docker compose up -d pgbouncer pg18 + PGPORT=6432 go test ./... + +and pgpool with: + + docker compose up -d pgpool pg18 + PGPORT=7432 go test ./... + +### Protocol debug output +You can use PQGO_DEBUG=1 to make the driver print the communication with +PostgreSQL to stderr; this works anywhere (test or applications) and can be +useful to debug protocol problems. + +For example: + + % PQGO_DEBUG=1 go test -run TestSimpleQuery + CLIENT → Startup 69 "\x00\x03\x00\x00database\x00pqgo\x00user [..]" + SERVER ← (R) AuthRequest 4 "\x00\x00\x00\x00" + SERVER ← (S) ParamStatus 19 "in_hot_standby\x00off\x00" + [..] + SERVER ← (Z) ReadyForQuery 1 "I" + START conn.query + START conn.simpleQuery + CLIENT → (Q) Query 9 "select 1\x00" + SERVER ← (T) RowDescription 29 "\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x00\x04\xff\xff\xff\xff\x00\x00" + SERVER ← (D) DataRow 7 "\x00\x01\x00\x00\x00\x011" + END conn.simpleQuery + END conn.query + SERVER ← (C) CommandComplete 9 "SELECT 1\x00" + SERVER ← (Z) ReadyForQuery 1 "I" + CLIENT → (X) Terminate 0 "" + PASS + ok github.com/lib/pq 0.010s diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go new file mode 100644 index 00000000..4a532868 --- /dev/null +++ b/vendor/github.com/lib/pq/array.go @@ -0,0 +1,903 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a any) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) + case []string: + return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) + case *[]string: + return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) + } + + return GenericArray{a} +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +type ByteaArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *ByteaArray) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %w", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a ByteaArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + b[i], err = strconv.ParseFloat(string(v), 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + x, err := strconv.ParseFloat(string(v), 32) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for +// an array or slice of any dimension. +type GenericArray struct{ A any } + +func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { + var assign func([]byte, reflect.Value) error + var del = "," + + // TODO calculate the assign function for other types + // TODO repeat this section on the element type of arrays or slices (multidimensional) + { + if reflect.PointerTo(rt).Implements(typeSQLScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) (err error) { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + err = ss.Scan(nil) + } else { + err = ss.Scan(src) + } + return + } + goto FoundType + } + + assign = func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + } + } + +FoundType: + + if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + return rt, assign, del +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src any) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Pointer: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) + dims, elems, err := parseArray(src, []byte(del)) + if err != nil { + return err + } + + // TODO allow multidimensional + + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + // TODO handle multidimensional + } + } + + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + err := assign(e, values.Index(i)) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + } + + // TODO handle multidimensional + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: + return nil, fmt.Errorf("pq: unable to convert %T to array", a.A) + } + + if n := rv.Len(); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 0, 1+2*n) + + b, _, err := appendArray(b, rv, n) + return string(b), err + } + + return "{}", nil +} + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + b[i], err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + x, err := strconv.ParseInt(string(v), 10, 32) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src any) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and the +// delimiter used between elements. +// +// Returns an error when n <= 0 or rv is not a reflect.Array or reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del = "," + var err error + var iv = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + enc, err := encode(v, 0) + if err != nil { + return nil, err + } + return append(b, enc...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + if depth == len(dims) { + break Element + } + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' && depth > 0 { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff --git a/vendor/github.com/lib/pq/as.go b/vendor/github.com/lib/pq/as.go new file mode 100644 index 00000000..1ea0ee5b --- /dev/null +++ b/vendor/github.com/lib/pq/as.go @@ -0,0 +1,26 @@ +//go:build !go1.26 + +package pq + +import ( + "errors" + "slices" +) + +// As asserts that the given error is [pq.Error] and returns it, returning nil +// if it's not a pq.Error. +// +// It will return nil if the pq.Error is not one of the given error codes. If no +// codes are given it will always return the Error. +// +// This is safe to call with a nil error. +func As(err error, codes ...ErrorCode) *Error { + if err == nil { // Not strictly needed, but prevents alloc for nil errors. + return nil + } + pqErr := new(Error) + if errors.As(err, &pqErr) && (len(codes) == 0 || slices.Contains(codes, pqErr.Code)) { + return pqErr + } + return nil +} diff --git a/vendor/github.com/lib/pq/as_go126.go b/vendor/github.com/lib/pq/as_go126.go new file mode 100644 index 00000000..18ffbc37 --- /dev/null +++ b/vendor/github.com/lib/pq/as_go126.go @@ -0,0 +1,23 @@ +//go:build go1.26 + +package pq + +import ( + "errors" + "github.com/lib/pq/pqerror" + "slices" +) + +// As asserts that the given error is [pq.Error] and returns it, returning nil +// if it's not a pq.Error. +// +// It will return nil if the pq.Error is not one of the given error codes. If no +// codes are given it will always return the Error. +// +// This is safe to call with a nil error. +func As(err error, codes ...pqerror.Code) *Error { + if pqErr, ok := errors.AsType[*Error](err); ok && (len(codes) == 0 || slices.Contains(codes, pqErr.Code)) { + return pqErr + } + return nil +} diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go new file mode 100644 index 00000000..67ca60cc --- /dev/null +++ b/vendor/github.com/lib/pq/buf.go @@ -0,0 +1,100 @@ +package pq + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" +) + +type readBuf []byte + +func (b *readBuf) int32() (n int) { + n = int(int32(binary.BigEndian.Uint32(*b))) + *b = (*b)[4:] + return +} + +func (b *readBuf) oid() (n oid.Oid) { + n = oid.Oid(binary.BigEndian.Uint32(*b)) + *b = (*b)[4:] + return +} + +// N.B: this is actually an unsigned 16-bit integer, unlike int32 +func (b *readBuf) int16() (n int) { + n = int(binary.BigEndian.Uint16(*b)) + *b = (*b)[2:] + return +} + +func (b *readBuf) string() string { + i := bytes.IndexByte(*b, 0) + if i < 0 { + panic(errors.New("pq: invalid message format; expected string terminator")) + } + s := (*b)[:i] + *b = (*b)[i+1:] + return string(s) +} + +func (b *readBuf) next(n int) (v []byte) { + v = (*b)[:n] + *b = (*b)[n:] + return +} + +func (b *readBuf) byte() byte { + return b.next(1)[0] +} + +type writeBuf struct { + buf []byte + pos int +} + +func (b *writeBuf) int32(n int) { + x := make([]byte, 4) + binary.BigEndian.PutUint32(x, uint32(n)) + b.buf = append(b.buf, x...) +} + +func (b *writeBuf) int16(n int) { + x := make([]byte, 2) + binary.BigEndian.PutUint16(x, uint16(n)) + b.buf = append(b.buf, x...) +} + +func (b *writeBuf) string(s string) { + b.buf = append(append(b.buf, s...), '\000') +} + +func (b *writeBuf) byte(c proto.RequestCode) { + b.buf = append(b.buf, byte(c)) +} + +func (b *writeBuf) bytes(v []byte) { + b.buf = append(b.buf, v...) +} + +func (b *writeBuf) wrap() []byte { + p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } + binary.BigEndian.PutUint32(p, uint32(len(p))) + return b.buf +} + +func (b *writeBuf) next(c proto.RequestCode) { + p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } + binary.BigEndian.PutUint32(p, uint32(len(p))) + b.pos = len(b.buf) + 1 + b.buf = append(b.buf, byte(c), 0, 0, 0, 0) +} diff --git a/vendor/github.com/lib/pq/compose.yaml b/vendor/github.com/lib/pq/compose.yaml new file mode 100644 index 00000000..254344ed --- /dev/null +++ b/vendor/github.com/lib/pq/compose.yaml @@ -0,0 +1,81 @@ +name: 'pqgo' + +services: + pgbouncer: + profiles: ['pgbouncer'] + image: 'cleanstart/pgbouncer:latest' + ports: ['127.0.0.1:6432:6432'] + command: ['/init/pgbouncer.ini'] + volumes: ['./testdata/init:/init'] + environment: + 'PGBOUNCER_DATABASE': 'pqgo' + + pgpool: + profiles: ['pgpool'] + image: 'pgpool/pgpool:4.4.3' + ports: ['127.0.0.1:7432:7432'] + volumes: ['./testdata/init:/init'] + entrypoint: '/init/entry-pgpool.sh' + environment: + 'PGPOOL_PARAMS_PORT': '7432' + 'PGPOOL_PARAMS_BACKEND_HOSTNAME0': 'pg18' + + pg18: + image: 'postgres:18' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready', '-U', 'pqgo', '-d', 'pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg17: + profiles: ['pg17'] + image: 'postgres:17' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready', '-U', 'pqgo', '-d', 'pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg16: + profiles: ['pg16'] + image: 'postgres:16' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready', '-U', 'pqgo', '-d', 'pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg15: + profiles: ['pg15'] + image: 'postgres:15' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready', '-U', 'pqgo', '-d', 'pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg14: + profiles: ['pg14'] + image: 'postgres:14' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + healthcheck: {test: ['CMD-SHELL', 'pg_isready', '-U', 'pqgo', '-d', 'pqgo'], start_period: '30s', start_interval: '1s'} + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go new file mode 100644 index 00000000..d9adc31c --- /dev/null +++ b/vendor/github.com/lib/pq/conn.go @@ -0,0 +1,1817 @@ +package pq + +import ( + "bufio" + "context" + "crypto/md5" + "crypto/sha256" + "database/sql" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "os" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/lib/pq/internal/pgpass" + "github.com/lib/pq/internal/pqsql" + "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" + "github.com/lib/pq/scram" +) + +// Common error types +var ( + ErrNotSupported = errors.New("pq: unsupported command") + ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction") + ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") + ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly") + ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership + ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions + + errQueryInProgress = errors.New("pq: there is already a query being processed on this connection") + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") +) + +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} + _ driver.ConnBeginTx = (*conn)(nil) + _ driver.ConnPrepareContext = (*conn)(nil) + _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.ExecerContext = (*conn)(nil) + _ driver.NamedValueChecker = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.QueryerContext = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) +) + +func init() { + sql.Register("postgres", &Driver{}) +} + +var debugProto = func() bool { + // Check for exactly "1" (rather than mere existence) so we can add + // options/flags in the future. I don't know if we ever want that, but it's + // nice to leave the option open. + return os.Getenv("PQGO_DEBUG") == "1" +}() + +// Driver is the Postgres database driver. +type Driver struct{} + +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. +func (d Driver) Open(name string) (driver.Conn, error) { + return Open(name) +} + +// Parameters sent by PostgreSQL on startup. +type parameterStatus struct { + serverVersion int + currentLocation *time.Location + inHotStandby, defaultTransactionReadOnly sql.NullBool +} + +type format int + +const ( + formatText format = 0 + formatBinary format = 1 +) + +var ( + // One result-column format code with the value 1 (i.e. all binary). + colFmtDataAllBinary = []byte{0, 1, 0, 1} + + // No result-column format codes (i.e. all text). + colFmtDataAllText = []byte{0, 0} +) + +type transactionStatus byte + +const ( + txnStatusIdle transactionStatus = 'I' + txnStatusIdleInTransaction transactionStatus = 'T' + txnStatusInFailedTransaction transactionStatus = 'E' +) + +func (s transactionStatus) String() string { + switch s { + case txnStatusIdle: + return "idle" + case txnStatusIdleInTransaction: + return "idle in transaction" + case txnStatusInFailedTransaction: + return "in a failed transaction" + default: + panic(fmt.Sprintf("pq: unknown transactionStatus %d", s)) + } +} + +// Dialer is the dialer interface. It can be used to obtain more control over +// how pq creates network connections. +type Dialer interface { + Dial(network, address string) (net.Conn, error) + DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) +} + +// DialerContext is the context-aware dialer interface. +type DialerContext interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type defaultDialer struct { + d net.Dialer +} + +func (d defaultDialer) Dial(network, address string) (net.Conn, error) { + return d.d.Dial(network, address) +} + +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.DialContext(ctx, network, address) +} + +func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.d.DialContext(ctx, network, address) +} + +type conn struct { + c net.Conn + buf *bufio.Reader + namei int + scratch [512]byte + txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + cfg Config + parameterStatus parameterStatus + + saveMessageType proto.ResponseCode + saveMessageBuffer []byte + + // If an error is set this connection is bad and all public-facing + // functions should return the appropriate error by calling get() + // (ErrBadConn) or getForNext(). + err syncErr + + secretKey []byte // Cancellation key for CancelRequest messages. + pid int // Cancellation PID. + inProgress atomic.Bool // This connection is in the middle of a processing a request. + noticeHandler func(*Error) // If not nil, notices will be synchronously sent here + notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here + gss GSS // GSSAPI context +} + +type syncErr struct { + err error + sync.Mutex +} + +// Return ErrBadConn if connection is bad. +func (e *syncErr) get() error { + e.Lock() + defer e.Unlock() + if e.err != nil { + return driver.ErrBadConn + } + return nil +} + +// Return the error set on the connection. Currently only used by rows.Next. +func (e *syncErr) getForNext() error { + e.Lock() + defer e.Unlock() + return e.err +} + +// Set error, only if it isn't set yet. +func (e *syncErr) set(err error) { + if err == nil { + panic("attempt to set nil err") + } + e.Lock() + defer e.Unlock() + if e.err == nil { + e.err = err + } +} + +func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf { + cn.scratch[0] = byte(b) + return &writeBuf{ + buf: cn.scratch[:5], + pos: 1, + } +} + +// Open opens a new connection to the database. dsn is a connection string. Most +// users should only use it through database/sql package from the standard +// library. +func Open(dsn string) (_ driver.Conn, err error) { + return DialOpen(defaultDialer{}, dsn) +} + +// DialOpen opens a new connection to the database using a dialer. +func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + c.Dialer(d) + return c.open(context.Background()) +} + +func (c *Connector) open(ctx context.Context) (*conn, error) { + tsa := c.cfg.TargetSessionAttrs +restartAll: + var ( + errs []error + app = func(err error, cfg Config) bool { + if err != nil { + if debugProto { + fmt.Fprintln(os.Stderr, "CONNECT (error)", err) + } + errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) + } + return err != nil + } + ) + for _, cfg := range c.cfg.hosts() { + mode := cfg.SSLMode + if mode == "" { + mode = SSLModePrefer + } + restartHost: + if debugProto { + fmt.Fprintln(os.Stderr, "CONNECT ", cfg.string()) + } + + cfg.SSLMode = mode + cn := &conn{cfg: cfg, dialer: c.dialer} + cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, + cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password")) + + var err error + cn.c, err = dial(ctx, c.dialer, cn.cfg) + if app(err, cfg) { + continue + } + + err = cn.ssl(cn.cfg, mode) + if err != nil && mode == SSLModePrefer { + mode = SSLModeDisable + goto restartHost + } + if app(err, cfg) { + if cn.c != nil { + _ = cn.c.Close() + } + continue + } + + cn.buf = bufio.NewReader(cn.c) + err = cn.startup(cn.cfg) + if err != nil && mode == SSLModeAllow { + mode = SSLModeRequire + goto restartHost + } + if app(err, cfg) { + _ = cn.c.Close() + continue + } + + // Reset the deadline, in case one was set (see dial) + if cn.cfg.ConnectTimeout > 0 { + err := cn.c.SetDeadline(time.Time{}) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + } + + err = cn.checkTSA(tsa) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + + return cn, nil + } + + // target_session_attrs=prefer-standby is treated as standby in checkTSA; we + // ran out of hosts so none are on standby. Clear the setting and try again. + if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { + tsa = TargetSessionAttrsAny + goto restartAll + } + + if len(c.cfg.Multi) == 0 { + // Remove the "connecting to [..]" when we have just one host, so the + // error is identical to what we had before. + return nil, errors.Unwrap(errs[0]) + } + return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...)) +} + +func (cn *conn) getBool(query string) (bool, error) { + res, err := cn.simpleQuery(query) + if err != nil { + return false, err + } + defer res.Close() + + v := make([]driver.Value, 1) + err = res.Next(v) + if err != nil { + return false, err + } + + switch vv := v[0].(type) { + default: + return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0]) + case bool: + return vv, nil + case string: + vv, ok := v[0].(string) + if !ok { + return false, err + } + return vv == "on", nil + } +} + +func (cn *conn) checkTSA(tsa TargetSessionAttrs) error { + var ( + geths = func() (hs bool, err error) { + hs = cn.parameterStatus.inHotStandby.Bool + if !cn.parameterStatus.inHotStandby.Valid { + hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()") + } + return hs, err + } + getro = func() (ro bool, err error) { + ro = cn.parameterStatus.defaultTransactionReadOnly.Bool + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + ro, err = cn.getBool("show transaction_read_only") + } + return ro, err + } + ) + + switch tsa { + default: + panic("unreachable") + case "", TargetSessionAttrsAny: + return nil + case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly: + readonly, err := getro() + if err != nil { + return err + } + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + var err error + readonly, err = cn.getBool("show transaction_read_only") + if err != nil { + return err + } + } + switch { + case tsa == TargetSessionAttrsReadOnly && !readonly: + return errors.New("session is not read-only") + case tsa == TargetSessionAttrsReadWrite: + if readonly { + return errors.New("session is read-only") + } + hs, err := geths() + if err != nil { + return err + } + if hs { + return errors.New("server is in hot standby mode") + } + return nil + default: + return nil + } + case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby: + hs, err := geths() + if err != nil { + return err + } + switch { + case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs: + return errors.New("server is not in hot standby mode") + case tsa == TargetSessionAttrsPrimary && hs: + return errors.New("server is in hot standby mode") + default: + return nil + } + } +} + +func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) { + network, address := cfg.network() + + // Zero or not specified means wait indefinitely. + if cfg.ConnectTimeout > 0 { + // connect_timeout should apply to the entire connection establishment + // procedure, so we both use a timeout for the TCP connection + // establishment and set a deadline for doing the initial handshake. The + // deadline is then reset after startup() is done. + var ( + deadline = time.Now().Add(cfg.ConnectTimeout) + conn net.Conn + err error + ) + if dctx, ok := d.(DialerContext); ok { + ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) + defer cancel() + conn, err = dctx.DialContext(ctx, network, address) + } else { + conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout) + } + if err != nil { + return nil, err + } + err = conn.SetDeadline(deadline) + return conn, err + } + if dctx, ok := d.(DialerContext); ok { + return dctx.DialContext(ctx, network, address) + } + return d.Dial(network, address) +} + +func (cn *conn) isInTransaction() bool { + return cn.txnStatus == txnStatusIdleInTransaction || + cn.txnStatus == txnStatusInFailedTransaction +} + +func (cn *conn) checkIsInTransaction(intxn bool) error { + if cn.isInTransaction() != intxn { + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus) + } + return nil +} + +func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { + if err := cn.err.get(); err != nil { + return nil, err + } + if err := cn.checkIsInTransaction(false); err != nil { + return nil, err + } + + _, commandTag, err := cn.simpleExec("BEGIN" + mode) + if err != nil { + return nil, cn.handleError(err) + } + if commandTag != "BEGIN" { + cn.err.set(driver.ErrBadConn) + return nil, fmt.Errorf("unexpected command tag %s", commandTag) + } + if cn.txnStatus != txnStatusIdleInTransaction { + cn.err.set(driver.ErrBadConn) + return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) + } + return cn, nil +} + +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + +func (cn *conn) Commit() error { + defer cn.closeTxn() + if err := cn.err.get(); err != nil { + return err + } + if err := cn.checkIsInTransaction(true); err != nil { + return err + } + + // We don't want the client to think that everything is okay if it tries + // to commit a failed transaction. However, no matter what we return, + // database/sql will release this connection back into the free connection + // pool so we have to abort the current transaction here. Note that you + // would get the same behaviour if you issued a COMMIT in a failed + // transaction, so it's also the least surprising thing to do here. + if cn.txnStatus == txnStatusInFailedTransaction { + if err := cn.rollback(); err != nil { + return err + } + return ErrInFailedTransaction + } + + _, commandTag, err := cn.simpleExec("COMMIT") + if err != nil { + if cn.isInTransaction() { + cn.err.set(driver.ErrBadConn) + } + return cn.handleError(err) + } + if commandTag != "COMMIT" { + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("unexpected command tag %s", commandTag) + } + return cn.checkIsInTransaction(false) +} + +func (cn *conn) Rollback() error { + defer cn.closeTxn() + if err := cn.err.get(); err != nil { + return err + } + + err := cn.rollback() + if err != nil { + return cn.handleError(err) + } + return nil +} + +func (cn *conn) rollback() (err error) { + if err := cn.checkIsInTransaction(true); err != nil { + return err + } + + _, commandTag, err := cn.simpleExec("ROLLBACK") + if err != nil { + if cn.isInTransaction() { + cn.err.set(driver.ErrBadConn) + } + return err + } + if commandTag != "ROLLBACK" { + return fmt.Errorf("unexpected command tag %s", commandTag) + } + return cn.checkIsInTransaction(false) +} + +func (cn *conn) gname() string { + cn.namei++ + return strconv.FormatInt(int64(cn.namei), 10) +} + +func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.simpleExec") + defer fmt.Fprintln(os.Stderr, " END conn.simpleExec") + } + + b := cn.writeBuf(proto.Query) + b.string(q) + err := cn.send(b) + if err != nil { + return nil, "", err + } + + for { + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } + switch t { + case proto.CommandComplete: + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: + cn.processReadyForQuery(r) + if res == nil && resErr == nil { + resErr = errUnexpectedReady + } + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.EmptyQueryResponse: + res = emptyRows + case proto.RowDescription, proto.DataRow: + // ignore any results + default: + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t) + } + } +} + +func (cn *conn) simpleQuery(q string) (*rows, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.simpleQuery") + defer fmt.Fprintln(os.Stderr, " END conn.simpleQuery") + } + + b := cn.writeBuf(proto.Query) + b.string(q) + err := cn.send(b) + if err != nil { + return nil, cn.handleError(err, q) + } + + var ( + res *rows + resErr error + ) + for { + t, r, err := cn.recv1() + if err != nil { + return nil, cn.handleError(err, q) + } + switch t { + case proto.CommandComplete, proto.EmptyQueryResponse: + // We allow queries which don't return any results through Query as + // well as Exec. We still have to give database/sql a rows object + // the user can close, though, to avoid connections from being + // leaked. A "rows" with done=true works fine for that purpose. + if resErr != nil { + cn.err.set(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t) + } + if res == nil { + res = &rows{cn: cn} + } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == proto.CommandComplete { + res.result, res.tag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, cn.handleError(err, q) + } + if res.colNames != nil { + return res, cn.handleError(resErr, q) + } + } + res.done = true + case proto.ReadyForQuery: + cn.processReadyForQuery(r) + if err == nil && res == nil { + res = &rows{done: true} + } + return res, cn.handleError(resErr, q) // done + case proto.ErrorResponse: + res = nil + resErr = parseError(r, q) + case proto.DataRow: + if res == nil { + cn.err.set(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution") + } + return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next + case proto.RowDescription: + // res might be non-nil here if we received a previous + // CommandComplete, but that's fine and just overwrite it. + res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)} + + // To work around a bug in QueryRow in Go 1.2 and earlier, wait + // until the first DataRow has been received. + default: + cn.err.set(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unknown response for simple query: %q", t) + } + } +} + +// Decides which column formats to use for a prepared statement. The input is +// an array of type oids, one element per result column. +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) { + if len(colTyps) == 0 { + return nil, colFmtDataAllText, nil + } + + colFmts = make([]format, len(colTyps)) + if forceText { + return colFmts, colFmtDataAllText, nil + } + + allBinary := true + allText := true + for i, t := range colTyps { + switch t.OID { + // This is the list of types to use binary mode for when receiving them + // through a prepared statement. If a type appears in this list, it + // must also be implemented in binaryDecode in encode.go. + case oid.T_bytea: + fallthrough + case oid.T_int8: + fallthrough + case oid.T_int4: + fallthrough + case oid.T_int2: + fallthrough + case oid.T_uuid: + colFmts[i] = formatBinary + allText = false + default: + allBinary = false + } + } + + if allBinary { + return colFmts, colFmtDataAllBinary, nil + } else if allText { + return colFmts, colFmtDataAllText, nil + } else { + colFmtData = make([]byte, 2+len(colFmts)*2) + if len(colFmts) > math.MaxUint16 { + return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts)) + } + binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) + for i, v := range colFmts { + binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) + } + return colFmts, colFmtData, nil + } +} + +func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.prepareTo") + defer fmt.Fprintln(os.Stderr, " END conn.prepareTo") + } + + st := &stmt{cn: cn, name: stmtName} + + b := cn.writeBuf(proto.Parse) + b.string(st.name) + b.string(q) + b.int16(0) + + b.next(proto.Describe) + b.byte(proto.Sync) + b.string(st.name) + + b.next(proto.Sync) + err := cn.send(b) + if err != nil { + return nil, err + } + + err = cn.readParseResponse() + if err != nil { + return nil, err + } + st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() + if err != nil { + return nil, err + } + st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult) + if err != nil { + return nil, err + } + + err = cn.readReadyForQuery() + if err != nil { + return nil, err + } + return st, nil +} + +func (cn *conn) Prepare(q string) (driver.Stmt, error) { + if err := cn.err.get(); err != nil { + return nil, err + } + + if pqsql.StartsWithCopy(q) { + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inProgress.Store(true) + } + return s, cn.handleError(err, q) + } + s, err := cn.prepareTo(q, cn.gname()) + if err != nil { + return nil, cn.handleError(err, q) + } + return s, nil +} + +func (cn *conn) Close() error { + // Don't go through send(); ListenerConn relies on us not scribbling on the + // scratch buffer of this connection. + err := cn.sendSimpleMessage(proto.Terminate) + if err != nil { + _ = cn.c.Close() // Ensure that cn.c.Close is always run. + return cn.handleError(err) + } + return cn.c.Close() +} + +func toNamedValue(v []driver.Value) []driver.NamedValue { + v2 := make([]driver.NamedValue, len(v)) + for i := range v { + v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} + } + return v2 +} + +// CheckNamedValue implements [driver.NamedValueChecker]. +func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { + if cn.cfg.BinaryParameters { + if bin, ok := nv.Value.(interface{ BinaryValue() ([]byte, error) }); ok { + var err error + nv.Value, err = bin.BinaryValue() + return err + } + } + + // Ignore Valuer, for backward compatibility with pq.Array(). + if _, ok := nv.Value.(driver.Valuer); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if !v.IsValid() { + return driver.ErrSkip + } + t := v.Type() + for t.Kind() == reflect.Pointer { + t, v = t.Elem(), v.Elem() + } + + // Ignore []byte and related types: *[]byte, json.RawMessage, etc. + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return driver.ErrSkip + } + + switch v.Kind() { + default: + return driver.ErrSkip + case reflect.Slice: + var err error + nv.Value, err = Array(v.Interface()).Value() + return err + case reflect.Uint64: + value := v.Uint() + if value >= math.MaxInt64 { + nv.Value = strconv.FormatUint(value, 10) + } else { + nv.Value = int64(value) + } + return nil + } +} + +// Implement the "Queryer" interface +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, toNamedValue(args)) +} + +func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { + if debugProto { + fmt.Fprintln(os.Stderr, " START conn.query") + defer fmt.Fprintln(os.Stderr, " END conn.query") + } + if err := cn.err.get(); err != nil { + return nil, err + } + if !cn.inProgress.CompareAndSwap(false, true) { + return nil, errQueryInProgress + } + + // Check to see if we can use the "simpleQuery" interface, which is + // *much* faster than going through prepare/exec + if len(args) == 0 { + return cn.simpleQuery(query) + } + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, args) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + + rows := &rows{cn: cn} + rows.rowsHeader, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } + return rows, nil + } + + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) + } + err = st.exec(args) + if err != nil { + return nil, cn.handleError(err, query) + } + return &rows{ + cn: cn, + rowsHeader: st.rowsHeader, + }, nil +} + +// Implement the optional "Execer" interface for one-shot queries +func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { + if err := cn.err.get(); err != nil { + return nil, err + } + if !cn.inProgress.CompareAndSwap(false, true) { + return nil, errQueryInProgress + } + + // Check to see if we can use the "simpleExec" interface, which is *much* + // faster than going through prepare/exec + if len(args) == 0 { + // ignore commandTag, our caller doesn't care + r, _, err := cn.simpleExec(query) + return r, cn.handleError(err, query) + } + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, toNamedValue(args)) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + + _, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } + res, _, err := cn.readExecuteResponse("Execute") + return res, cn.handleError(err, query) + } + + // Use the unnamed statement to defer planning until bind time, or else + // value-based selectivity estimates cannot be used. + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) + } + r, err := st.Exec(args) + if err != nil { + return nil, cn.handleError(err, query) + } + return r, nil +} + +type safeRetryError struct{ Err error } + +func (se *safeRetryError) Error() string { return se.Err.Error() } + +func (cn *conn) send(m *writeBuf) error { + if debugProto { + w := m.wrap() + for len(w) > 0 { // Can contain multiple messages. + c := proto.RequestCode(w[0]) + l := int(binary.BigEndian.Uint32(w[1:5])) - 4 + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5]) + w = w[l+5:] + } + } + + n, err := cn.c.Write(m.wrap()) + if err != nil && n == 0 { + err = &safeRetryError{Err: err} + } + return err +} + +func (cn *conn) sendStartupPacket(m *writeBuf) error { + if debugProto { + w := m.wrap() + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", "Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:]) + } + _, err := cn.c.Write((m.wrap())[1:]) + return err +} + +// Send a message of type typ to the server on the other end of cn. The message +// should have no payload. This method does not use the scratch buffer. +func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { + if debugProto { + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", typ, 0, []byte{}) + } + _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) + return err +} + +// saveMessage memorizes a message and its buffer in the conn struct. +// recvMessage will then return these values on the next call to it. This +// method is useful in cases where you have to see what the next message is +// going to be (e.g. to see whether it's an error or not) but you can't handle +// the message yourself. +func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error { + if cn.saveMessageType != 0 { + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType) + } + cn.saveMessageType = typ + cn.saveMessageBuffer = *buf + return nil +} + +// recvMessage receives any message from the backend, or returns an error if +// a problem occurred while reading the message. +func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { + // workaround for a QueryRow bug, see exec + if cn.saveMessageType != 0 { + t := cn.saveMessageType + *r = cn.saveMessageBuffer + cn.saveMessageType = 0 + cn.saveMessageBuffer = nil + return t, nil + } + + x := cn.scratch[:5] + _, err := io.ReadFull(cn.buf, x) + if err != nil { + return 0, err + } + + // Read the type and length of the message that follows. + t := proto.ResponseCode(x[0]) + n := int(binary.BigEndian.Uint32(x[1:])) - 4 + + if proto.ResponseCode(t) == proto.ReadyForQuery { + cn.inProgress.Store(false) + } + + // When PostgreSQL cannot start a backend (e.g., an external process limit), + // it sends plain text like "Ecould not fork new process [..]", which + // doesn't use the standard encoding for the Error message. + // + // libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)", + // but check < 4 since n represents bytes remaining to be read after length. + if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) { + msg, _ := cn.buf.ReadString('\x00') + return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00")) + } + + var y []byte + if n <= len(cn.scratch) { + y = cn.scratch[:n] + } else { + y = make([]byte, n) + } + _, err = io.ReadFull(cn.buf, y) + if err != nil { + return 0, err + } + *r = y + if debugProto { + fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) + } + return t, nil +} + +// recv receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// NoticeResponses are ignored. This function should generally be used only +// during the startup sequence. +func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) { + for { + r := new(readBuf) + t, err := cn.recvMessage(r) + if err != nil { + return 0, nil, err + } + switch t { + case proto.ErrorResponse: + return 0, nil, parseError(r, "") + case proto.NoticeResponse: + if n := cn.noticeHandler; n != nil { + n(parseError(r, "")) + } + case proto.NotificationResponse: + if n := cn.notificationHandler; n != nil { + n(recvNotification(r)) + } + default: + return t, r, nil + } + } +} + +// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by +// the caller to avoid an allocation. +func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) { + for { + t, err := cn.recvMessage(r) + if err != nil { + return 0, err + } + + switch t { + case proto.NotificationResponse: + if n := cn.notificationHandler; n != nil { + n(recvNotification(r)) + } + case proto.NoticeResponse: + if n := cn.noticeHandler; n != nil { + n(parseError(r, "")) + } + case proto.ParameterStatus: + cn.processParameterStatus(r) + default: + return t, nil + } + } +} + +// recv1 receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// All asynchronous messages are ignored, with the exception of ErrorResponse. +func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { + r := new(readBuf) + t, err := cn.recv1Buf(r) + if err != nil { + return 0, nil, err + } + return t, r, nil +} + +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func (cn *conn) ssl(cfg Config, mode SSLMode) error { + upgrade, err := ssl(cfg, mode) + if err != nil { + return err + } + if upgrade == nil { + return nil // Nothing to do + } + + // Only negotiate the ssl handshake if requested (which is the default). + // sslnegotiation=direct is supported by pg17 and above. + if cfg.SSLNegotiation != SSLNegotiationDirect { + w := cn.writeBuf(0) + w.int32(proto.NegotiateSSLCode) + if err = cn.sendStartupPacket(w); err != nil { + return err + } + + b := cn.scratch[:1] + _, err = io.ReadFull(cn.c, b) + if err != nil { + return err + } + + if b[0] != 'S' { + return ErrSSLNotSupported + } + } + + cn.c, err = upgrade(cn.c) + return err +} + +func (cn *conn) startup(cfg Config) error { + w := cn.writeBuf(0) + // Send maximum protocol version in startup; if the server doesn't support + // this version it responds with NegotiateProtocolVersion and the maximum + // version it supports (and will use). + w.int32(cfg.MaxProtocolVersion.proto()) + + if cfg.User != "" { + w.string("user") + w.string(cfg.User) + } + if cfg.Database != "" { + w.string("database") + w.string(cfg.Database) + } + // w.string("replication") // Sent by libpq, but we don't support that. + if cfg.Options != "" { + w.string("options") + w.string(cfg.Options) + } + if cfg.ApplicationName != "" { + w.string("application_name") + w.string(cfg.ApplicationName) + } + if cfg.ClientEncoding != "" { + w.string("client_encoding") + w.string(cfg.ClientEncoding) + } + for k, v := range cfg.Runtime { + w.string(k) + w.string(v) + } + + w.string("") + if err := cn.sendStartupPacket(w); err != nil { + return err + } + + for { + t, r, err := cn.recv() + if err != nil { + return err + } + switch t { + case proto.BackendKeyData: + cn.pid = r.int32() + if len(*r) > 256 { + return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r)) + } + cn.secretKey = make([]byte, len(*r)) + copy(cn.secretKey, *r) + case proto.ParameterStatus: + cn.processParameterStatus(r) + case proto.AuthenticationRequest: + err := cn.auth(r, cfg) + if err != nil { + return err + } + case proto.NegotiateProtocolVersion: + newestMinor := r.int32() + serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor + if serverVersion < cfg.MinProtocolVersion.proto() { + return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) + } + case proto.ReadyForQuery: + cn.processReadyForQuery(r) + return nil + default: + return fmt.Errorf("pq: unknown response for startup: %q", t) + } + } +} + +func (cn *conn) auth(r *readBuf, cfg Config) error { + switch code := proto.AuthCode(r.int32()); code { + default: + return fmt.Errorf("pq: unknown authentication response: %s", code) + case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: + return fmt.Errorf("pq: unsupported authentication method: %s", code) + case proto.AuthReqOk: + return nil + + case proto.AuthReqPassword: + w := cn.writeBuf(proto.PasswordMessage) + w.string(cfg.Password) + // Don't need to check AuthOk response here; auth() is called in a loop, + // which catches the errors and AuthReqOk responses. + return cn.send(w) + + case proto.AuthReqMD5: + s := string(r.next(4)) + w := cn.writeBuf(proto.PasswordMessage) + w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) + // Same here. + return cn.send(w) + + case proto.AuthReqGSS: // GSSAPI, startup + if newGss == nil { + return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)") + } + cli, err := newGss() + if err != nil { + return fmt.Errorf("pq: kerberos error: %w", err) + } + + var token []byte + if cfg.isset("krbspn") { + // Use the supplied SPN if provided.. + token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if cfg.isset("krbsrvname") { + service = cfg.KrbSrvname + } + token, err = cli.GetInitToken(cfg.Host, service) + } + if err != nil { + return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err) + } + + w := cn.writeBuf(proto.GSSResponse) + w.bytes(token) + err = cn.send(w) + if err != nil { + return err + } + + // Store for GSSAPI continue message + cn.gss = cli + return nil + + case proto.AuthReqGSSCont: // GSSAPI continue + if cn.gss == nil { + return errors.New("pq: GSSAPI protocol error") + } + + done, tokOut, err := cn.gss.Continue([]byte(*r)) + if err == nil && !done { + w := cn.writeBuf(proto.SASLInitialResponse) + w.bytes(tokOut) + err = cn.send(w) + if err != nil { + return err + } + } + + // Errors fall through and read the more detailed message from the + // server. + return nil + + case proto.AuthReqSASL: + sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) + sc.Step(nil) + if sc.Err() != nil { + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) + } + scOut := sc.Out() + + w := cn.writeBuf(proto.SASLResponse) + w.string("SCRAM-SHA-256") + w.int32(len(scOut)) + w.bytes(scOut) + err := cn.send(w) + if err != nil { + return err + } + + t, r, err := cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) + } + + if r.int32() != int(proto.AuthReqSASLCont) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) + } + + nextStep := r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) + } + + scOut = sc.Out() + w = cn.writeBuf(proto.SASLResponse) + w.bytes(scOut) + err = cn.send(w) + if err != nil { + return err + } + + t, r, err = cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) + } + + if r.int32() != int(proto.AuthReqSASLFin) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) + } + + nextStep = r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) + } + + return nil + } +} + +// parseComplete parses the "command tag" from a CommandComplete message, and +// returns the number of rows affected (if applicable) and a string identifying +// only the command that was executed, e.g. "ALTER TABLE". Returns an error if +// the command can cannot be parsed. +func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) { + commandsWithAffectedRows := []string{ + "SELECT ", + // INSERT is handled below + "UPDATE ", + "DELETE ", + "FETCH ", + "MOVE ", + "COPY ", + } + + var affectedRows *string + for _, tag := range commandsWithAffectedRows { + if strings.HasPrefix(commandTag, tag) { + t := commandTag[len(tag):] + affectedRows = &t + commandTag = tag[:len(tag)-1] + break + } + } + // INSERT also includes the oid of the inserted row in its command tag. Oids + // in user tables are deprecated, and the oid is only returned when exactly + // one row is inserted, so it's unlikely to be of value to any real-world + // application and we can ignore it. + if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { + parts := strings.Split(commandTag, " ") + if len(parts) != 3 { + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag) + } + affectedRows = &parts[len(parts)-1] + commandTag = "INSERT" + } + // There should be no affected rows attached to the tag, just return it + if affectedRows == nil { + return driver.RowsAffected(0), commandTag, nil + } + n, err := strconv.ParseInt(*affectedRows, 10, 64) + if err != nil { + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err) + } + return driver.RowsAffected(n), commandTag, nil +} + +func md5s(s string) string { + h := md5.New() + h.Write([]byte(s)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error { + // Do one pass over the parameters to see if we're going to send any of them + // over in binary. If we are, create a paramFormats array at the same time. + var paramFormats []int + for i, x := range args { + _, ok := x.Value.([]byte) + if ok { + if paramFormats == nil { + paramFormats = make([]int, len(args)) + } + paramFormats[i] = 1 + } + } + if paramFormats == nil { + b.int16(0) + } else { + b.int16(len(paramFormats)) + for _, x := range paramFormats { + b.int16(x) + } + } + + b.int16(len(args)) + for _, x := range args { + if x.Value == nil { + b.int32(-1) + } else if xx, ok := x.Value.([]byte); ok && xx == nil { + b.int32(-1) + } else { + datum, err := binaryEncode(x.Value) + if err != nil { + return err + } + b.int32(len(datum)) + b.bytes(datum) + } + } + return nil +} + +func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error { + if len(args) >= 65536 { + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + } + + b := cn.writeBuf(proto.Parse) + b.byte(0) // unnamed statement + b.string(query) + b.int16(0) + + b.next(proto.Bind) + b.int16(0) // unnamed portal and statement + err := cn.sendBinaryParameters(b, args) + if err != nil { + return err + } + b.bytes(colFmtDataAllText) + + b.next(proto.Describe) + b.byte(proto.Parse) + b.byte(0) // unnamed portal + + b.next(proto.Execute) + b.byte(0) + b.int32(0) + + b.next(proto.Sync) + return cn.send(b) +} + +func (cn *conn) processParameterStatus(r *readBuf) { + switch r.string() { + default: + // ignore + case "server_version": + var major1, major2 int + _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) + if err == nil { + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + } + case "TimeZone": + switch tz := r.string(); tz { + case "UTC", "Etc/UTC", "Etc/Universal", "Etc/Zulu", "Etc/UCT": + cn.parameterStatus.currentLocation = time.UTC + default: + var err error + cn.parameterStatus.currentLocation, err = time.LoadLocation(tz) + if err != nil { + cn.parameterStatus.currentLocation = nil + } + } + // Use sql.NullBool so we can distinguish between false and not sent. If + // it's not sent we use a query to get the value – I don't know when these + // parameters are not sent, but this is what libpq does. + case "in_hot_standby": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b} + } + case "default_transaction_read_only": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b} + } + } +} + +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) +} + +func (cn *conn) readReadyForQuery() error { + t, r, err := cn.recv1() + if err != nil { + return err + } + switch t { + case proto.ReadyForQuery: + cn.processReadyForQuery(r) + return nil + case proto.ErrorResponse: + err := parseError(r, "") + cn.err.set(driver.ErrBadConn) + return err + default: + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t) + } +} + +func (cn *conn) readParseResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } + switch t { + case proto.ParseComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err + default: + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected Parse response %q", t) + } +} + +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) { + for { + t, r, err := cn.recv1() + if err != nil { + return nil, nil, nil, err + } + switch t { + case proto.ParameterDescription: + nparams := r.int16() + paramTyps = make([]oid.Oid, nparams) + for i := range paramTyps { + paramTyps[i] = r.oid() + } + case proto.NoData: + return paramTyps, nil, nil, nil + case proto.RowDescription: + colNames, colTyps = parseStatementRowDescribe(r) + return paramTyps, colNames, colTyps, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return nil, nil, nil, err + default: + cn.err.set(driver.ErrBadConn) + return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t) + } + } +} + +func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) { + t, r, err := cn.recv1() + if err != nil { + return rowsHeader{}, err + } + switch t { + case proto.RowDescription: + return parsePortalRowDescribe(r), nil + case proto.NoData: + return rowsHeader{}, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return rowsHeader{}, err + default: + cn.err.set(driver.ErrBadConn) + return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t) + } +} + +func (cn *conn) readBindResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } + switch t { + case proto.BindComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err + default: + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected Bind response %q", t) + } +} + +func (cn *conn) postExecuteWorkaround() error { + // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores + // any errors from rows.Next, which masks errors that happened during the + // execution of the query. To avoid the problem in common cases, we wait + // here for one more message from the database. If it's not an error the + // query will likely succeed (or perhaps has already, if it's a + // CommandComplete), so we push the message into the conn struct; recv1 + // will return it as the next message for rows.Next or rows.Close. + // However, if it's an error, we wait until ReadyForQuery and then return + // the error to our caller. + for { + t, r, err := cn.recv1() + if err != nil { + return err + } + switch t { + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err + case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse: + // the query didn't fail, but we can't process this message + return cn.saveMessage(t, r) + default: + cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected message during extended query execution: %q", t) + } + } +} + +// Only for Exec(), since we ignore the returned data +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) { + for { + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } + switch t { + case proto.CommandComplete: + if resErr != nil { + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr) + } + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: + cn.processReadyForQuery(r) + if res == nil && resErr == nil { + resErr = errUnexpectedReady + } + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, "") + case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse: + if resErr != nil { + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr) + } + if t == proto.EmptyQueryResponse { + res = emptyRows + } + // ignore any results + default: + cn.err.set(driver.ErrBadConn) + return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t) + } + } +} + +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { + n := r.int16() + colNames = make([]string, n) + colTyps = make([]fieldDesc, n) + for i := range colNames { + colNames[i] = r.string() + r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() + // format code not known when describing a statement; always 0 + r.next(2) + } + return +} + +func parsePortalRowDescribe(r *readBuf) rowsHeader { + n := r.int16() + colNames := make([]string, n) + colFmts := make([]format, n) + colTyps := make([]fieldDesc, n) + for i := range colNames { + colNames[i] = r.string() + r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() + colFmts[i] = format(r.int16()) + } + return rowsHeader{ + colNames: colNames, + colFmts: colFmts, + colTyps: colTyps, + } +} + +func (cn *conn) ResetSession(ctx context.Context) error { + // Ensure bad connections are reported: From database/sql/driver: + // If a connection is never returned to the connection pool but immediately reused, then + // ResetSession is called prior to reuse but IsValid is not called. + return cn.err.get() +} + +func (cn *conn) IsValid() bool { + return cn.err.get() == nil +} diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go new file mode 100644 index 00000000..16de38eb --- /dev/null +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -0,0 +1,226 @@ +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "time" + + "github.com/lib/pq/internal/proto" +) + +const watchCancelDialContextTimeout = 10 * time.Second + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + finish := cn.watchCancel(ctx) + r, err := cn.query(query, args) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnPrepareContext" interface +func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + return cn.Prepare(query) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) Ping(ctx context.Context) error { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + rows, err := cn.simpleQuery(";") + if err != nil { + return driver.ErrBadConn + } + _ = rows.Close() + return nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}, 1) + go func() { + select { + case <-done: + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.err.set(ctx.Err()) + + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) + defer cancel() + + _ = cn.cancel(ctxCancel) + case <-finished: + } + }() + return func() { + select { + case <-finished: + cn.err.set(ctx.Err()) + _ = cn.Close() + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel(ctx context.Context) error { + // Use a copy since a new connection is created here. This is necessary + // because cancel is called from a goroutine in watchCancel. + cfg := cn.cfg.Clone() + + c, err := dial(ctx, cn.dialer, cfg) + if err != nil { + return err + } + defer func() { _ = c.Close() }() + + cn2 := conn{c: c} + err = cn2.ssl(cfg, cfg.SSLMode) + if err != nil { + return err + } + + w := cn2.writeBuf(0) + w.int32(proto.CancelRequestCode) + w.int32(cn.pid) + w.bytes(cn.secretKey) + if err := cn2.sendStartupPacket(w); err != nil { + return err + } + + // Read until EOF to ensure that the server received the cancel. + _, err = io.Copy(io.Discard, c) + return err +} + +// Implement the "StmtQueryContext" interface +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + finish := st.watchCancel(ctx) + r, err := st.query(args) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "StmtExecContext" interface +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if finish := st.watchCancel(ctx); finish != nil { + defer finish() + } + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(args) + if err != nil { + return nil, st.cn.handleError(err) + } + res, _, err := st.cn.readExecuteResponse("simple query") + return res, st.cn.handleError(err) +} + +// watchCancel is implemented on stmt in order to not mark the parent conn as bad +func (st *stmt) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + // At this point the function level context is canceled, so it + // must not be used for the additional network request to cancel + // the query. Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) + defer cancel() + + _ = st.cancel(ctxCancel) + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (st *stmt) cancel(ctx context.Context) error { + return st.cn.cancel(ctx) +} diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go new file mode 100644 index 00000000..27a268ec --- /dev/null +++ b/vendor/github.com/lib/pq/connector.go @@ -0,0 +1,1157 @@ +package pq + +import ( + "context" + "crypto/tls" + "database/sql/driver" + "fmt" + "math/rand" + "net" + "net/netip" + neturl "net/url" + "os" + "path/filepath" + "reflect" + "runtime" + "slices" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "github.com/lib/pq/internal/pgservice" + "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" +) + +type ( + // SSLMode is a sslmode setting. + SSLMode string + + // SSLNegotiation is a sslnegotiation setting. + SSLNegotiation string + + // TargetSessionAttrs is a target_session_attrs setting. + TargetSessionAttrs string + + // LoadBalanceHosts is a load_balance_hosts setting. + LoadBalanceHosts string + + // ProtocolVersion is a min_protocol_version or max_protocol_version + // setting. + ProtocolVersion string + + // SSLProtocolVersion is a ssl_min_protocol_version or + // ssl_max_protocol_version setting. + SSLProtocolVersion string +) + +// Values for [SSLMode] that pq supports. +const ( + // No SSL + SSLModeDisable = SSLMode("disable") + + // First try a non-SSL connection and if that fails try an SSL connection. + SSLModeAllow = SSLMode("allow") + + // First try an SSL connection and if that fails try a non-SSL connection. + SSLModePrefer = SSLMode("prefer") + + // Require SSL, but skip verification. This is the default. + SSLModeRequire = SSLMode("require") + + // Require SSL and verify that the certificate was signed by a trusted CA. + SSLModeVerifyCA = SSLMode("verify-ca") + + // Require SSL and verify that the certificate was signed by a trusted CA + // and the server host name matches the one in the certificate. + SSLModeVerifyFull = SSLMode("verify-full") +) + +var sslModes = []SSLMode{SSLModeDisable, SSLModeAllow, SSLModePrefer, SSLModeRequire, + SSLModeVerifyFull, SSLModeVerifyCA} + +func (s SSLMode) useSSL() bool { + switch s { + case SSLModePrefer, SSLModeRequire, SSLModeVerifyCA, SSLModeVerifyFull: + return true + } + return false +} + +// Values for [SSLNegotiation] that pq supports. +const ( + // Negotiate whether SSL should be used. This is the default. + SSLNegotiationPostgres = SSLNegotiation("postgres") + + // Always use SSL, don't try to negotiate. + SSLNegotiationDirect = SSLNegotiation("direct") +) + +var sslNegotiations = []SSLNegotiation{SSLNegotiationPostgres, SSLNegotiationDirect} + +// Values for [TargetSessionAttrs] that pq supports. +const ( + // Any successful connection is acceptable. This is the default. + TargetSessionAttrsAny = TargetSessionAttrs("any") + + // Session must accept read-write transactions by default: the server must + // not be in hot standby mode and default_transaction_read_only must be + // off. + TargetSessionAttrsReadWrite = TargetSessionAttrs("read-write") + + // Session must not accept read-write transactions by default. + TargetSessionAttrsReadOnly = TargetSessionAttrs("read-only") + + // Server must not be in hot standby mode. + TargetSessionAttrsPrimary = TargetSessionAttrs("primary") + + // Server must be in hot standby mode. + TargetSessionAttrsStandby = TargetSessionAttrs("standby") + + // First try to find a standby server, but if none of the listed hosts is a + // standby server, try again in any mode. + TargetSessionAttrsPreferStandby = TargetSessionAttrs("prefer-standby") +) + +var targetSessionAttrs = []TargetSessionAttrs{TargetSessionAttrsAny, + TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly, TargetSessionAttrsPrimary, + TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby} + +// Values for [LoadBalanceHosts] that pq supports. +const ( + // Don't load balance; try hosts in the order in which they're provided. + // This is the default. + LoadBalanceHostsDisable = LoadBalanceHosts("disable") + + // Hosts are tried in random order to balance connections across multiple + // PostgreSQL servers. + // + // When using this value it's recommended to also configure a reasonable + // value for connect_timeout. Because then, if one of the nodes that are + // used for load balancing is not responding, a new node will be tried. + LoadBalanceHostsRandom = LoadBalanceHosts("random") +) + +var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom} + +// Values for [ProtocolVersion] that pq supports. +const ( + // ProtocolVersion30 is the default protocol version, supported in + // PostgreSQL 3.0 and newer. + ProtocolVersion30 = ProtocolVersion("3.0") + + // ProtocolVersion32 uses a longer secret key length for query cancellation, + // supported in PostgreSQL 18 and newer. + ProtocolVersion32 = ProtocolVersion("3.2") + + // ProtocolVersionLatest is the latest protocol version that pq supports + // (which may not be supported by the server). + ProtocolVersionLatest = ProtocolVersion("latest") +) + +var protocolVersions = []ProtocolVersion{ProtocolVersion30, ProtocolVersion32, ProtocolVersionLatest} + +// Values for [SSLProtocolVersion] that pq supports. +const ( + SSLProtocolVersionTLS10 = SSLProtocolVersion("TLSv1.0") + SSLProtocolVersionTLS11 = SSLProtocolVersion("TLSv1.1") + SSLProtocolVersionTLS12 = SSLProtocolVersion("TLSv1.2") + SSLProtocolVersionTLS13 = SSLProtocolVersion("TLSv1.3") +) + +var sslProtocolVersions = []SSLProtocolVersion{SSLProtocolVersionTLS10, SSLProtocolVersionTLS11, + SSLProtocolVersionTLS12, SSLProtocolVersionTLS13} + +func (s SSLProtocolVersion) tlsconf() uint16 { + switch s { + case SSLProtocolVersionTLS10: + return tls.VersionTLS10 + case SSLProtocolVersionTLS11: + return tls.VersionTLS11 + case SSLProtocolVersionTLS12: + return tls.VersionTLS12 + case SSLProtocolVersionTLS13: + return tls.VersionTLS13 + default: + return 0 + } +} + +// Connector represents a fixed configuration for the pq driver with a given +// dsn. Connector satisfies the [database/sql/driver.Connector] interface and +// can be used to create any number of DB Conn's via [sql.OpenDB]. +type Connector struct { + cfg Config + dialer Dialer +} + +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// [sql.OpenDB]. +func NewConnector(dsn string) (*Connector, error) { + cfg, err := NewConfig(dsn) + if err != nil { + return nil, err + } + return NewConnectorConfig(cfg) +} + +// NewConnectorConfig returns a connector for the pq driver in a fixed +// configuration with the given [Config]. The returned connector can be used to +// create any number of equivalent Conn's. The returned connector is intended to +// be used with [sql.OpenDB]. +func NewConnectorConfig(cfg Config) (*Connector, error) { + return &Connector{cfg: cfg, dialer: defaultDialer{}}, nil +} + +// Connect returns a connection to the database using the fixed configuration of +// this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } + +// Dialer allows change the dialer used to open connections. +func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } + +// Driver returns the underlying driver of this Connector. +func (c *Connector) Driver() driver.Driver { return &Driver{} } + +func (p ProtocolVersion) proto() int { + switch p { + default: + return proto.ProtocolVersion30 + case ProtocolVersion32, ProtocolVersionLatest: + return proto.ProtocolVersion32 + } +} + +// Config holds options pq supports when connecting to PostgreSQL. +// +// The postgres struct tag is used for the value from the DSN (e.g. +// "dbname=abc"), and the env struct tag is used for the environment variable +// (e.g. "PGDATABASE=abc") +type Config struct { + // The host to connect to. Absolute paths and values that start with @ are + // for unix domain sockets. Defaults to localhost. + // + // A comma-separated list of host names is also accepted, in which case each + // host name in the list is tried in order or randomly if load_balance_hosts + // is set. An empty item selects the default of localhost. The + // target_session_attrs option controls properties the host must have to be + // considered acceptable. + Host string `postgres:"host" env:"PGHOST"` + + // IPv4 or IPv6 address to connect to. Using hostaddr allows the application + // to avoid a host name lookup, which might be important in applications + // with time constraints. A hostname is required for sslmode=verify-full and + // the GSSAPI or SSPI authentication methods. + // + // The following rules are used: + // + // - If host is given without hostaddr, a host name lookup occurs. + // + // - If hostaddr is given without host, the value for hostaddr gives the + // server network address. The connection attempt will fail if the + // authentication method requires a host name. + // + // - If both host and hostaddr are given, the value for hostaddr gives the + // server network address. The value for host is ignored unless the + // authentication method requires it, in which case it will be used as the + // host name. + // + // A comma-separated list of hostaddr values is also accepted, in which case + // each host in the list is tried in order or randonly if load_balance_hosts + // is set. An empty item causes the corresponding host name to be used, or + // the default host name if that is empty as well. The target_session_attrs + // option controls properties the host must have to be considered + // acceptable. + Hostaddr netip.Addr `postgres:"hostaddr" env:"PGHOSTADDR"` + + // The port to connect to. Defaults to 5432. + // + // If multiple hosts were given in the host or hostaddr parameters, this + // parameter may specify a comma-separated list of ports of the same length + // as the host list, or it may specify a single port number to be used for + // all hosts. An empty string, or an empty item in a comma-separated list, + // specifies the default of 5432. + Port uint16 `postgres:"port" env:"PGPORT"` + + // The name of the database to connect to. + Database string `postgres:"dbname" env:"PGDATABASE"` + + // The user to sign in as. Defaults to the current user. + User string `postgres:"user" env:"PGUSER"` + + // The user's password. + Password string `postgres:"password" env:"PGPASSWORD"` + + // Path to [pgpass] file to store passwords; overrides Password. + // + // [pgpass]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + Passfile string `postgres:"passfile" env:"PGPASSFILE"` + + // Commandline options to send to the server at connection start. + Options string `postgres:"options" env:"PGOPTIONS"` + + // Application name, displayed in pg_stat_activity and log entries. + ApplicationName string `postgres:"application_name" env:"PGAPPNAME"` + + // Used if application_name is not given. Specifying a fallback name is + // useful in generic utility programs that wish to set a default application + // name but allow it to be overridden by the user. + FallbackApplicationName string `postgres:"fallback_application_name" env:"-"` + + // Whether to use SSL. Defaults to "require" (different from libpq's default + // of "prefer"). + // + // [RegisterTLSConfig] can be used to registers a custom [tls.Config], which + // can be used by setting sslmode=pqgo-«key» in the connection string. + SSLMode SSLMode `postgres:"sslmode" env:"PGSSLMODE"` + + // When set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only). + SSLNegotiation SSLNegotiation `postgres:"sslnegotiation" env:"PGSSLNEGOTIATION"` + + // Path to client SSL certificate. The file must contain PEM encoded data. + // + // Defaults to ~/.postgresql/postgresql.crt + SSLCert string `postgres:"sslcert" env:"PGSSLCERT"` + + // Path to secret key for sslcert. The file must contain PEM encoded data. + // + // Defaults to ~/.postgresql/postgresql.key + SSLKey string `postgres:"sslkey" env:"PGSSLKEY"` + + // Path to root certificate. The file must contain PEM encoded data. + // + // The special value "system" can be used to load the system's root + // certificates ([x509.SystemCertPool]). This will change the default + // sslmode to verify-full and issue an error if a lower setting is used – as + // anyone can register a valid certificate hostname verification becomes + // essential. + // + // Defaults to ~/.postgresql/root.crt. + SSLRootCert string `postgres:"sslrootcert" env:"PGSSLROOTCERT"` + + // By default SNI is on, any value which is not starting with "1" disables + // SNI. + SSLSNI bool `postgres:"sslsni" env:"PGSSLSNI"` + + // Minimum SSL/TLS protocol version to allow for the connection. + // + // The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at + // the time of writing. + SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"` + + // Maximum SSL/TLS protocol version to allow for the connection. If not set, + // this parameter is ignored and the connection will use the maximum bound + // defined by the backend, if set. Setting the maximum protocol version is + // mainly useful for testing or if some component has issues working with a + // newer protocol. + SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"` + + // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a + // PEM file. This is a pq extension, not supported in libpq. + SSLInline bool `postgres:"sslinline" env:"-"` + + // GSS (Kerberos) service name when constructing the SPN (default is + // postgres). This will be combined with the host to form the full SPN: + // krbsrvname/host. + KrbSrvname string `postgres:"krbsrvname" env:"PGKRBSRVNAME"` + + // GSS (Kerberos) SPN. This takes priority over krbsrvname if present. This + // is a pq extension, not supported in libpq. + KrbSpn string `postgres:"krbspn" env:"-"` + + // Maximum time to wait while connecting, in seconds. Zero, negative, or not + // specified means wait indefinitely + ConnectTimeout time.Duration `postgres:"connect_timeout" env:"PGCONNECT_TIMEOUT"` + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. This is a pq extension, not + // supported in libpq. + BinaryParameters bool `postgres:"binary_parameters" env:"-"` + + // This connection should never use the binary format when receiving query + // results from prepared statements. Only provided for debugging. This is a + // pq extension, not supported in libpq. + DisablePreparedBinaryResult bool `postgres:"disable_prepared_binary_result" env:"-"` + + // Client encoding; pq only supports UTF8 and this must be blank or "UTF8". + ClientEncoding string `postgres:"client_encoding" env:"PGCLIENTENCODING"` + + // Date/time representation to use; pq only supports "ISO, MDY" and this + // must be blank or "ISO, MDY". + Datestyle string `postgres:"datestyle" env:"PGDATESTYLE"` + + // Default time zone. + TZ string `postgres:"tz" env:"PGTZ"` + + // Default mode for the genetic query optimizer. + Geqo string `postgres:"geqo" env:"PGGEQO"` + + // Determine whether the session must have certain properties to be + // acceptable. It's typically used in combination with multiple host names + // to select the first acceptable alternative among several hosts. + TargetSessionAttrs TargetSessionAttrs `postgres:"target_session_attrs" env:"PGTARGETSESSIONATTRS"` + + // Controls the order in which the client tries to connect to the available + // hosts. Once a connection attempt is successful no other hosts will be + // tried. This parameter is typically used in combination with multiple host + // names. + // + // This parameter can be used in combination with target_session_attrs to, + // for example, load balance over standby servers only. Once successfully + // connected, subsequent queries on the returned connection will all be sent + // to the same server. + LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"` + + // Minimum acceptable PostgreSQL protocol version. If the server does not + // support at least this version, the connection will fail. Defaults to + // "3.0". + MinProtocolVersion ProtocolVersion `postgres:"min_protocol_version" env:"PGMINPROTOCOLVERSION"` + + // Maximum PostgreSQL protocol version to request from the server. Defaults to "3.0". + MaxProtocolVersion ProtocolVersion `postgres:"max_protocol_version" env:"PGMAXPROTOCOLVERSION"` + + // Load connection parameters from the service file at ~/.pg_service.conf + // (which can be configured with PGSERVICEFILE). + // + // The service file is a INI-like file to configure connection parameters: + // + // [servicename] + // # Comment + // dbname=foo + // + // Unlike libpq, this does not look at the system-wide service file, as the + // location of this is a compile-time value that is not easy for pq to + // retrieve. + Service string `postgres:"service" env:"PGSERVICE"` + + // Path to connection service file. Defaults to ~/.pg_service.conf. + ServiceFile string `postgres:"-" env:"PGSERVICEFILE"` + + // Runtime parameters: any unrecognized parameter in the DSN will be added + // to this and sent to PostgreSQL during startup. + Runtime map[string]string `postgres:"-" env:"-"` + + // Multi contains additional connection details. The first value is + // available in [Config.Host], [Config.Hostaddr], and [Config.Port], and + // additional ones (if any) are available here. + Multi []ConfigMultihost + + // Record which parameters were given, so we can distinguish between an + // empty string "not given at all". + // + // The alternative is to use pointers or sql.Null[..], but that's more + // awkward to use. + set []string `env:"set"` + + multiHost []string + multiHostaddr []netip.Addr + multiPort []uint16 +} + +// ConfigMultihost specifies an additional server to try to connect to. +type ConfigMultihost struct { + Host string + Hostaddr netip.Addr + Port uint16 +} + +// NewConfig creates a new [Config] from the defaults, environment, service +// file, and DSN, in that order. That is: a service overrides any value from the +// environment, which in turn gets overridden by the same parameter in the +// connection string. +// +// Most connection parameters supported by PostgreSQL are supported; see the +// [Config] struct for supported parameters. pq also lets you specify any +// [run-time parameter] such as search_path or work_mem in the connection +// string. This is different from libpq, which uses the "options" parameter for +// this (which also works in pq). +// +// # key=value connection strings +// +// For key=value strings, use single quotes for values that contain whitespace +// or empty values. A backslash will escape the next character: +// +// "user=pqgo password='with spaces'" +// "user=''" +// "user=space\ man password='it\'s valid'" +// +// # URL connection strings +// +// pq supports URL-style postgres:// or postgresql:// connection strings in the +// form: +// +// postgres[ql]://[user[:pwd]@][net-location][:port][/dbname][?param1=value1&...] +// +// Go's [net/url.Parse] is more strict than PostgreSQL's URL parser and will +// (correctly) reject %2F in the host part. This means that unix-socket URLs: +// +// postgres://[user[:pwd]@][unix-socket][:port[/dbname]][?param1=value1&...] +// postgres://%2Ftmp%2Fpostgres/db +// +// will not work. You will need to use "host=/tmp/postgres dbname=db". +// +// Similarly, multiple ports also won't work, but ?port= will: +// +// postgres://host1,host2:5432,6543/dbname Doesn't work +// postgres://host1,host2/dbname?port=5432,6543 Works +// +// # Environment +// +// Most [PostgreSQL environment variables] are supported by pq. Environment +// variables have a lower precedence than explicitly provided connection +// parameters. pq will return an error if environment variables it does not +// support are set. Environment variables have a lower precedence than +// explicitly provided connection parameters. +// +// [PostgreSQL environment variables]: http://www.postgresql.org/docs/current/static/libpq-envars.html +// [run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html +func NewConfig(dsn string) (Config, error) { + return newConfig(dsn, os.Environ()) +} + +// Clone returns a copy of the [Config]. +func (cfg Config) Clone() Config { + rt := make(map[string]string) + for k, v := range cfg.Runtime { + rt[k] = v + } + c := cfg + c.Runtime = rt + c.set = append([]string{}, cfg.set...) + return c +} + +// hosts returns a slice of copies of this config, one for each host. +func (cfg Config) hosts() []Config { + cfgs := make([]Config, 1, len(cfg.Multi)+1) + cfgs[0] = cfg.Clone() + for _, m := range cfg.Multi { + c := cfg.Clone() + c.Host, c.Hostaddr, c.Port = m.Host, m.Hostaddr, m.Port + cfgs = append(cfgs, c) + } + + if cfg.LoadBalanceHosts == LoadBalanceHostsRandom { + rand.Shuffle(len(cfgs), func(i, j int) { cfgs[i], cfgs[j] = cfgs[j], cfgs[i] }) + } + + return cfgs +} + +func newConfig(dsn string, env []string) (Config, error) { + cfg := Config{ + Host: "localhost", + Port: 5432, + SSLSNI: true, + SSLMode: SSLModePrefer, + MinProtocolVersion: "3.0", + MaxProtocolVersion: "3.0", + } + if err := cfg.fromEnv(env); err != nil { + return Config{}, err + } + if err := cfg.fromDSN(dsn); err != nil { + return Config{}, err + } + if err := cfg.fromService(); err != nil { + return Config{}, err + } + + // Need to have exactly the same number of host and hostaddr, or only specify one. + if cfg.isset("host") && cfg.Host != "" && cfg.Hostaddr != (netip.Addr{}) && len(cfg.multiHost) != len(cfg.multiHostaddr) { + return Config{}, fmt.Errorf("pq: could not match %d host names to %d hostaddr values", + len(cfg.multiHost)+1, len(cfg.multiHostaddr)+1) + } + // Need one port that applies to all or exactly the same number of ports as hosts. + l, ll := max(len(cfg.multiHost), len(cfg.multiHostaddr)), len(cfg.multiPort) + if l > 0 && ll > 0 && l != ll { + return Config{}, fmt.Errorf("pq: could not match %d port numbers to %d hosts", ll+1, l+1) + } + + // Populate Multi + if len(cfg.multiHostaddr) > len(cfg.multiHost) { + cfg.multiHost = make([]string, len(cfg.multiHostaddr)) + } + for i, h := range cfg.multiHost { + p := cfg.Port + if len(cfg.multiPort) > 0 { + p = cfg.multiPort[i] + } + var addr netip.Addr + if len(cfg.multiHostaddr) > 0 { + addr = cfg.multiHostaddr[i] + } + cfg.Multi = append(cfg.Multi, ConfigMultihost{ + Host: h, + Port: p, + Hostaddr: addr, + }) + } + + // Use the "fallback" application name if necessary + if cfg.isset("fallback_application_name") && !cfg.isset("application_name") { + cfg.ApplicationName = cfg.FallbackApplicationName + } + + // We can't work with any client_encoding other than UTF-8 currently. + // However, we have historically allowed the user to set it to UTF-8 + // explicitly, and there's no reason to break such programs, so allow that. + // Note that the "options" setting could also set client_encoding, but + // parsing its value is not worth it. Instead, we always explicitly send + // client_encoding as a separate run-time parameter, which should override + // anything set in options. + if cfg.isset("client_encoding") && !isUTF8(cfg.ClientEncoding) { + return Config{}, fmt.Errorf(`pq: unsupported client_encoding %q: must be absent or "UTF8"`, cfg.ClientEncoding) + } + // DateStyle needs a similar treatment. + if cfg.isset("datestyle") && cfg.Datestyle != "ISO, MDY" { + return Config{}, fmt.Errorf(`pq: unsupported datestyle %q: must be absent or "ISO, MDY"`, cfg.Datestyle) + } + cfg.ClientEncoding, cfg.Datestyle = "UTF8", "ISO, MDY" + + // Set default user if not explicitly provided. + if !cfg.isset("user") { + u, err := pqutil.User() + if err != nil { + return Config{}, err + } + cfg.User = u + } + + // SSL is not necessary or supported over UNIX domain sockets. + if nw, _ := cfg.network(); nw == "unix" { + cfg.SSLMode = SSLModeDisable + } + + if cfg.MinProtocolVersion > cfg.MaxProtocolVersion { + return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q", + cfg.MinProtocolVersion, cfg.MaxProtocolVersion) + } + if cfg.SSLNegotiation == SSLNegotiationDirect { + switch cfg.SSLMode { + case SSLModeDisable, SSLModeAllow, SSLModePrefer: + return Config{}, fmt.Errorf( + `pq: weak sslmode %q may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full")`, + cfg.SSLMode) + } + } + if cfg.SSLRootCert == "system" { + if !cfg.isset("sslmode") { + cfg.SSLMode = SSLModeVerifyFull + } + if cfg.SSLMode != SSLModeVerifyFull { + return Config{}, fmt.Errorf( + `pq: weak sslmode %q may not be used with sslrootcert=system (use "verify-full")`, + cfg.SSLMode) + } + } + + return cfg, nil +} + +func (cfg Config) network() (string, string) { + if cfg.Hostaddr != (netip.Addr{}) { + return "tcp", net.JoinHostPort(cfg.Hostaddr.String(), strconv.Itoa(int(cfg.Port))) + } + // UNIX domain sockets are either represented by an (absolute) file system + // path or they live in the abstract name space (starting with an @). + if filepath.IsAbs(cfg.Host) || strings.HasPrefix(cfg.Host, "@") { + sockPath := filepath.Join(cfg.Host, ".s.PGSQL."+strconv.Itoa(int(cfg.Port))) + return "unix", sockPath + } + return "tcp", net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port))) +} + +func (cfg *Config) fromEnv(env []string) error { + e := make(map[string]string) + for _, v := range env { + k, v, ok := strings.Cut(v, "=") + if !ok { + continue + } + switch k { + case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated. + "PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff + "PGREQUIREAUTH", "PGCHANNELBINDING", + "PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER": + return fmt.Errorf("pq: environment variable $%s is not supported", k) + case "PGKRBSRVNAME": + if newGss == nil { + return fmt.Errorf("pq: environment variable $%s is not supported as Kerberos is not enabled", k) + } + } + e[k] = v + } + return cfg.setFromTag(e, "env", false) +} + +// parseOpts parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func (cfg *Config) fromDSN(dsn string) error { + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + var err error + dsn, err = convertURL(dsn) + if err != nil { + return err + } + } + + var ( + opt = make(map[string]string) + s = []rune(dsn) + i int + next = func() (rune, bool) { + if i >= len(s) { + return 0, false + } + r := s[i] + i++ + return r, true + } + skipSpaces = func() (rune, bool) { + r, ok := next() + for unicode.IsSpace(r) && ok { + r, ok = next() + } + return r, ok + } + ) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = skipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = skipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = skipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + opt[string(keyRunes)] = "" + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + opt[string(keyRunes)] = string(valRunes) + } + + return cfg.setFromTag(opt, "postgres", false) +} + +func (cfg *Config) fromService() error { + if cfg.Service == "" { + return nil + } + + if !cfg.isset("PGSERVICEFILE") { + if home := pqutil.Home(); home != "" { + if runtime.GOOS != "windows" { + home = filepath.Dir(home) // Unlike other files this uses ~/ and not ~/.postgresql + } + cfg.ServiceFile = filepath.Join(home, ".pg_service.conf") + } + } + + opts, err := pgservice.FindService(cfg.ServiceFile, cfg.Service) + if err != nil { + return fmt.Errorf("pq: %w", err) + } + return cfg.setFromTag(opts, "postgres", true) +} + +func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) error { + f := "pq: wrong value for %q: " + if tag == "env" { + f = "pq: wrong value for $%s: " + } + var ( + types = reflect.TypeOf(cfg).Elem() + values = reflect.ValueOf(cfg).Elem() + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get(tag) + connectTimeout = (tag == "postgres" && k == "connect_timeout") || (tag == "env" && k == "PGCONNECT_TIMEOUT") + host = (tag == "postgres" && k == "host") || (tag == "env" && k == "PGHOST") + hostaddr = (tag == "postgres" && k == "hostaddr") || (tag == "env" && k == "PGHOSTADDR") + port = (tag == "postgres" && k == "port") || (tag == "env" && k == "PGPORT") + sslmode = (tag == "postgres" && k == "sslmode") || (tag == "env" && k == "PGSSLMODE") + sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION") + targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS") + loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") + minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION") + maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION") + sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION") + sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION") + ) + if k == "" || k == "-" { + continue + } + + v, ok := o[k] + delete(o, k) + if ok { + t, ok := rt.Tag.Lookup("postgres") + if !ok || t == "" || t == "-" { // For PGSERVICEFILE, which can only be from env + t, ok = rt.Tag.Lookup("env") + } + if ok && t != "" && t != "-" { + cfg.set = append(cfg.set, t) + } + switch rt.Type.Kind() { + default: + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type.Kind()) + case reflect.Struct: + if rt.Type == reflect.TypeOf(netip.Addr{}) { + if hostaddr { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + cfg.multiHostaddr = append(cfg.multiHostaddr, netip.Addr{}) + } else { + ip, err := netip.ParseAddr(vvv) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiHostaddr = append(cfg.multiHostaddr, ip) + } + } + } + ip, err := netip.ParseAddr(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.Set(reflect.ValueOf(ip)) + } else { + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type) + } + case reflect.String: + if sslmode && !slices.Contains(sslModes, SSLMode(v)) && !(strings.HasPrefix(v, "pqgo-") && hasTLSConfig(v[5:])) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslModes)) + } + if sslnegotiation && !slices.Contains(sslNegotiations, SSLNegotiation(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslNegotiations)) + } + if targetsessionattrs && !slices.Contains(targetSessionAttrs, TargetSessionAttrs(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(targetSessionAttrs)) + } + if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts)) + } + if (minprotocolversion || maxprotocolversion) && !slices.Contains(protocolVersions, ProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(protocolVersions)) + } + if (sslminprotocolversion || sslmaxprotocolversion) && !slices.Contains(sslProtocolVersions, SSLProtocolVersion(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslProtocolVersions)) + } + if host { + vv := strings.Split(v, ",") + v = vv[0] + for i, vvv := range vv[1:] { + if vvv == "" { + vv[i+1] = "localhost" + } + } + cfg.multiHost = append(cfg.multiHost, vv[1:]...) + } + rv.SetString(v) + case reflect.Int64: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + if connectTimeout { + n = int64(time.Duration(n) * time.Second) + } + rv.SetInt(n) + case reflect.Uint16: + if port { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + vvv = "5432" + } + n, err := strconv.ParseUint(vvv, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiPort = append(cfg.multiPort, uint16(n)) + } + } + n, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetUint(n) + case reflect.Bool: + b, err := pqutil.ParseBool(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetBool(b) + } + } + } + + if service && len(o) > 0 { + // TODO(go1.23): use maps.Keys once we require Go 1.23. + var key string + for k := range o { + key = k + break + } + return fmt.Errorf("pq: unknown setting %q in service file for service %q", key, cfg.Service) + } + + // Set run-time; we delete map keys as they're set in the struct. + if !service && tag == "postgres" { + // Make sure database= sets dbname=, as that previously worked (kind of + // by accident). + // TODO(v2): remove + if d, ok := o["database"]; ok { + cfg.Database = d + delete(o, "database") + } + cfg.Runtime = o + } + + return nil +} + +func (cfg Config) isset(name string) bool { + return slices.Contains(cfg.set, name) +} + +// Convert to a map; used only in tests. +func (cfg Config) tomap() map[string]string { + var ( + o = make(map[string]string) + values = reflect.ValueOf(cfg) + types = reflect.TypeOf(cfg) + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get("postgres") + ) + if k == "" || k == "-" { + continue + } + if !rv.IsZero() || slices.Contains(cfg.set, k) { + switch rt.Type.Kind() { + default: + if s, ok := rv.Interface().(fmt.Stringer); ok { + o[k] = s.String() + } else { + o[k] = rv.String() + } + case reflect.Uint16: + n := rv.Uint() + o[k] = strconv.FormatUint(n, 10) + case reflect.Int64: + n := rv.Int() + if k == "connect_timeout" { + n = int64(time.Duration(n) / time.Second) + } + o[k] = strconv.FormatInt(n, 10) + case reflect.Bool: + if rv.Bool() { + o[k] = "yes" + } else { + o[k] = "no" + } + } + } + } + for k, v := range cfg.Runtime { + o[k] = v + } + return o +} + +// Create DSN for this config; used only in tests. +func (cfg Config) string() string { + var ( + m = cfg.tomap() + keys = make([]string, 0, len(m)) + ) + for k := range m { + switch k { + case "datestyle", "client_encoding": + continue + case "host", "port", "user", "sslsni", "sslmode", "min_protocol_version", "max_protocol_version": + if !cfg.isset(k) { + continue + } + } + if k == "host" && len(cfg.multiHost) > 0 { + m[k] += "," + strings.Join(cfg.multiHost, ",") + } + if k == "hostaddr" && len(cfg.multiHostaddr) > 0 { + for _, ha := range cfg.multiHostaddr { + m[k] += "," + if ha != (netip.Addr{}) { + m[k] += ha.String() + } + } + } + if k == "port" && len(cfg.multiPort) > 0 { + for _, p := range cfg.multiPort { + m[k] += "," + strconv.Itoa(int(p)) + } + } + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for i, k := range keys { + if i > 0 { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + var ( + v = m[k] + nv = make([]rune, 0, len(v)+2) + quote = v == "" + ) + for _, c := range v { + if c == ' ' { + quote = true + } + if c == '\'' { + nv = append(nv, '\\') + } + nv = append(nv, c) + } + if quote { + b.WriteByte('\'') + } + b.WriteString(string(nv)) + if quote { + b.WriteByte('\'') + } + } + return b.String() +} + +// Recognize all sorts of silly things as "UTF-8", like Postgres does +func isUTF8(name string) bool { + s := strings.Map(func(c rune) rune { + if 'A' <= c && c <= 'Z' { + return c + ('a' - 'A') + } + if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { + return c + } + return -1 // discard + }, name) + return s == "utf8" || s == "unicode" +} + +func convertURL(url string) (string, error) { + u, err := neturl.Parse(url) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") + } + } + + if u.User != nil { + pw, _ := u.User.Password() + accrue("user", u.User.Username()) + accrue("password", pw) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go new file mode 100644 index 00000000..4c9a8cc7 --- /dev/null +++ b/vendor/github.com/lib/pq/copy.go @@ -0,0 +1,337 @@ +package pq + +import ( + "context" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "sync" + + "github.com/lib/pq/internal/proto" +) + +var ( + errCopyInClosed = errors.New("pq: copyin statement has already been closed") + errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") + errCopyToNotSupported = errors.New("pq: COPY TO is not supported") + errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") +) + +type copyin struct { + cn *conn + buffer []byte + rowData chan []byte + done chan bool + closed bool + mu struct { + sync.Mutex + err error + driver.Result + } +} + +const ( + ciBufferSize = 64 * 1024 + // flush buffer before the buffer is filled up and needs reallocation + ciBufferFlushSize = 63 * 1024 +) + +func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, resErr error) { + if !cn.isInTransaction() { + return nil, errCopyNotSupportedOutsideTxn + } + + ci := ©in{ + cn: cn, + buffer: make([]byte, 0, ciBufferSize), + rowData: make(chan []byte), + done: make(chan bool, 1), + } + // add CopyData identifier + 4 bytes for message length + ci.buffer = append(ci.buffer, byte(proto.CopyDataRequest), 0, 0, 0, 0) + + b := cn.writeBuf(proto.Query) + b.string(q) + err := cn.send(b) + if err != nil { + return nil, err + } + +awaitCopyInResponse: + for { + t, r, err := cn.recv1() + if err != nil { + return nil, err + } + switch t { + case proto.CopyInResponse: + if r.byte() != 0 { + resErr = errBinaryCopyNotSupported + break awaitCopyInResponse + } + go ci.resploop() + return ci, nil + case proto.CopyOutResponse: + resErr = errCopyToNotSupported + break awaitCopyInResponse + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.ReadyForQuery: + if resErr == nil { + ci.setBad(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unexpected ReadyForQuery in response to COPY") + } + cn.processReadyForQuery(r) + return nil, resErr + default: + ci.setBad(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unknown response for copy query: %q", t) + } + } + + // something went wrong, abort COPY before we return + b = cn.writeBuf(proto.CopyFail) + b.string(resErr.Error()) + err = cn.send(b) + if err != nil { + return nil, err + } + + for { + t, r, err := cn.recv1() + if err != nil { + return nil, err + } + + switch t { + case proto.CopyDoneResponse, proto.CommandComplete, proto.ErrorResponse: + case proto.ReadyForQuery: + // correctly aborted, we're done + cn.processReadyForQuery(r) + return nil, resErr + default: + ci.setBad(driver.ErrBadConn) + return nil, fmt.Errorf("pq: unknown response for CopyFail: %q", t) + } + } +} + +func (ci *copyin) flush(buf []byte) error { + if len(buf)-1 > proto.MaxUint32 { + return errors.New("pq: too many columns") + } + // set message length (without message identifier) + binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) + + _, err := ci.cn.c.Write(buf) + return err +} + +func (ci *copyin) resploop() { + for { + var r readBuf + t, err := ci.cn.recvMessage(&r) + if err != nil { + ci.setBad(driver.ErrBadConn) + ci.setError(err) + ci.done <- true + return + } + switch t { + case proto.CommandComplete: + // complete + res, _, err := ci.cn.parseComplete(r.string()) + if err != nil { + panic(err) + } + ci.setResult(res) + case proto.NoticeResponse: + if n := ci.cn.noticeHandler; n != nil { + n(parseError(&r, "")) + } + case proto.ReadyForQuery: + ci.cn.processReadyForQuery(&r) + ci.done <- true + return + case proto.ErrorResponse: + err := parseError(&r, "") + ci.setError(err) + default: + ci.setBad(driver.ErrBadConn) + ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) + ci.done <- true + return + } + } +} + +func (ci *copyin) setBad(err error) { + ci.cn.err.set(err) +} + +func (ci *copyin) getBad() error { + return ci.cn.err.get() +} + +func (ci *copyin) err() error { + ci.mu.Lock() + err := ci.mu.err + ci.mu.Unlock() + return err +} + +// setError() sets ci.err if one has not been set already. Caller must not be +// holding ci.Mutex. +func (ci *copyin) setError(err error) { + ci.mu.Lock() + if ci.mu.err == nil { + ci.mu.err = err + } + ci.mu.Unlock() +} + +func (ci *copyin) setResult(result driver.Result) { + ci.mu.Lock() + ci.mu.Result = result + ci.mu.Unlock() +} + +func (ci *copyin) getResult() driver.Result { + ci.mu.Lock() + result := ci.mu.Result + ci.mu.Unlock() + if result == nil { + return driver.RowsAffected(0) + } + return result +} + +func (ci *copyin) NumInput() int { + return -1 +} + +func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { + return nil, ErrNotSupported +} + +// Exec inserts values into the COPY stream. The insert is asynchronous +// and Exec can return errors from previous Exec calls to the same +// COPY stmt. +// +// You need to call Exec(nil) to sync the COPY stream and to get any +// errors from pending data, since Stmt.Close() doesn't return errors +// to the user. +func (ci *copyin) Exec(v []driver.Value) (driver.Result, error) { + if ci.closed { + return nil, errCopyInClosed + } + if err := ci.getBad(); err != nil { + return nil, err + } + if err := ci.err(); err != nil { + return nil, err + } + + if len(v) == 0 { + if err := ci.Close(); err != nil { + return driver.RowsAffected(0), err + } + return ci.getResult(), nil + } + + var ( + numValues = len(v) + err error + ) + for i, value := range v { + ci.buffer, err = appendEncodedText(ci.buffer, value) + if err != nil { + return nil, ci.cn.handleError(err) + } + if i < numValues-1 { + ci.buffer = append(ci.buffer, '\t') + } + } + + ci.buffer = append(ci.buffer, '\n') + + if len(ci.buffer) > ciBufferFlushSize { + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } + // reset buffer, keep bytes for message identifier and length + ci.buffer = ci.buffer[:5] + } + + return driver.RowsAffected(0), nil +} + +// CopyData inserts a raw string into the COPY stream. The insert is +// asynchronous and CopyData can return errors from previous CopyData calls to +// the same COPY stmt. +// +// You need to call Exec(nil) to sync the COPY stream and to get any +// errors from pending data, since Stmt.Close() doesn't return errors +// to the user. +func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, error) { + if ci.closed { + return nil, errCopyInClosed + } + if finish := ci.cn.watchCancel(ctx); finish != nil { + defer finish() + } + if err := ci.getBad(); err != nil { + return nil, err + } + if err := ci.err(); err != nil { + return nil, err + } + + ci.buffer = append(ci.buffer, []byte(line)...) + ci.buffer = append(ci.buffer, '\n') + + if len(ci.buffer) > ciBufferFlushSize { + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } + + // reset buffer, keep bytes for message identifier and length + ci.buffer = ci.buffer[:5] + } + + return driver.RowsAffected(0), nil +} + +func (ci *copyin) Close() error { + if ci.closed { // Don't do anything, we're already closed + return nil + } + ci.closed = true + + if err := ci.getBad(); err != nil { + return err + } + + if len(ci.buffer) > 0 { + err := ci.flush(ci.buffer) + if err != nil { + return ci.cn.handleError(err) + } + } + // Avoid touching the scratch buffer as resploop could be using it. + err := ci.cn.sendSimpleMessage(proto.CopyDoneRequest) + if err != nil { + return ci.cn.handleError(err) + } + + <-ci.done + ci.cn.inProgress.Store(false) + + if err := ci.err(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/lib/pq/deprecated.go b/vendor/github.com/lib/pq/deprecated.go new file mode 100644 index 00000000..d43934a0 --- /dev/null +++ b/vendor/github.com/lib/pq/deprecated.go @@ -0,0 +1,133 @@ +package pq + +import ( + "bytes" + "database/sql" + + "github.com/lib/pq/pqerror" +) + +// [pq.Error.Severity] values. +// +// Deprecated: use pqerror.Severity[..] values. +// +//go:fix inline +const ( + Efatal = pqerror.SeverityFatal + Epanic = pqerror.SeverityPanic + Ewarning = pqerror.SeverityWarning + Enotice = pqerror.SeverityNotice + Edebug = pqerror.SeverityDebug + Einfo = pqerror.SeverityInfo + Elog = pqerror.SeverityLog +) + +// PGError is an interface used by previous versions of pq. +// +// Deprecated: use the Error type. This is never used. +type PGError interface { + Error() string + Fatal() bool + Get(k byte) (v string) +} + +// Get implements the legacy PGError interface. +// +// Deprecated: new code should use the fields of the Error struct directly. +func (e *Error) Get(k byte) (v string) { + switch k { + case 'S': + return e.Severity + case 'C': + return string(e.Code) + case 'M': + return e.Message + case 'D': + return e.Detail + case 'H': + return e.Hint + case 'P': + return e.Position + case 'p': + return e.InternalPosition + case 'q': + return e.InternalQuery + case 'W': + return e.Where + case 's': + return e.Schema + case 't': + return e.Table + case 'c': + return e.Column + case 'd': + return e.DataTypeName + case 'n': + return e.Constraint + case 'F': + return e.File + case 'L': + return e.Line + case 'R': + return e.Routine + } + return "" +} + +// ParseURL converts a url to a connection string for driver.Open. +// +// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...") +// now works, and calling this manually is no longer required. +func ParseURL(url string) (string, error) { return convertURL(url) } + +// NullTime represents a [time.Time] that may be null. +// +// Deprecated: this is an alias for [sql.NullTime]. +// +//go:fix inline +type NullTime = sql.NullTime + +// CopyIn creates a COPY FROM statement which can be prepared with Tx.Prepare(). +// The target table should be visible in search_path. +// +// It copies all columns if the list of columns is empty. +// +// Deprecated: there is no need to use this query builder, you can use: +// +// tx.Prepare("copy tbl (col1, col2) from stdin") +func CopyIn(table string, columns ...string) string { + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() +} + +// CopyInSchema creates a COPY FROM statement which can be prepared with +// Tx.Prepare(). +// +// Deprecated: there is no need to use this query builder, you can use: +// +// tx.Prepare("copy schema.tbl (col1, col2) from stdin") +func CopyInSchema(schema, table string, columns ...string) string { + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(schema, b) + b.WriteRune('.') + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() +} + +func makeStmt(b *bytes.Buffer, columns ...string) { + if len(columns) == 0 { + b.WriteString(" FROM STDIN") + return + } + b.WriteString(" (") + for i, col := range columns { + if i != 0 { + b.WriteString(", ") + } + BufferQuoteIdentifier(col, b) + } + b.WriteString(") FROM STDIN") +} diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go new file mode 100644 index 00000000..9d9d78e4 --- /dev/null +++ b/vendor/github.com/lib/pq/doc.go @@ -0,0 +1,137 @@ +/* +Package pq is a Go PostgreSQL driver for database/sql. + +Most clients will use the database/sql package instead of using this package +directly. For example: + + import ( + "database/sql" + + _ "github.com/lib/pq" + ) + + func main() { + dsn := "user=pqgo dbname=pqgo sslmode=verify-full" + db, err := sql.Open("postgres", dsn) + if err != nil { + log.Fatal(err) + } + + age := 21 + rows, err := db.Query("select name from users where age = $1", age) + // … + } + +You can also connect with an URL: + + dsn := "postgres://pqgo:password@localhost/pqgo?sslmode=verify-full" + db, err := sql.Open("postgres", dsn) + +# Connection String Parameters + +See [NewConfig]. + +# Queries + +database/sql does not dictate any specific format for parameter placeholders, +and pq uses the PostgreSQL-native ordinal markers ($1, $2, etc.). The same +placeholder can be used more than once: + + rows, err := db.Query( + `select * from users where name = $1 or age between $2 and $2 + 3`, + "Duck", 64) + +pq does not support [sql.Result.LastInsertId]. Use the RETURNING clause with a +Query or QueryRow call instead to return the identifier: + + row := db.QueryRow(`insert into users(name, age) values('Scrooge McDuck', 93) returning id`) + + var userid int + err := row.Scan(&userid) + +# Data Types + +Parameters pass through [driver.DefaultParameterConverter] before they are handled +by this package. When the binary_parameters connection option is enabled, []byte +values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + +# Errors + +pq may return errors of type [*pq.Error] which contain error details: + + pqErr := new(pq.Error) + if errors.As(err, &pqErr) { + fmt.Println("pq error:", pqErr.Code.Name()) + } + +# Bulk imports + +You can perform bulk imports by preparing a "COPY [..] FROM STDIN" statement in +a transaction ([sql.Tx]). The returned [sql.Stmt] handle can then be repeatedly +"executed" to copy data into the target table. After all data has been processed +you should call Exec() once with no arguments to flush all buffered data. Any +call to Exec() might return an error which should be handled appropriately, but +because of the internal buffering an error returned by Exec() might not be +related to the data passed in the call that failed. + +It is not possible to COPY outside of an explicit transaction in pq. + +Use nil for NULL, or explicitly add WITH NULL 'SOME STRING' (the default of \N +doesn't work). + +# Notifications + +PostgreSQL supports a simple publish/subscribe model using PostgreSQL's [NOTIFY] mechanism. + +To start listening for notifications, you first have to open a new connection to +the database by calling [NewListener]. This connection can not be used for +anything other than LISTEN / NOTIFY. Calling Listen will open a "notification +channel"; once a notification channel is open, a notification generated on that +channel will effect a send on the Listener.Notify channel. A notification +channel will remain open until Unlisten is called, though connection loss might +result in some notifications being lost. To solve this problem, Listener sends a +nil pointer over the Notify channel any time the connection is re-established +following a connection loss. The application can get information about the state +of the underlying connection by setting an event callback in the call to +NewListener. + +A single [Listener] can safely be used from concurrent goroutines, which means +that there is often no need to create more than one Listener in your +application. However, a Listener is always connected to a single database, so +you will need to create a new Listener instance for every database you want to +receive notifications in. + +The channel name in both Listen and Unlisten is case sensitive, and can contain +any characters legal in an [identifier]. Note that the channel name will be +truncated to 63 bytes by the PostgreSQL server. + +# Kerberos Support + +If you need support for Kerberos authentication, add the following to your main +package: + + import "github.com/lib/pq/auth/kerberos" + + func init() { + pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) + } + +This package is in a separate module so that users who don't need Kerberos don't +have to add unnecessary dependencies. + +[identifier]: http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +[NOTIFY]: http://www.postgresql.org/docs/current/static/sql-notify.html +*/ +package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go new file mode 100644 index 00000000..f9b65051 --- /dev/null +++ b/vendor/github.com/lib/pq/encode.go @@ -0,0 +1,400 @@ +package pq + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/lib/pq/internal/pqtime" + "github.com/lib/pq/oid" +) + +func binaryEncode(x any) ([]byte, error) { + switch v := x.(type) { + case []byte: + return v, nil + default: + return encode(x, oid.T_unknown) + } +} + +func encode(x any, pgtypOid oid.Oid) ([]byte, error) { + switch v := x.(type) { + case int64: + return strconv.AppendInt(nil, v, 10), nil + case float64: + return strconv.AppendFloat(nil, v, 'f', -1, 64), nil + case []byte: + if v == nil { + return nil, nil + } + if pgtypOid == oid.T_bytea { + return encodeBytea(v), nil + } + return v, nil + case string: + if pgtypOid == oid.T_bytea { + return encodeBytea([]byte(v)), nil + } + return []byte(v), nil + case bool: + return strconv.AppendBool(nil, v), nil + case time.Time: + return formatTS(v), nil + default: + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) + } +} + +func decode(ps *parameterStatus, s []byte, typ oid.Oid, f format) (any, error) { + switch f { + case formatBinary: + return binaryDecode(s, typ) + case formatText: + return textDecode(ps, s, typ) + default: + panic("unreachable") + } +} + +func binaryDecode(s []byte, typ oid.Oid) (any, error) { + switch typ { + case oid.T_bytea: + return s, nil + case oid.T_int8: + return int64(binary.BigEndian.Uint64(s)), nil + case oid.T_int4: + return int64(int32(binary.BigEndian.Uint32(s))), nil + case oid.T_int2: + return int64(int16(binary.BigEndian.Uint16(s))), nil + case oid.T_uuid: + return decodeUUIDBinary(s) + default: + return nil, fmt.Errorf("pq: don't know how to decode binary parameter of type %d", uint32(typ)) + } + +} + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + return dst, nil +} + +func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { + switch typ { + case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text: + return string(s), nil + case oid.T_bytea: + b, err := parseBytea(s) + if err != nil { + err = errors.New("pq: " + err.Error()) + } + return b, err + case oid.T_timestamptz: + return parseTS(ps.currentLocation, string(s)) + case oid.T_timestamp, oid.T_date: + return parseTS(nil, string(s)) + case oid.T_time: + return parseTime(typ, s) + case oid.T_timetz: + return parseTime(typ, s) + case oid.T_bool: + return s[0] == 't', nil + case oid.T_int8, oid.T_int4, oid.T_int2: + i, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + err = errors.New("pq: " + err.Error()) + } + return i, err + case oid.T_float4, oid.T_float8: + // We always use 64 bit parsing, regardless of whether the input text is for + // a float4 or float8, because clients expect float64s for all float datatypes + // and returning a 32-bit parsed float64 produces lossy results. + f, err := strconv.ParseFloat(string(s), 64) + if err != nil { + err = errors.New("pq: " + err.Error()) + } + return f, err + } + return s, nil +} + +// appendEncodedText encodes item in text format as required by COPY +// and appends to buf +func appendEncodedText(buf []byte, x any) ([]byte, error) { + switch v := x.(type) { + case int64: + return strconv.AppendInt(buf, v, 10), nil + case float64: + return strconv.AppendFloat(buf, v, 'f', -1, 64), nil + case []byte: + encodedBytea := encodeBytea(v) + return appendEscapedText(buf, string(encodedBytea)), nil + case string: + return appendEscapedText(buf, v), nil + case bool: + return strconv.AppendBool(buf, v), nil + case time.Time: + return append(buf, formatTS(v)...), nil + case nil: + return append(buf, `\N`...), nil + default: + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) + } +} + +func appendEscapedText(buf []byte, text string) []byte { + escapeNeeded := false + startPos := 0 + + // check if we need to escape + for i := 0; i < len(text); i++ { + c := text[i] + if c == '\\' || c == '\n' || c == '\r' || c == '\t' { + escapeNeeded = true + startPos = i + break + } + } + if !escapeNeeded { + return append(buf, text...) + } + + // copy till first char to escape, iterate the rest + result := append(buf, text[:startPos]...) + for i := startPos; i < len(text); i++ { + switch c := text[i]; c { + case '\\': + result = append(result, '\\', '\\') + case '\n': + result = append(result, '\\', 'n') + case '\r': + result = append(result, '\\', 'r') + case '\t': + result = append(result, '\\', 't') + default: + result = append(result, c) + } + } + return result +} + +func parseTime(typ oid.Oid, s []byte) (time.Time, error) { + str := string(s) + + f := "15:04:05" + if typ == oid.T_timetz { + f = "15:04:05-07" + // PostgreSQL just sends the hour if the minute and second is 0: + // 22:04:59+00 + // 22:04:59+08 + // 22:04:59+08:30 + // 22:04:59+08:30:40 + // 23:00:00.112321+02:12:13 + // So add those to the format string. + c := strings.Count(str, ":") + if c > 3 { + f = "15:04:05-07:00:00" + } else if c > 2 { + f = "15:04:05-07:00" + } + } + + // Go doesn't parse 24:00, so manually set that to midnight on Jan 2. 24:00 + // is never with subseconds but may have a timezone: + // 24:00:00 + // 24:00:00+08 + // 24:00:00-08:01:01 + var is2400Time bool + if strings.HasPrefix(str, "24:00:00") { + is2400Time = true + if len(str) > 8 { + str = "00:00:00" + str[8:] + } else { + str = "00:00:00" + } + } + + t, err := time.Parse(f, str) + if err != nil { + return time.Time{}, errors.New("pq: " + err.Error()) + } + if is2400Time { + t = t.Add(24 * time.Hour) + } + // TODO(v2): it uses UTC, which it shouldn't. But I'm afraid changing it now + // will break people's code. + //if typ == oid.T_time { + // // Don't use UTC but time.FixedZone("", 0) + // t = t.In(globalLocationCache.getLocation(0)) + //} + return t, nil +} + +var ( + infinityTSEnabled = false + infinityTSNegative time.Time + infinityTSPositive time.Time +) + +// EnableInfinityTs controls the handling of Postgres' "-infinity" and +// "infinity" "timestamp"s. +// +// If EnableInfinityTs is not called, "-infinity" and "infinity" will return +// []byte("-infinity") and []byte("infinity") respectively, and potentially +// cause error "sql: Scan error on column index 0: unsupported driver -> Scan +// pair: []uint8 -> *time.Time", when scanning into a time.Time value. +// +// Once EnableInfinityTs has been called, all connections created using this +// driver will decode Postgres' "-infinity" and "infinity" for "timestamp", +// "timestamp with time zone" and "date" types to the predefined minimum and +// maximum times, respectively. When encoding time.Time values, any time which +// equals or precedes the predefined minimum time will be encoded to +// "-infinity". Any values at or past the maximum time will similarly be +// encoded to "infinity". +// +// If EnableInfinityTs is called with negative >= positive, it will panic. +// Calling EnableInfinityTs after a connection has been established results in +// undefined behavior. If EnableInfinityTs is called more than once, it will +// panic. +func EnableInfinityTs(negative time.Time, positive time.Time) { + if infinityTSEnabled { + panic("pq: infinity timestamp already enabled") + } + if !negative.Before(positive) { + panic("pq: infinity timestamp: negative value must be smaller (before) than positive") + } + infinityTSEnabled = true + infinityTSNegative = negative + infinityTSPositive = positive +} + +// Testing might want to toggle infinityTSEnabled +func disableInfinityTS() { + infinityTSEnabled = false +} + +// This is a time function specific to the Postgres default DateStyle setting +// ("ISO, MDY"), the only one we currently support. This accounts for the +// discrepancies between the parsing available with time.Parse and the Postgres +// date formatting quirks. +func parseTS(currentLocation *time.Location, str string) (any, error) { + switch str { + case "-infinity": + if infinityTSEnabled { + return infinityTSNegative, nil + } + return []byte(str), nil + case "infinity": + if infinityTSEnabled { + return infinityTSPositive, nil + } + return []byte(str), nil + } + t, err := ParseTimestamp(currentLocation, str) + if err != nil { + err = errors.New("pq: " + err.Error()) + } + return t, err +} + +// ParseTimestamp parses Postgres' text format. It returns a time.Time in +// currentLocation iff that time's offset agrees with the offset sent from the +// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the fixed +// offset offset provided by the Postgres server. +func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { + return pqtime.Parse(currentLocation, str) +} + +// formatTS formats t into a format postgres understands. +func formatTS(t time.Time) []byte { + if infinityTSEnabled { + // t <= -infinity : ! (t > -infinity) + if !t.After(infinityTSNegative) { + return []byte("-infinity") + } + // t >= infinity : ! (!t < infinity) + if !t.Before(infinityTSPositive) { + return []byte("infinity") + } + } + return FormatTimestamp(t) +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func FormatTimestamp(t time.Time) []byte { + return pqtime.Format(t) +} + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytea(s []byte) (result []byte, err error) { + // Hex format. + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + return result, nil + } + + // Escape format. + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseUint(string(s[1:4]), 8, 8) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %w", err) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + return result, nil +} + +func encodeBytea(v []byte) (result []byte) { + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) + return result +} diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go new file mode 100644 index 00000000..0851a66b --- /dev/null +++ b/vendor/github.com/lib/pq/error.go @@ -0,0 +1,324 @@ +package pq + +import ( + "database/sql/driver" + "fmt" + "io" + "net" + "runtime" + "strconv" + "strings" + "unicode/utf8" + + "github.com/lib/pq/pqerror" +) + +// Error returned by the PostgreSQL server. +// +// The [Error] method returns the error message and error code: +// +// pq: invalid input syntax for type json (22P02) +// +// The [ErrorWithDetail] method also includes the error Detail, Hint, and +// location context (if any): +// +// ERROR: invalid input syntax for type json (22P02) +// DETAIL: Token "asd" is invalid. +// CONTEXT: line 5, column 8: +// +// 3 | 'def', +// 4 | 123, +// 5 | 'foo', 'asd'::jsonb +// ^ +type Error struct { + // [Efatal], [Epanic], [Ewarning], [Enotice], [Edebug], [Einfo], or [Elog]. + // Always present. + Severity string + + // SQLSTATE code. Always present. + Code pqerror.Code + + // Primary human-readable error message. This should be accurate but terse + // (typically one line). Always present. + Message string + + // Optional secondary error message carrying more detail about the problem. + // Might run to multiple lines. + Detail string + + // Optional suggestion what to do about the problem. This is intended to + // differ from Detail in that it offers advice (potentially inappropriate) + // rather than hard facts. Might run to multiple lines. + Hint string + + // error position as an index into the original query string, as decimal + // ASCII integer. The first character has index 1, and positions are + // measured in characters not bytes. + Position string + + // This is defined the same as the Position field, but it is used when the + // cursor position refers to an internally generated command rather than the + // one submitted by the client. The InternalQuery field will always appear + // when this field appears. + InternalPosition string + + // Text of a failed internally-generated command. This could be, for + // example, an SQL query issued by a PL/pgSQL function. + InternalQuery string + + // An indication of the context in which the error occurred. Presently this + // includes a call stack traceback of active procedural language functions + // and internally-generated queries. The trace is one entry per line, most + // recent first. + Where string + + // If the error was associated with a specific database object, the name of + // the schema containing that object, if any. + Schema string + + // If the error was associated with a specific table, the name of the table. + // (Refer to the schema name field for the name of the table's schema.) + Table string + + // If the error was associated with a specific table column, the name of the + // column. (Refer to the schema and table name fields to identify the + // table.) + Column string + + // If the error was associated with a specific data type, the name of the + // data type. (Refer to the schema name field for the name of the data + // type's schema.) + DataTypeName string + + // If the error was associated with a specific constraint, the name of the + // constraint. Refer to fields listed above for the associated table or + // domain. (For this purpose, indexes are treated as constraints, even if + // they weren't created with constraint syntax.) + Constraint string + + // File name of the source-code location where the error was reported. + File string + + // Line number of the source-code location where the error was reported. + Line string + + // Name of the source-code routine reporting the error. + Routine string + + query string +} + +type ( + // ErrorCode is a five-character error code. + // + // Deprecated: use pqerror.Code + // + //go:fix inline + ErrorCode = pqerror.Code + + // ErrorClass is only the class part of an error code. + // + // Deprecated: use pqerror.Class + // + //go:fix inline + ErrorClass = pqerror.Class +) + +func parseError(r *readBuf, q string) *Error { + err := &Error{query: q} + for t := r.byte(); t != 0; t = r.byte() { + msg := r.string() + switch t { + case 'S': + err.Severity = msg + case 'C': + err.Code = pqerror.Code(msg) + case 'M': + err.Message = msg + case 'D': + err.Detail = msg + case 'H': + err.Hint = msg + case 'P': + err.Position = msg + case 'p': + err.InternalPosition = msg + case 'q': + err.InternalQuery = msg + case 'W': + err.Where = msg + case 's': + err.Schema = msg + case 't': + err.Table = msg + case 'c': + err.Column = msg + case 'd': + err.DataTypeName = msg + case 'n': + err.Constraint = msg + case 'F': + err.File = msg + case 'L': + err.Line = msg + case 'R': + err.Routine = msg + } + } + return err +} + +// Fatal returns true if the Error Severity is fatal. +func (e *Error) Fatal() bool { return e.Severity == pqerror.SeverityFatal } + +// SQLState returns the SQLState of the error. +func (e *Error) SQLState() string { return string(e.Code) } + +func (e *Error) Error() string { + msg := e.Message + if e.query != "" && e.Position != "" { + pos, err := strconv.Atoi(e.Position) + if err == nil { + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + if len(lines) == 1 { + msg += " at column " + strconv.Itoa(col) + } else { + msg += " at position " + strconv.Itoa(line) + ":" + strconv.Itoa(col) + } + } + } + + if e.Code != "" { + return "pq: " + msg + " (" + string(e.Code) + ")" + } + return "pq: " + msg +} + +// ErrorWithDetail returns the error message with detailed information and +// location context (if any). +// +// See the documentation on [Error]. +func (e *Error) ErrorWithDetail() string { + b := new(strings.Builder) + b.Grow(len(e.Message) + len(e.Detail) + len(e.Hint) + 30) + b.WriteString("ERROR: ") + b.WriteString(e.Message) + if e.Code != "" { + b.WriteString(" (") + b.WriteString(string(e.Code)) + b.WriteByte(')') + } + if e.Detail != "" { + b.WriteString("\nDETAIL: ") + b.WriteString(e.Detail) + } + if e.Hint != "" { + b.WriteString("\nHINT: ") + b.WriteString(e.Hint) + } + + if e.query != "" && e.Position != "" { + b.Grow(512) + pos, err := strconv.Atoi(e.Position) + if err != nil { + return b.String() + } + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + + fmt.Fprintf(b, "\nCONTEXT: line %d, column %d:\n\n", line, col) + if line > 2 { + fmt.Fprintf(b, "% 7d | %s\n", line-2, expandTab(lines[line-3])) + } + if line > 1 { + fmt.Fprintf(b, "% 7d | %s\n", line-1, expandTab(lines[line-2])) + } + /// Expand tabs, so that the ^ is at at the correct position, but leave + /// "column 10-13" intact. Adjusting this to the visual column would be + /// better, but we don't know the tabsize of the user in their editor, + /// which can be 8, 4, 2, or something else. We can't know. So leaving + /// it as the character index is probably the "most correct". + expanded := expandTab(lines[line-1]) + diff := len(expanded) - len(lines[line-1]) + fmt.Fprintf(b, "% 7d | %s\n", line, expanded) + fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col-1+diff), "^") + } + + return b.String() +} + +func posToLine(pos int, lines []string) (line, col int) { + read := 0 + for i := range lines { + line++ + ll := utf8.RuneCountInString(lines[i]) + 1 // +1 for the removed newline + if read+ll >= pos { + col = max(pos-read, 1) // Should be lower than 1, but just in case. + break + } + read += ll + } + return line, col +} + +func expandTab(s string) string { + var ( + b strings.Builder + l int + fill = func(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = ' ' + } + return string(b) + } + ) + b.Grow(len(s)) + for _, r := range s { + switch r { + case '\t': + tw := 8 - l%8 + b.WriteString(fill(tw)) + l += tw + default: + b.WriteRune(r) + l += 1 + } + } + return b.String() +} + +func (cn *conn) handleError(reported error, query ...string) error { + switch err := reported.(type) { + case nil: + return nil + case runtime.Error, *net.OpError: + cn.err.set(driver.ErrBadConn) + case *safeRetryError: + cn.err.set(driver.ErrBadConn) + reported = driver.ErrBadConn + case *Error: + if len(query) > 0 && query[0] != "" { + err.query = query[0] + reported = err + } + if err.Fatal() { + reported = driver.ErrBadConn + } + case error: + if err == io.EOF || err.Error() == "remote error: handshake failure" { + reported = driver.ErrBadConn + } + default: + cn.err.set(driver.ErrBadConn) + reported = fmt.Errorf("pq: unknown error %T: %[1]s", err) + } + + // Any time we return ErrBadConn, we need to remember it since *Tx doesn't + // mark the connection bad in database/sql. + if reported == driver.ErrBadConn { + cn.err.set(driver.ErrBadConn) + } + return reported +} diff --git a/vendor/github.com/lib/pq/internal/pgpass/pgpass.go b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go new file mode 100644 index 00000000..002631da --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go @@ -0,0 +1,71 @@ +package pgpass + +import ( + "bufio" + "os" + "path/filepath" + "strings" + + "github.com/lib/pq/internal/pqutil" +) + +func PasswordFromPgpass(passfile, user, password, host, port, dbname string, passwordSet bool) string { + // Do not process .pgpass if a password was supplied. + if passwordSet { + return password + } + + filename := pqutil.Pgpass(passfile) + if filename == "" { + return "" + } + + fp, err := os.Open(filename) + if err != nil { + return "" + } + defer fp.Close() + + scan := bufio.NewScanner(fp) + for scan.Scan() { + line := scan.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := splitFields(line) + if len(split) != 5 { + continue + } + + socket := host == "" || filepath.IsAbs(host) || strings.HasPrefix(host, "@") + if (split[0] == "*" || split[0] == host || (split[0] == "localhost" && socket)) && + (split[1] == "*" || split[1] == port) && + (split[2] == "*" || split[2] == dbname) && + (split[3] == "*" || split[3] == user) { + return split[4] + } + } + + return "" +} + +func splitFields(s string) []string { + var ( + fs = make([]string, 0, 5) + f = make([]rune, 0, len(s)) + esc bool + ) + for _, c := range s { + switch { + case esc: + f, esc = append(f, c), false + case c == '\\': + esc = true + case c == ':': + fs, f = append(fs, string(f)), f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) +} diff --git a/vendor/github.com/lib/pq/internal/pgservice/pgservice.go b/vendor/github.com/lib/pq/internal/pgservice/pgservice.go new file mode 100644 index 00000000..9842648c --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pgservice/pgservice.go @@ -0,0 +1,70 @@ +package pgservice + +import ( + "bufio" + "fmt" + "os" + "strings" + + "github.com/lib/pq/internal/pqutil" +) + +func FindService(path string, service string) (map[string]string, error) { + fp, err := os.Open(path) + if err != nil { + if pqutil.ErrNotExists(err) { + // libpq just returns "definition of service not found" if the + // default file doesn't exist, but IMO that's confusing. + return nil, fmt.Errorf("service file %q not found", path) + } + return nil, err + } + defer fp.Close() + + var ( + scan = bufio.NewScanner(fp) + i int + ) + for scan.Scan() { + i++ + line := strings.TrimSpace(scan.Text()) + if line == "" || line[0] == '#' { + continue + } + + // [service] header that we want. + if line[0] == '[' && line[len(line)-1] == ']' && strings.TrimSpace(line[1:len(line)-1]) == service { + opts := make(map[string]string) + for scan.Scan() { + i++ + line := strings.TrimSpace(scan.Text()) + if line == "" || line[0] == '#' { + continue + } + // Next header: our work here is done. + if line[0] == '[' && line[len(line)-1] == ']' { + return opts, nil + } + + k, v, ok := strings.Cut(line, "=") + if !ok { + return nil, fmt.Errorf("line %d: missing '=' in %q", i, line) + } + k, v = strings.TrimSpace(k), strings.TrimSpace(v) + if k == "" { + return nil, fmt.Errorf("line %d: no value before '=' in %q", i, line) + } + opts[k] = v + } + if scan.Err() != nil { + return nil, scan.Err() + } + return opts, nil + } + } + if scan.Err() != nil { + return nil, scan.Err() + } + + return nil, fmt.Errorf("definition of service %q not found", service) +} diff --git a/vendor/github.com/lib/pq/internal/pqsql/copy.go b/vendor/github.com/lib/pq/internal/pqsql/copy.go new file mode 100644 index 00000000..ccb688f6 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqsql/copy.go @@ -0,0 +1,37 @@ +package pqsql + +// StartsWithCopy reports if the SQL strings start with "copy", ignoring +// whitespace, comments, and casing. +func StartsWithCopy(query string) bool { + if len(query) < 4 { + return false + } + var linecmt, blockcmt bool + for i := 0; i < len(query); i++ { + c := query[i] + if linecmt { + linecmt = c != '\n' + continue + } + if blockcmt { + blockcmt = !(c == '/' && query[i-1] == '*') + continue + } + if c == '-' && len(query) > i+1 && query[i+1] == '-' { + linecmt = true + continue + } + if c == '/' && len(query) > i+1 && query[i+1] == '*' { + blockcmt = true + continue + } + if c == ' ' || c == '\t' || c == '\r' || c == '\n' { + continue + } + + // First non-comment and non-whitespace. + return len(query) > i+3 && c|0x20 == 'c' && query[i+1]|0x20 == 'o' && + query[i+2]|0x20 == 'p' && query[i+3]|0x20 == 'y' + } + return false +} diff --git a/vendor/github.com/lib/pq/internal/pqtime/loc.go b/vendor/github.com/lib/pq/internal/pqtime/loc.go new file mode 100644 index 00000000..d23dd5b0 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqtime/loc.go @@ -0,0 +1,37 @@ +package pqtime + +import ( + "sync" + "time" +) + +// The location cache caches the time zones typically used by the client. +type locationCache struct { + cache map[int]*time.Location + lock sync.Mutex +} + +// All connections share the same list of timezones. Benchmarking shows that +// about 5% speed could be gained by putting the cache in the connection and +// losing the mutex, at the cost of a small amount of memory and a somewhat +// significant increase in code complexity. +var globalLocationCache = &locationCache{cache: make(map[int]*time.Location)} + +func Reset() { + globalLocationCache = &locationCache{cache: make(map[int]*time.Location)} +} + +// Returns the cached timezone for the specified offset, creating and caching +// it if necessary. +func (c *locationCache) getLocation(offset int) *time.Location { + c.lock.Lock() + defer c.lock.Unlock() + l, ok := c.cache[offset] + if !ok { + // TODO(v2): for offset=0 it should use some descriptive text like + // "without time zone". + l = time.FixedZone("", offset) + c.cache[offset] = l + } + return l +} diff --git a/vendor/github.com/lib/pq/internal/pqtime/pqtime.go b/vendor/github.com/lib/pq/internal/pqtime/pqtime.go new file mode 100644 index 00000000..28008e86 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqtime/pqtime.go @@ -0,0 +1,190 @@ +package pqtime + +import ( + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" +) + +var errInvalidTimestamp = errors.New("invalid timestamp") + +type timestampParser struct { + err error +} + +func (p *timestampParser) expect(str string, char byte, pos int) { + if p.err != nil { + return + } + if pos+1 > len(str) { + p.err = errInvalidTimestamp + return + } + if c := str[pos]; c != char && p.err == nil { + p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) + } +} + +func (p *timestampParser) mustAtoi(str string, begin int, end int) int { + if p.err != nil { + return 0 + } + if begin < 0 || end < 0 || begin > end || end > len(str) { + p.err = errInvalidTimestamp + return 0 + } + result, err := strconv.Atoi(str[begin:end]) + if err != nil { + if p.err == nil { + p.err = fmt.Errorf("expected number; got '%v'", str) + } + return 0 + } + return result +} + +func Parse(currentLocation *time.Location, str string) (time.Time, error) { + p := timestampParser{} + + monSep := strings.IndexRune(str, '-') + // this is Gregorian year, not ISO Year + // In Gregorian system, the year 1 BC is followed by AD 1 + year := p.mustAtoi(str, 0, monSep) + daySep := monSep + 3 + month := p.mustAtoi(str, monSep+1, daySep) + p.expect(str, '-', daySep) + timeSep := daySep + 3 + day := p.mustAtoi(str, daySep+1, timeSep) + + minLen := monSep + len("01-01") + 1 + + isBC := strings.HasSuffix(str, " BC") + if isBC { + minLen += 3 + } + + var hour, minute, second int + if len(str) > minLen { + p.expect(str, ' ', timeSep) + minSep := timeSep + 3 + p.expect(str, ':', minSep) + hour = p.mustAtoi(str, timeSep+1, minSep) + secSep := minSep + 3 + p.expect(str, ':', secSep) + minute = p.mustAtoi(str, minSep+1, secSep) + secEnd := secSep + 3 + second = p.mustAtoi(str, secSep+1, secEnd) + } + remainderIdx := monSep + len("01-01 00:00:00") + 1 + // Three optional (but ordered) sections follow: the + // fractional seconds, the time zone offset, and the BC + // designation. We set them up here and adjust the other + // offsets if the preceding sections exist. + + nanoSec := 0 + tzOff := 0 + + if remainderIdx < len(str) && str[remainderIdx] == '.' { + fracStart := remainderIdx + 1 + fracOff := strings.IndexAny(str[fracStart:], "-+Z ") + if fracOff < 0 { + fracOff = len(str) - fracStart + } + fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) + nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) + + remainderIdx += fracOff + 1 + } + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { + // time zone separator is always '-' or '+' or 'Z' (UTC is +00) + var tzSign int + switch c := str[tzStart]; c { + case '-': + tzSign = -1 + case '+': + tzSign = +1 + default: + return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + } + tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) + remainderIdx += 3 + var tzMin, tzSec int + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) + remainderIdx += 3 + } + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) + remainderIdx += 3 + } + tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) + } else if tzStart < len(str) && str[tzStart] == 'Z' { + // time zone Z separator indicates UTC is +00 + remainderIdx += 1 + } + + var isoYear int + + if isBC { + isoYear = 1 - year + remainderIdx += 3 + } else { + isoYear = year + } + if remainderIdx < len(str) { + return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) + } + t := time.Date(isoYear, time.Month(month), day, + hour, minute, second, nanoSec, + globalLocationCache.getLocation(tzOff)) + + if currentLocation != nil { + // Set the location of the returned Time based on the session's + // TimeZone value, but only if the local time zone database agrees with + // the remote database on the offset. + lt := t.In(currentLocation) + _, newOff := lt.Zone() + if newOff == tzOff { + t = lt + } + } + + return t, p.err +} + +// Format into Postgres' text format for timestamps. +func Format(t time.Time) []byte { + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) + + _, offset := t.Zone() + offset %= 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/path.go b/vendor/github.com/lib/pq/internal/pqutil/path.go new file mode 100644 index 00000000..a28fc95f --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/path.go @@ -0,0 +1,86 @@ +package pqutil + +import ( + "errors" + "fmt" + "io" + "os" + "os/user" + "path/filepath" + "runtime" + "syscall" +) + +// Home gets the PostgreSQL configuration dir in the user's home directory: +// %APPDATA%/postgresql on Windows, and $HOME/.postgresql/postgresql.crt +// everywhere else. +// +// Returns an empy string if no home directory was found. +// +// Matches pqGetHomeDirectory() from PostgreSQL. +// https://github.com/postgres/postgres/blob/2b117bb/src/interfaces/libpq/fe-connect.c#L8214 +func Home() string { + if runtime.GOOS == "windows" { + // pq uses SHGetFolderPath(), which is deprecated but x/sys/windows has + // KnownFolderPath(). We don't really want to pull that in though, so + // use APPDATA env. This is also what PostgreSQL uses in some other + // codepaths (get_home_path() for example). + ad := os.Getenv("APPDATA") + if ad == "" { + return "" + } + return filepath.Join(ad, "postgresql") + } + + home, _ := os.UserHomeDir() + if home == "" { + u, err := user.Current() + if err != nil { + return "" + } + home = u.HomeDir + } + return filepath.Join(home, ".postgresql") +} + +// ErrNotExists reports if err is a "path doesn't exist" type error. +// +// fs.ErrNotExist is not enough, as "/dev/null/somefile" will return ENOTDIR +// instead of ENOENT. +func ErrNotExists(err error) bool { + perr := new(os.PathError) + if errors.As(err, &perr) && (perr.Err == syscall.ENOENT || perr.Err == syscall.ENOTDIR) { + return true + } + return false +} + +var WarnFD io.Writer = os.Stderr + +// Pgpass gets the filepath to the pgpass file to use, returning "" if a pgpass +// file shouldn't be used. +func Pgpass(passfile string) string { + // Get passfile from the options. + if passfile == "" { + home := Home() + if home == "" { + return "" + } + passfile = filepath.Join(home, ".pgpass") + } + + // On Win32, the directory is protected, so we don't have to check the file. + if runtime.GOOS != "windows" { + fi, err := os.Stat(passfile) + if err != nil { + return "" + } + if fi.Mode().Perm()&(0x77) != 0 { + fmt.Fprintf(WarnFD, + "WARNING: password file %q has group or world access; permissions should be u=rw (0600) or less\n", + passfile) + return "" + } + } + return passfile +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm.go b/vendor/github.com/lib/pq/internal/pqutil/perm.go new file mode 100644 index 00000000..05fb9a6a --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm.go @@ -0,0 +1,64 @@ +//go:build !windows && !plan9 + +package pqutil + +import ( + "errors" + "os" + "syscall" +) + +var ( + ErrSSLKeyUnknownOwnership = errors.New("pq: could not get owner information for private key, may not be properly protected") + ErrSSLKeyHasWorldPermissions = errors.New("pq: private key has world access; permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") +) + +// SSLKeyPermissions checks the permissions on user-supplied SSL key files, +// which should have very little access. libpq does not check key file +// permissions on Windows. +// +// If the file is owned by the same user the process is running as, the file +// should only have 0600. If the file is owned by root, and the group matches +// the group that the process is running in, the permissions cannot be more than +// 0640. The file should never have world permissions. +// +// Returns an error when the permission check fails. +func SSLKeyPermissions(sslkey string) error { + fi, err := os.Stat(sslkey) + if err != nil { + return err + } + + return CheckPermissions(fi) +} + +func CheckPermissions(fi os.FileInfo) error { + // The maximum permissions that a private key file owned by a regular user + // is allowed to have. This translates to u=rw. Regardless of if we're + // running as root or not, 0600 is acceptable, so we return if no bits + // beyond the regular user permission mask are set. + if fi.Mode().Perm()&^os.FileMode(0o600) == 0 { + return nil + } + + // We need to pull the Unix file information to get the file's owner. + // If we can't access it, there's some sort of operating system level error + // and we should fail rather than attempting to use faulty information. + sys, ok := fi.Sys().(*syscall.Stat_t) + if !ok { + return ErrSSLKeyUnknownOwnership + } + + // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what + // Postgres does. + if sys.Uid == 0 { + // The maximum permissions that a private key file owned by root is + // allowed to have. This translates to u=rw,g=r. + if fi.Mode().Perm()&^os.FileMode(0o640) != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil + } + + return ErrSSLKeyHasWorldPermissions +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go new file mode 100644 index 00000000..3ce75957 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go @@ -0,0 +1,12 @@ +//go:build windows || plan9 + +package pqutil + +import "errors" + +var ( + ErrSSLKeyUnknownOwnership = errors.New("unused") + ErrSSLKeyHasWorldPermissions = errors.New("unused") +) + +func SSLKeyPermissions(sslkey string) error { return nil } diff --git a/vendor/github.com/lib/pq/internal/pqutil/pqutil.go b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go new file mode 100644 index 00000000..ca869e9c --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go @@ -0,0 +1,32 @@ +package pqutil + +import ( + "strconv" + "strings" +) + +// ParseBool is like strconv.ParseBool, but also accepts "yes"/"no" and +// "on"/"off". +func ParseBool(str string) (bool, error) { + switch str { + case "1", "t", "T", "true", "TRUE", "True", "yes", "on": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "no", "off": + return false, nil + } + return false, &strconv.NumError{Func: "ParseBool", Num: str, Err: strconv.ErrSyntax} +} + +func Join[S ~[]E, E ~string](s S) string { + var b strings.Builder + for i := range s { + if i > 0 { + b.WriteString(", ") + } + if i == len(s)-1 { + b.WriteString("or ") + } + b.WriteString(string(s[i])) + } + return b.String() +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_other.go b/vendor/github.com/lib/pq/internal/pqutil/user_other.go new file mode 100644 index 00000000..09e4f8df --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_other.go @@ -0,0 +1,9 @@ +//go:build js || android || hurd || zos || wasip1 || appengine + +package pqutil + +import "errors" + +func User() (string, error) { + return "", errors.New("pqutil.User: not supported on current platform") +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_posix.go b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go new file mode 100644 index 00000000..bd0ece6d --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go @@ -0,0 +1,25 @@ +//go:build !windows && !js && !android && !hurd && !zos && !wasip1 && !appengine + +package pqutil + +import ( + "os" + "os/user" + "runtime" +) + +func User() (string, error) { + env := "USER" + if runtime.GOOS == "plan9" { + env = "user" + } + if n := os.Getenv(env); n != "" { + return n, nil + } + + u, err := user.Current() + if err != nil { + return "", err + } + return u.Username, nil +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_windows.go b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go new file mode 100644 index 00000000..960cb805 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go @@ -0,0 +1,28 @@ +//go:build windows && !appengine + +package pqutil + +import ( + "path/filepath" + "syscall" +) + +func User() (string, error) { + // Perform Windows user name lookup identically to libpq. + // + // The PostgreSQL code makes use of the legacy Win32 function GetUserName, + // and that function has not been imported into stock Go. GetUserNameEx is + // available though, the difference being that a wider range of names are + // available. To get the output to be the same as GetUserName, only the + // base (or last) component of the result is returned. + var ( + name = make([]uint16, 128) + pwnameSz = uint32(len(name)) - 1 + ) + err := syscall.GetUserNameEx(syscall.NameSamCompatible, &name[0], &pwnameSz) + if err != nil { + return "", err + } + s := syscall.UTF16ToString(name) + return filepath.Base(s), nil +} diff --git a/vendor/github.com/lib/pq/internal/proto/proto.go b/vendor/github.com/lib/pq/internal/proto/proto.go new file mode 100644 index 00000000..e8b4bc59 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/proto.go @@ -0,0 +1,186 @@ +// From src/include/libpq/protocol.h and src/include/libpq/pqcomm.h – PostgreSQL 18.1 + +package proto + +import ( + "fmt" + "strconv" +) + +// Constants from pqcomm.h +const ( + ProtocolVersion30 = (3 << 16) | 0 //lint:ignore SA4016 x + ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18. + CancelRequestCode = (1234 << 16) | 5678 + NegotiateSSLCode = (1234 << 16) | 5679 + NegotiateGSSCode = (1234 << 16) | 5680 +) + +// Constants from fe-connect.c +const ( + MaxErrlen = 30_000 // https://github.com/postgres/postgres/blob/c6a10a89f/src/interfaces/libpq/fe-connect.c#L4067 +) + +// RequestCode is a request codes sent by the frontend. +type RequestCode byte + +// These are the request codes sent by the frontend. +const ( + Bind = RequestCode('B') + Close = RequestCode('C') + Describe = RequestCode('D') + Execute = RequestCode('E') + FunctionCall = RequestCode('F') + Flush = RequestCode('H') + Parse = RequestCode('P') + Query = RequestCode('Q') + Sync = RequestCode('S') + Terminate = RequestCode('X') + CopyFail = RequestCode('f') + GSSResponse = RequestCode('p') + PasswordMessage = RequestCode('p') + SASLInitialResponse = RequestCode('p') + SASLResponse = RequestCode('p') + CopyDoneRequest = RequestCode('c') + CopyDataRequest = RequestCode('d') +) + +func (r RequestCode) String() string { + s, ok := map[RequestCode]string{ + Bind: "Bind", + Close: "Close", + Describe: "Describe", + Execute: "Execute", + FunctionCall: "FunctionCall", + Flush: "Flush", + Parse: "Parse", + Query: "Query", + Sync: "Sync", + Terminate: "Terminate", + CopyFail: "CopyFail", + // These are all the same :-/ + //GSSResponse: "GSSResponse", + PasswordMessage: "PasswordMessage", + //SASLInitialResponse: "SASLInitialResponse", + //SASLResponse: "SASLResponse", + CopyDoneRequest: "CopyDone", + CopyDataRequest: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// ResponseCode is a response codes sent by the backend. +type ResponseCode byte + +// These are the response codes sent by the backend. +const ( + ParseComplete = ResponseCode('1') + BindComplete = ResponseCode('2') + CloseComplete = ResponseCode('3') + NotificationResponse = ResponseCode('A') + CommandComplete = ResponseCode('C') + DataRow = ResponseCode('D') + ErrorResponse = ResponseCode('E') + CopyInResponse = ResponseCode('G') + CopyOutResponse = ResponseCode('H') + EmptyQueryResponse = ResponseCode('I') + BackendKeyData = ResponseCode('K') + NoticeResponse = ResponseCode('N') + AuthenticationRequest = ResponseCode('R') + ParameterStatus = ResponseCode('S') + RowDescription = ResponseCode('T') + FunctionCallResponse = ResponseCode('V') + CopyBothResponse = ResponseCode('W') + ReadyForQuery = ResponseCode('Z') + NoData = ResponseCode('n') + PortalSuspended = ResponseCode('s') + ParameterDescription = ResponseCode('t') + NegotiateProtocolVersion = ResponseCode('v') + CopyDoneResponse = ResponseCode('c') + CopyDataResponse = ResponseCode('d') +) + +func (r ResponseCode) String() string { + s, ok := map[ResponseCode]string{ + ParseComplete: "ParseComplete", + BindComplete: "BindComplete", + CloseComplete: "CloseComplete", + NotificationResponse: "NotificationResponse", + CommandComplete: "CommandComplete", + DataRow: "DataRow", + ErrorResponse: "ErrorResponse", + CopyInResponse: "CopyInResponse", + CopyOutResponse: "CopyOutResponse", + EmptyQueryResponse: "EmptyQueryResponse", + BackendKeyData: "BackendKeyData", + NoticeResponse: "NoticeResponse", + AuthenticationRequest: "AuthRequest", + ParameterStatus: "ParamStatus", + RowDescription: "RowDescription", + FunctionCallResponse: "FunctionCallResponse", + CopyBothResponse: "CopyBothResponse", + ReadyForQuery: "ReadyForQuery", + NoData: "NoData", + PortalSuspended: "PortalSuspended", + ParameterDescription: "ParamDescription", + NegotiateProtocolVersion: "NegotiateProtocolVersion", + CopyDoneResponse: "CopyDone", + CopyDataResponse: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// AuthCode are authentication request codes sent by the backend. +type AuthCode int32 + +// These are the authentication request codes sent by the backend. +const ( + AuthReqOk = AuthCode(0) // User is authenticated + AuthReqKrb4 = AuthCode(1) // Kerberos V4. Not supported any more. + AuthReqKrb5 = AuthCode(2) // Kerberos V5. Not supported any more. + AuthReqPassword = AuthCode(3) // Password + AuthReqCrypt = AuthCode(4) // crypt password. Not supported any more. + AuthReqMD5 = AuthCode(5) // md5 password + _ = AuthCode(6) // 6 is available. It was used for SCM creds, not supported any more. + AuthReqGSS = AuthCode(7) // GSSAPI without wrap() + AuthReqGSSCont = AuthCode(8) // Continue GSS exchanges + AuthReqSSPI = AuthCode(9) // SSPI negotiate without wrap() + AuthReqSASL = AuthCode(10) // Begin SASL authentication + AuthReqSASLCont = AuthCode(11) // Continue SASL authentication + AuthReqSASLFin = AuthCode(12) // Final SASL message +) + +func (a AuthCode) String() string { + s, ok := map[AuthCode]string{ + AuthReqOk: "ok", + AuthReqKrb4: "krb4", + AuthReqKrb5: "krb5", + AuthReqPassword: "password", + AuthReqCrypt: "crypt", + AuthReqMD5: "md5", + AuthReqGSS: "GDD", + AuthReqGSSCont: "GSSCont", + AuthReqSSPI: "SSPI", + AuthReqSASL: "SASL", + AuthReqSASLCont: "SASLCont", + AuthReqSASLFin: "SASLFin", + }[a] + if !ok { + s = "" + } + return s + " (" + strconv.Itoa(int(a)) + ")" +} diff --git a/vendor/github.com/lib/pq/internal/proto/sz_32.go b/vendor/github.com/lib/pq/internal/proto/sz_32.go new file mode 100644 index 00000000..68065591 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_32.go @@ -0,0 +1,7 @@ +//go:build 386 || arm || mips || mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxInt diff --git a/vendor/github.com/lib/pq/internal/proto/sz_64.go b/vendor/github.com/lib/pq/internal/proto/sz_64.go new file mode 100644 index 00000000..2b8ad897 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_64.go @@ -0,0 +1,7 @@ +//go:build !386 && !arm && !mips && !mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxUint32 diff --git a/vendor/github.com/lib/pq/krb.go b/vendor/github.com/lib/pq/krb.go new file mode 100644 index 00000000..408ec01f --- /dev/null +++ b/vendor/github.com/lib/pq/krb.go @@ -0,0 +1,27 @@ +package pq + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGss NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/lib/pq/auth/kerberos" +// +// func init() { +// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) +// } +func RegisterGSSProvider(newGssArg NewGSSFunc) { + newGss = newGssArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSpn(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} diff --git a/vendor/github.com/lib/pq/notice.go b/vendor/github.com/lib/pq/notice.go new file mode 100644 index 00000000..7b9ff392 --- /dev/null +++ b/vendor/github.com/lib/pq/notice.go @@ -0,0 +1,69 @@ +package pq + +import ( + "context" + "database/sql/driver" +) + +// NoticeHandler returns the notice handler on the given connection, if any. A +// runtime panic occurs if c is not a pq connection. This is rarely used +// directly, use [ConnectorNoticeHandler] and [ConnectorWithNoticeHandler] instead. +func NoticeHandler(c driver.Conn) func(*Error) { + return c.(*conn).noticeHandler +} + +// SetNoticeHandler sets the given notice handler on the given connection. A +// runtime panic occurs if c is not a pq connection. A nil handler may be used +// to unset it. This is rarely used directly, use ConnectorNoticeHandler and +// [ConnectorWithNoticeHandler] instead. +// +// Note: Notice handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func SetNoticeHandler(c driver.Conn, handler func(*Error)) { + c.(*conn).noticeHandler = handler +} + +// NoticeHandlerConnector wraps a regular connector and sets a notice handler +// on it. +type NoticeHandlerConnector struct { + driver.Connector + noticeHandler func(*Error) +} + +// Connect calls the underlying connector's connect method and then sets the +// notice handler. +func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { + c, err := n.Connector.Connect(ctx) + if err == nil { + SetNoticeHandler(c, n.noticeHandler) + } + return c, err +} + +// ConnectorNoticeHandler returns the currently set notice handler, if any. If +// the given connector is not a result of [ConnectorWithNoticeHandler], nil is +// returned. +func ConnectorNoticeHandler(c driver.Connector) func(*Error) { + if c, ok := c.(*NoticeHandlerConnector); ok { + return c.noticeHandler + } + return nil +} + +// ConnectorWithNoticeHandler creates or sets the given handler for the given +// connector. If the given connector is a result of calling this function +// previously, it is simply set on the given connector and returned. Otherwise, +// this returns a new connector wrapping the given one and setting the notice +// handler. A nil notice handler may be used to unset it. +// +// The returned connector is intended to be used with database/sql.OpenDB. +// +// Note: Notice handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector { + if c, ok := c.(*NoticeHandlerConnector); ok { + c.noticeHandler = handler + return c + } + return &NoticeHandlerConnector{Connector: c, noticeHandler: handler} +} diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go new file mode 100644 index 00000000..4f4c4227 --- /dev/null +++ b/vendor/github.com/lib/pq/notify.go @@ -0,0 +1,834 @@ +package pq + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/lib/pq/internal/proto" +) + +// Notification represents a single notification from the database. +type Notification struct { + BePid int // Process ID (PID) of the notifying postgres backend. + Channel string // Name of the channel the notification was sent on. + Extra string // Payload, or the empty string if unspecified. +} + +func recvNotification(r *readBuf) *Notification { + bePid := r.int32() + channel := r.string() + extra := r.string() + return &Notification{bePid, channel, extra} +} + +// SetNotificationHandler sets the given notification handler on the given +// connection. A runtime panic occurs if c is not a pq connection. A nil handler +// may be used to unset it. +// +// Note: Notification handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func SetNotificationHandler(c driver.Conn, handler func(*Notification)) { + c.(*conn).notificationHandler = handler +} + +// NotificationHandlerConnector wraps a regular connector and sets a +// notification handler on it. +type NotificationHandlerConnector struct { + driver.Connector + notificationHandler func(*Notification) +} + +// Connect calls the underlying connector's connect method and then sets the +// notification handler. +func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { + c, err := n.Connector.Connect(ctx) + if err == nil { + SetNotificationHandler(c, n.notificationHandler) + } + return c, err +} + +// ConnectorNotificationHandler returns the currently set notification handler, +// if any. If the given connector is not a result of +// [ConnectorWithNotificationHandler], nil is returned. +func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { + if c, ok := c.(*NotificationHandlerConnector); ok { + return c.notificationHandler + } + return nil +} + +// ConnectorWithNotificationHandler creates or sets the given handler for the +// given connector. If the given connector is a result of calling this function +// previously, it is simply set on the given connector and returned. Otherwise, +// this returns a new connector wrapping the given one and setting the +// notification handler. A nil notification handler may be used to unset it. +// +// The returned connector is intended to be used with database/sql.OpenDB. +// +// Note: Notification handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector { + if c, ok := c.(*NotificationHandlerConnector); ok { + c.notificationHandler = handler + return c + } + return &NotificationHandlerConnector{Connector: c, notificationHandler: handler} +} + +const ( + connStateIdle int32 = iota + connStateExpectResponse + connStateExpectReadyForQuery +) + +type message struct { + typ proto.ResponseCode + err error +} + +var errListenerConnClosed = errors.New("pq: ListenerConn has been closed") + +// ListenerConn is a low-level interface for waiting for notifications. You +// should use [Listener] instead. +type ListenerConn struct { + connectionLock sync.Mutex // guards cn and err + senderLock sync.Mutex // the sending goroutine will be holding this lock + cn *conn + err error + connState int32 + notificationChan chan<- *Notification + replyChan chan message +} + +// NewListenerConn creates a new ListenerConn. Use NewListener instead. +func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { + return newDialListenerConn(defaultDialer{}, name, notificationChan) +} + +func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { + cn, err := DialOpen(d, name) + if err != nil { + return nil, err + } + + l := &ListenerConn{ + cn: cn.(*conn), + notificationChan: c, + connState: connStateIdle, + replyChan: make(chan message, 2), + } + + go l.listenerConnMain() + return l, nil +} + +// We can only allow one goroutine at a time to be running a query on the +// connection for various reasons, so the goroutine sending on the connection +// must be holding senderLock. +// +// Returns an error if an unrecoverable error has occurred and the ListenerConn +// should be abandoned. +func (l *ListenerConn) acquireSenderLock() error { + // we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery + l.senderLock.Lock() + + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + if err != nil { + l.senderLock.Unlock() + return err + } + return nil +} + +func (l *ListenerConn) releaseSenderLock() { + l.senderLock.Unlock() +} + +// setState advances the protocol state to newState. Returns false if moving +// to that state from the current state is not allowed. +func (l *ListenerConn) setState(newState int32) bool { + var expectedState int32 + + switch newState { + case connStateIdle: + expectedState = connStateExpectReadyForQuery + case connStateExpectResponse: + expectedState = connStateIdle + case connStateExpectReadyForQuery: + expectedState = connStateExpectResponse + default: + panic(fmt.Sprintf("unexpected listenerConnState %d", newState)) + } + + return atomic.CompareAndSwapInt32(&l.connState, expectedState, newState) +} + +// Main logic is here: receive messages from the postgres backend, forward +// notifications and query replies and keep the internal state in sync with the +// protocol state. Returns when the connection has been lost, is about to go +// away or should be discarded because we couldn't agree on the state with the +// server backend. +func (l *ListenerConn) listenerConnLoop() (err error) { + r := &readBuf{} + for { + t, err := l.cn.recvMessage(r) + if err != nil { + return err + } + + switch t { + case proto.NotificationResponse: + // recvNotification copies all the data so we don't need to worry + // about the scratch buffer being overwritten. + l.notificationChan <- recvNotification(r) + + case proto.RowDescription, proto.DataRow: + // only used by tests; ignore + + case proto.ErrorResponse: + // We might receive an ErrorResponse even when not in a query; it + // is expected that the server will close the connection after + // that, but we should make sure that the error we display is the + // one from the stray ErrorResponse, not io.ErrUnexpectedEOF. + if !l.setState(connStateExpectReadyForQuery) { + return parseError(r, "") + } + l.replyChan <- message{t, parseError(r, "")} + + case proto.CommandComplete, proto.EmptyQueryResponse: + if !l.setState(connStateExpectReadyForQuery) { + // protocol out of sync + return fmt.Errorf("unexpected CommandComplete") + } + // ExecSimpleQuery doesn't need to know about this message + + case proto.ReadyForQuery: + if !l.setState(connStateIdle) { + // protocol out of sync + return fmt.Errorf("unexpected ReadyForQuery") + } + l.replyChan <- message{t, nil} + + case proto.ParameterStatus: + // ignore + case proto.NoticeResponse: + if n := l.cn.noticeHandler; n != nil { + n(parseError(r, "")) + } + default: + return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) + } + } +} + +// This is the main routine for the goroutine receiving on the database +// connection. Most of the main logic is in listenerConnLoop. +func (l *ListenerConn) listenerConnMain() { + err := l.listenerConnLoop() + + // listenerConnLoop terminated; we're done, but we still have to clean up. + // Make sure nobody tries to start any new queries by making sure the err + // pointer is set. It is important that we do not overwrite its value; a + // connection could be closed by either this goroutine or one sending on the + // connection – whoever closes the connection is assumed to have the more + // meaningful error message (as the other one will probably get + // net.errClosed), so that goroutine sets the error we expose while the + // other error is discarded. If the connection is lost while two goroutines + // are operating on the socket, it probably doesn't matter which error we + // expose so we don't try to do anything more complex. + l.connectionLock.Lock() + if l.err == nil { + l.err = err + } + _ = l.cn.Close() + l.connectionLock.Unlock() + + // There might be a query in-flight; make sure nobody's waiting for a + // response to it, since there's not going to be one. + close(l.replyChan) + + // let the listener know we're done + close(l.notificationChan) + + // this ListenerConn is done +} + +// Listen sends a LISTEN query to the server. See ExecSimpleQuery. +func (l *ListenerConn) Listen(channel string) (bool, error) { + return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) +} + +// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery. +func (l *ListenerConn) Unlisten(channel string) (bool, error) { + return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) +} + +// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery. +func (l *ListenerConn) UnlistenAll() (bool, error) { + return l.ExecSimpleQuery("UNLISTEN *") +} + +// Ping the remote server to make sure it's alive. Non-nil error means the +// connection has failed and should be abandoned. +func (l *ListenerConn) Ping() error { + sent, err := l.ExecSimpleQuery("") + if !sent { + return err + } + if err != nil { // shouldn't happen + panic(err) + } + return nil +} + +// Attempt to send a query on the connection. Returns an error if sending the +// query failed, and the caller should initiate closure of this connection. The +// caller must be holding senderLock (see acquireSenderLock and +// releaseSenderLock). +func (l *ListenerConn) sendSimpleQuery(q string) (err error) { + // Must set connection state before sending the query + if !l.setState(connStateExpectResponse) { + return errors.New("pq: two queries running at the same time") + } + + // Can't use l.cn.writeBuf here because it uses the scratch buffer which + // might get overwritten by listenerConnLoop. + b := &writeBuf{ + buf: []byte("Q\x00\x00\x00\x00"), + pos: 1, + } + b.string(q) + return l.cn.send(b) +} + +// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable +// parameters) on the connection. The possible return values are: +// 1. "executed" is true; the query was executed to completion on the database +// server. If the query failed, err will be set to the error returned by the +// database, otherwise err will be nil. +// 2. If "executed" is false, the query could not be executed on the remote +// server. err will be non-nil. +// +// After a call to ExecSimpleQuery has returned an executed=false value, the +// connection has either been closed or will be closed shortly thereafter, and +// all subsequently executed queries will return an error. +func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { + if err = l.acquireSenderLock(); err != nil { + return false, err + } + defer l.releaseSenderLock() + + err = l.sendSimpleQuery(q) + if err != nil { + // We can't know what state the protocol is in, so we need to abandon + // this connection. + l.connectionLock.Lock() + // Set the error pointer if it hasn't been set already; see + // listenerConnMain. + if l.err == nil { + l.err = err + } + l.connectionLock.Unlock() + _ = l.cn.c.Close() + return false, err + } + + // now we just wait for a reply.. + for { + m, ok := <-l.replyChan + if !ok { + // We lost the connection to server, don't bother waiting for a + // a response. err should have been set already. + l.connectionLock.Lock() + err := l.err + l.connectionLock.Unlock() + return false, err + } + switch m.typ { + case proto.ReadyForQuery: + // sanity check + if m.err != nil { + panic("m.err != nil") + } + // done; err might or might not be set + return true, err + + case proto.ErrorResponse: + // sanity check + if m.err == nil { + panic("m.err == nil") + } + // server responded with an error; ReadyForQuery to follow + err = m.err + + default: + return false, fmt.Errorf("unknown response for simple query: %q", m.typ) + } + } +} + +// Close closes the connection. +func (l *ListenerConn) Close() error { + l.connectionLock.Lock() + if l.err != nil { + l.connectionLock.Unlock() + return errListenerConnClosed + } + l.err = errListenerConnClosed + l.connectionLock.Unlock() + // We can't send anything on the connection without holding senderLock. + // Simply close the net.Conn to wake up everyone operating on it. + return l.cn.c.Close() +} + +// Err returns the reason the connection was closed. It is not safe to call +// this function until l.Notify has been closed. +func (l *ListenerConn) Err() error { + return l.err +} + +// ErrChannelAlreadyOpen is returned from Listen when a channel is already +// open. +var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") + +// ErrChannelNotOpen is returned from Unlisten when a channel is not open. +var ErrChannelNotOpen = errors.New("pq: channel is not open") + +// ListenerEventType is an enumeration of listener event types. +type ListenerEventType int + +const ( + // ListenerEventConnected is emitted only when the database connection has + // been initially initialized. The err argument of the callback will always + // be nil. + ListenerEventConnected ListenerEventType = iota + + // ListenerEventDisconnected is emitted after a database connection has been + // lost, either because of an error or because Close has been called. The + // err argument will be set to the reason the database connection was lost. + ListenerEventDisconnected + + // ListenerEventReconnected is emitted after a database connection has been + // re-established after connection loss. The err argument of the callback + // will always be nil. After this event has been emitted, a nil + // pq.Notification is sent on the Listener.Notify channel. + ListenerEventReconnected + + // ListenerEventConnectionAttemptFailed is emitted after a connection to the + // database was attempted, but failed. The err argument will be set to an + // error describing why the connection attempt did not succeed. + ListenerEventConnectionAttemptFailed +) + +// EventCallbackType is the event callback type. See also ListenerEventType +// constants' documentation. +type EventCallbackType func(event ListenerEventType, err error) + +func (l ListenerEventType) String() string { + return map[ListenerEventType]string{ + ListenerEventConnected: "connected", + ListenerEventDisconnected: "disconnected", + ListenerEventReconnected: "reconnected", + ListenerEventConnectionAttemptFailed: "connectionAttemptFailed", + }[l] +} + +// Listener provides an interface for listening to notifications from a +// PostgreSQL database. For general usage information, see section +// "Notifications". +// +// Listener can safely be used from concurrently running goroutines. +type Listener struct { + // Channel for receiving notifications from the database. In some cases a + // nil value will be sent. See section "Notifications" above. + Notify chan *Notification + + dsn string + minReconnectInterval time.Duration + maxReconnectInterval time.Duration + dialer Dialer + eventCallback EventCallbackType + + lock sync.Mutex + isClosed bool + reconnectCond *sync.Cond + cn *ListenerConn + connNotificationChan <-chan *Notification + channels map[string]struct{} +} + +// NewListener creates a new database connection dedicated to LISTEN / NOTIFY. +// +// name should be set to a connection string to be used to establish the +// database connection (see section "Connection String Parameters" above). +// +// minReconnect controls the duration to wait before trying to re-establish the +// database connection after connection loss. After each consecutive failure +// this interval is doubled, until maxReconnect is reached. Successfully +// completing the connection establishment procedure resets the interval back to +// minReconnect. +// +// The last parameter cb can be set to a function which will be called by the +// Listener when the state of the underlying database connection changes. This +// callback will be called by the goroutine which dispatches the notifications +// over the Notify channel, so you should try to avoid doing potentially +// time-consuming operations from the callback. +func NewListener(dsn string, minReconnect, maxReconnect time.Duration, cb EventCallbackType) *Listener { + return NewDialListener(defaultDialer{}, dsn, minReconnect, maxReconnect, cb) +} + +// NewDialListener is like NewListener but it takes a Dialer. +func NewDialListener(d Dialer, dsn string, minReconnect, maxReconnect time.Duration, cb EventCallbackType) *Listener { + l := &Listener{ + dsn: dsn, + minReconnectInterval: minReconnect, + maxReconnectInterval: maxReconnect, + dialer: d, + eventCallback: cb, + channels: make(map[string]struct{}), + Notify: make(chan *Notification, 32), + } + l.reconnectCond = sync.NewCond(&l.lock) + go l.listenerMain() + return l +} + +// NotificationChannel returns the notification channel for this listener. This +// is the same channel as Notify, and will not be recreated during the life time +// of the Listener. +func (l *Listener) NotificationChannel() <-chan *Notification { + return l.Notify +} + +// Listen starts listening for notifications on a channel. Calls to this +// function will block until an acknowledgement has been received from the +// server. Note that Listener automatically re-establishes the connection after +// connection loss, so this function may block indefinitely if the connection +// can not be re-established. +// +// Listen will only fail in three conditions: +// 1. The channel is already open. The returned error will be +// [ErrChannelAlreadyOpen]. +// 2. The query was executed on the remote server, but PostgreSQL returned an +// error message in response to the query. The returned error will be a +// [pq.Error] containing the information the server supplied. +// 3. Close is called on the Listener before the request could be completed. +// +// The channel name is case-sensitive. +func (l *Listener) Listen(channel string) error { + l.lock.Lock() + defer l.lock.Unlock() + if l.isClosed { + return net.ErrClosed + } + + // The server allows you to issue a LISTEN on a channel which is already + // open, but it seems useful to be able to detect this case to spot for + // mistakes in application logic. If the application genuinely does't care, + // it can check the exported error and ignore it. + _, exists := l.channels[channel] + if exists { + return ErrChannelAlreadyOpen + } + + if l.cn != nil { + // If resp is true but error is set then the query was executed on the + // remote server but resulted in an error. This should be relatively + // rare, so it's fine if we just pass the error to our caller. + // If resp is false then we could not complete the query on the remote + // server and our underlying connection is about to go away, so we only + // add relname to l.channels, and wait for resync() to take care of the + // rest. + resp, err := l.cn.Listen(channel) + if resp && err != nil { + return err + } + } + + l.channels[channel] = struct{}{} + for l.cn == nil { + l.reconnectCond.Wait() + // we let go of the mutex for a while + if l.isClosed { + return net.ErrClosed + } + } + + return nil +} + +// Unlisten removes a channel from the Listener's channel list. Returns +// ErrChannelNotOpen if the Listener is not listening on the specified channel. +// Returns immediately with no error if there is no connection. Note that you +// might still get notifications for this channel even after Unlisten has +// returned. +// +// The channel name is case-sensitive. +func (l *Listener) Unlisten(channel string) error { + l.lock.Lock() + defer l.lock.Unlock() + + if l.isClosed { + return net.ErrClosed + } + + // Similarly to LISTEN, this is not an error in Postgres, but it seems + // useful to distinguish from the normal conditions. + _, exists := l.channels[channel] + if !exists { + return ErrChannelNotOpen + } + + if l.cn != nil { + // Similarly to Listen (see comment there), the caller should only be + // bothered with an error if it came from the backend as a response to + // our query. + resp, err := l.cn.Unlisten(channel) + if resp && err != nil { + return err + } + } + + // Don't bother waiting for resync if there's no connection. + delete(l.channels, channel) + return nil +} + +// UnlistenAll removes all channels from the Listener's channel list. Returns +// immediately with no error if there is no connection. Note that you might +// still get notifications for any of the deleted channels even after +// UnlistenAll has returned. +func (l *Listener) UnlistenAll() error { + l.lock.Lock() + defer l.lock.Unlock() + + if l.isClosed { + return net.ErrClosed + } + + if l.cn != nil { + // Similarly to Listen (see comment in that function), the caller + // should only be bothered with an error if it came from the backend as + // a response to our query. + gotResponse, err := l.cn.UnlistenAll() + if gotResponse && err != nil { + return err + } + } + + // Don't bother waiting for resync if there's no connection. + l.channels = make(map[string]struct{}) + return nil +} + +// Ping the remote server to make sure it's alive. Non-nil return value means +// that there is no active connection. +func (l *Listener) Ping() error { + l.lock.Lock() + defer l.lock.Unlock() + + if l.isClosed { + return net.ErrClosed + } + if l.cn == nil { + return errors.New("no connection") + } + + return l.cn.Ping() +} + +// Clean up after losing the server connection. Returns l.cn.Err(), which should +// have the reason the connection was lost. +func (l *Listener) disconnectCleanup() error { + l.lock.Lock() + defer l.lock.Unlock() + + // sanity check; can't look at Err() until the channel has been closed + select { + case _, ok := <-l.connNotificationChan: + if ok { + panic("connNotificationChan not closed") + } + default: + panic("connNotificationChan not closed") + } + + err := l.cn.Err() + _ = l.cn.Close() + l.cn = nil + return err +} + +// Synchronize the list of channels we want to be listening on with the server +// after the connection has been established. +func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { + doneChan := make(chan error) + go func(notificationChan <-chan *Notification) { + for channel := range l.channels { + // If we got a response, return that error to our caller as it's + // going to be more descriptive than cn.Err(). + gotResponse, err := cn.Listen(channel) + if gotResponse && err != nil { + doneChan <- err + return + } + + // If we couldn't reach the server, wait for notificationChan to + // close and then return the error message from the connection, as + // per ListenerConn's interface. + if err != nil { + for range notificationChan { + } + doneChan <- cn.Err() + return + } + } + doneChan <- nil + }(notificationChan) + + // Ignore notifications while synchronization is going on to avoid + // deadlocks. We have to send a nil notification over Notify anyway as we + // can't possibly know which notifications (if any) were lost while the + // connection was down, so there's no reason to try and process these + // messages at all. + for { + select { + case _, ok := <-notificationChan: + if !ok { + notificationChan = nil + } + + case err := <-doneChan: + return err + } + } +} + +// caller should NOT be holding l.lock +func (l *Listener) closed() bool { + l.lock.Lock() + defer l.lock.Unlock() + + return l.isClosed +} + +func (l *Listener) connect() error { + l.lock.Lock() + defer l.lock.Unlock() + if l.isClosed { + return net.ErrClosed + } + + notificationChan := make(chan *Notification, 32) + + var err error + l.cn, err = newDialListenerConn(l.dialer, l.dsn, notificationChan) + if err != nil { + return err + } + + err = l.resync(l.cn, notificationChan) + if err != nil { + _ = l.cn.Close() + return err + } + + l.connNotificationChan = notificationChan + l.reconnectCond.Broadcast() + return nil +} + +// Close disconnects the Listener from the database and shuts it down. +// Subsequent calls to its methods will return an error. Close returns an error +// if the connection has already been closed. +func (l *Listener) Close() error { + l.lock.Lock() + defer l.lock.Unlock() + + if l.isClosed { + return net.ErrClosed + } + + if l.cn != nil { + _ = l.cn.Close() + } + l.isClosed = true + + // Unblock calls to Listen() + l.reconnectCond.Broadcast() + + return nil +} + +func (l *Listener) emitEvent(event ListenerEventType, err error) { + if l.eventCallback != nil { + l.eventCallback(event, err) + } +} + +// Main logic here: maintain a connection to the server when possible, wait +// for notifications and emit events. +func (l *Listener) listenerConnLoop() { + var ( + nextReconnect time.Time + reconnectInterval = l.minReconnectInterval + ) + for { + for { + err := l.connect() + if err == nil { + break + } + if l.closed() { + return + } + + l.emitEvent(ListenerEventConnectionAttemptFailed, err) + time.Sleep(reconnectInterval) + reconnectInterval *= 2 + if reconnectInterval > l.maxReconnectInterval { + reconnectInterval = l.maxReconnectInterval + } + } + + if nextReconnect.IsZero() { + l.emitEvent(ListenerEventConnected, nil) + } else { + l.emitEvent(ListenerEventReconnected, nil) + l.Notify <- nil + } + + reconnectInterval = l.minReconnectInterval + nextReconnect = time.Now().Add(reconnectInterval) + + for { + notification, ok := <-l.connNotificationChan + if !ok { // lost connection, loop again + break + } + l.Notify <- notification + } + + err := l.disconnectCleanup() + if l.closed() { + return + } + l.emitEvent(ListenerEventDisconnected, err) + + time.Sleep(time.Until(nextReconnect)) + } +} + +func (l *Listener) listenerMain() { + l.listenerConnLoop() + close(l.Notify) +} diff --git a/vendor/github.com/lib/pq/oid/doc.go b/vendor/github.com/lib/pq/oid/doc.go new file mode 100644 index 00000000..a4865066 --- /dev/null +++ b/vendor/github.com/lib/pq/oid/doc.go @@ -0,0 +1,7 @@ +//go:generate go run ./gen.go + +// Package oid contains OID constants as defined by the Postgres server. +package oid + +// Oid is a Postgres Object ID. +type Oid uint32 diff --git a/vendor/github.com/lib/pq/oid/types.go b/vendor/github.com/lib/pq/oid/types.go new file mode 100644 index 00000000..ecc84c2c --- /dev/null +++ b/vendor/github.com/lib/pq/oid/types.go @@ -0,0 +1,343 @@ +// Code generated by gen.go. DO NOT EDIT. + +package oid + +const ( + T_bool Oid = 16 + T_bytea Oid = 17 + T_char Oid = 18 + T_name Oid = 19 + T_int8 Oid = 20 + T_int2 Oid = 21 + T_int2vector Oid = 22 + T_int4 Oid = 23 + T_regproc Oid = 24 + T_text Oid = 25 + T_oid Oid = 26 + T_tid Oid = 27 + T_xid Oid = 28 + T_cid Oid = 29 + T_oidvector Oid = 30 + T_pg_ddl_command Oid = 32 + T_pg_type Oid = 71 + T_pg_attribute Oid = 75 + T_pg_proc Oid = 81 + T_pg_class Oid = 83 + T_json Oid = 114 + T_xml Oid = 142 + T__xml Oid = 143 + T_pg_node_tree Oid = 194 + T__json Oid = 199 + T_smgr Oid = 210 + T_index_am_handler Oid = 325 + T_point Oid = 600 + T_lseg Oid = 601 + T_path Oid = 602 + T_box Oid = 603 + T_polygon Oid = 604 + T_line Oid = 628 + T__line Oid = 629 + T_cidr Oid = 650 + T__cidr Oid = 651 + T_float4 Oid = 700 + T_float8 Oid = 701 + T_abstime Oid = 702 + T_reltime Oid = 703 + T_tinterval Oid = 704 + T_unknown Oid = 705 + T_circle Oid = 718 + T__circle Oid = 719 + T_money Oid = 790 + T__money Oid = 791 + T_macaddr Oid = 829 + T_inet Oid = 869 + T__bool Oid = 1000 + T__bytea Oid = 1001 + T__char Oid = 1002 + T__name Oid = 1003 + T__int2 Oid = 1005 + T__int2vector Oid = 1006 + T__int4 Oid = 1007 + T__regproc Oid = 1008 + T__text Oid = 1009 + T__tid Oid = 1010 + T__xid Oid = 1011 + T__cid Oid = 1012 + T__oidvector Oid = 1013 + T__bpchar Oid = 1014 + T__varchar Oid = 1015 + T__int8 Oid = 1016 + T__point Oid = 1017 + T__lseg Oid = 1018 + T__path Oid = 1019 + T__box Oid = 1020 + T__float4 Oid = 1021 + T__float8 Oid = 1022 + T__abstime Oid = 1023 + T__reltime Oid = 1024 + T__tinterval Oid = 1025 + T__polygon Oid = 1027 + T__oid Oid = 1028 + T_aclitem Oid = 1033 + T__aclitem Oid = 1034 + T__macaddr Oid = 1040 + T__inet Oid = 1041 + T_bpchar Oid = 1042 + T_varchar Oid = 1043 + T_date Oid = 1082 + T_time Oid = 1083 + T_timestamp Oid = 1114 + T__timestamp Oid = 1115 + T__date Oid = 1182 + T__time Oid = 1183 + T_timestamptz Oid = 1184 + T__timestamptz Oid = 1185 + T_interval Oid = 1186 + T__interval Oid = 1187 + T__numeric Oid = 1231 + T_pg_database Oid = 1248 + T__cstring Oid = 1263 + T_timetz Oid = 1266 + T__timetz Oid = 1270 + T_bit Oid = 1560 + T__bit Oid = 1561 + T_varbit Oid = 1562 + T__varbit Oid = 1563 + T_numeric Oid = 1700 + T_refcursor Oid = 1790 + T__refcursor Oid = 2201 + T_regprocedure Oid = 2202 + T_regoper Oid = 2203 + T_regoperator Oid = 2204 + T_regclass Oid = 2205 + T_regtype Oid = 2206 + T__regprocedure Oid = 2207 + T__regoper Oid = 2208 + T__regoperator Oid = 2209 + T__regclass Oid = 2210 + T__regtype Oid = 2211 + T_record Oid = 2249 + T_cstring Oid = 2275 + T_any Oid = 2276 + T_anyarray Oid = 2277 + T_void Oid = 2278 + T_trigger Oid = 2279 + T_language_handler Oid = 2280 + T_internal Oid = 2281 + T_opaque Oid = 2282 + T_anyelement Oid = 2283 + T__record Oid = 2287 + T_anynonarray Oid = 2776 + T_pg_authid Oid = 2842 + T_pg_auth_members Oid = 2843 + T__txid_snapshot Oid = 2949 + T_uuid Oid = 2950 + T__uuid Oid = 2951 + T_txid_snapshot Oid = 2970 + T_fdw_handler Oid = 3115 + T_pg_lsn Oid = 3220 + T__pg_lsn Oid = 3221 + T_tsm_handler Oid = 3310 + T_anyenum Oid = 3500 + T_tsvector Oid = 3614 + T_tsquery Oid = 3615 + T_gtsvector Oid = 3642 + T__tsvector Oid = 3643 + T__gtsvector Oid = 3644 + T__tsquery Oid = 3645 + T_regconfig Oid = 3734 + T__regconfig Oid = 3735 + T_regdictionary Oid = 3769 + T__regdictionary Oid = 3770 + T_jsonb Oid = 3802 + T__jsonb Oid = 3807 + T_anyrange Oid = 3831 + T_event_trigger Oid = 3838 + T_int4range Oid = 3904 + T__int4range Oid = 3905 + T_numrange Oid = 3906 + T__numrange Oid = 3907 + T_tsrange Oid = 3908 + T__tsrange Oid = 3909 + T_tstzrange Oid = 3910 + T__tstzrange Oid = 3911 + T_daterange Oid = 3912 + T__daterange Oid = 3913 + T_int8range Oid = 3926 + T__int8range Oid = 3927 + T_pg_shseclabel Oid = 4066 + T_regnamespace Oid = 4089 + T__regnamespace Oid = 4090 + T_regrole Oid = 4096 + T__regrole Oid = 4097 +) + +var TypeName = map[Oid]string{ + T_bool: "BOOL", + T_bytea: "BYTEA", + T_char: "CHAR", + T_name: "NAME", + T_int8: "INT8", + T_int2: "INT2", + T_int2vector: "INT2VECTOR", + T_int4: "INT4", + T_regproc: "REGPROC", + T_text: "TEXT", + T_oid: "OID", + T_tid: "TID", + T_xid: "XID", + T_cid: "CID", + T_oidvector: "OIDVECTOR", + T_pg_ddl_command: "PG_DDL_COMMAND", + T_pg_type: "PG_TYPE", + T_pg_attribute: "PG_ATTRIBUTE", + T_pg_proc: "PG_PROC", + T_pg_class: "PG_CLASS", + T_json: "JSON", + T_xml: "XML", + T__xml: "_XML", + T_pg_node_tree: "PG_NODE_TREE", + T__json: "_JSON", + T_smgr: "SMGR", + T_index_am_handler: "INDEX_AM_HANDLER", + T_point: "POINT", + T_lseg: "LSEG", + T_path: "PATH", + T_box: "BOX", + T_polygon: "POLYGON", + T_line: "LINE", + T__line: "_LINE", + T_cidr: "CIDR", + T__cidr: "_CIDR", + T_float4: "FLOAT4", + T_float8: "FLOAT8", + T_abstime: "ABSTIME", + T_reltime: "RELTIME", + T_tinterval: "TINTERVAL", + T_unknown: "UNKNOWN", + T_circle: "CIRCLE", + T__circle: "_CIRCLE", + T_money: "MONEY", + T__money: "_MONEY", + T_macaddr: "MACADDR", + T_inet: "INET", + T__bool: "_BOOL", + T__bytea: "_BYTEA", + T__char: "_CHAR", + T__name: "_NAME", + T__int2: "_INT2", + T__int2vector: "_INT2VECTOR", + T__int4: "_INT4", + T__regproc: "_REGPROC", + T__text: "_TEXT", + T__tid: "_TID", + T__xid: "_XID", + T__cid: "_CID", + T__oidvector: "_OIDVECTOR", + T__bpchar: "_BPCHAR", + T__varchar: "_VARCHAR", + T__int8: "_INT8", + T__point: "_POINT", + T__lseg: "_LSEG", + T__path: "_PATH", + T__box: "_BOX", + T__float4: "_FLOAT4", + T__float8: "_FLOAT8", + T__abstime: "_ABSTIME", + T__reltime: "_RELTIME", + T__tinterval: "_TINTERVAL", + T__polygon: "_POLYGON", + T__oid: "_OID", + T_aclitem: "ACLITEM", + T__aclitem: "_ACLITEM", + T__macaddr: "_MACADDR", + T__inet: "_INET", + T_bpchar: "BPCHAR", + T_varchar: "VARCHAR", + T_date: "DATE", + T_time: "TIME", + T_timestamp: "TIMESTAMP", + T__timestamp: "_TIMESTAMP", + T__date: "_DATE", + T__time: "_TIME", + T_timestamptz: "TIMESTAMPTZ", + T__timestamptz: "_TIMESTAMPTZ", + T_interval: "INTERVAL", + T__interval: "_INTERVAL", + T__numeric: "_NUMERIC", + T_pg_database: "PG_DATABASE", + T__cstring: "_CSTRING", + T_timetz: "TIMETZ", + T__timetz: "_TIMETZ", + T_bit: "BIT", + T__bit: "_BIT", + T_varbit: "VARBIT", + T__varbit: "_VARBIT", + T_numeric: "NUMERIC", + T_refcursor: "REFCURSOR", + T__refcursor: "_REFCURSOR", + T_regprocedure: "REGPROCEDURE", + T_regoper: "REGOPER", + T_regoperator: "REGOPERATOR", + T_regclass: "REGCLASS", + T_regtype: "REGTYPE", + T__regprocedure: "_REGPROCEDURE", + T__regoper: "_REGOPER", + T__regoperator: "_REGOPERATOR", + T__regclass: "_REGCLASS", + T__regtype: "_REGTYPE", + T_record: "RECORD", + T_cstring: "CSTRING", + T_any: "ANY", + T_anyarray: "ANYARRAY", + T_void: "VOID", + T_trigger: "TRIGGER", + T_language_handler: "LANGUAGE_HANDLER", + T_internal: "INTERNAL", + T_opaque: "OPAQUE", + T_anyelement: "ANYELEMENT", + T__record: "_RECORD", + T_anynonarray: "ANYNONARRAY", + T_pg_authid: "PG_AUTHID", + T_pg_auth_members: "PG_AUTH_MEMBERS", + T__txid_snapshot: "_TXID_SNAPSHOT", + T_uuid: "UUID", + T__uuid: "_UUID", + T_txid_snapshot: "TXID_SNAPSHOT", + T_fdw_handler: "FDW_HANDLER", + T_pg_lsn: "PG_LSN", + T__pg_lsn: "_PG_LSN", + T_tsm_handler: "TSM_HANDLER", + T_anyenum: "ANYENUM", + T_tsvector: "TSVECTOR", + T_tsquery: "TSQUERY", + T_gtsvector: "GTSVECTOR", + T__tsvector: "_TSVECTOR", + T__gtsvector: "_GTSVECTOR", + T__tsquery: "_TSQUERY", + T_regconfig: "REGCONFIG", + T__regconfig: "_REGCONFIG", + T_regdictionary: "REGDICTIONARY", + T__regdictionary: "_REGDICTIONARY", + T_jsonb: "JSONB", + T__jsonb: "_JSONB", + T_anyrange: "ANYRANGE", + T_event_trigger: "EVENT_TRIGGER", + T_int4range: "INT4RANGE", + T__int4range: "_INT4RANGE", + T_numrange: "NUMRANGE", + T__numrange: "_NUMRANGE", + T_tsrange: "TSRANGE", + T__tsrange: "_TSRANGE", + T_tstzrange: "TSTZRANGE", + T__tstzrange: "_TSTZRANGE", + T_daterange: "DATERANGE", + T__daterange: "_DATERANGE", + T_int8range: "INT8RANGE", + T__int8range: "_INT8RANGE", + T_pg_shseclabel: "PG_SHSECLABEL", + T_regnamespace: "REGNAMESPACE", + T__regnamespace: "_REGNAMESPACE", + T_regrole: "REGROLE", + T__regrole: "_REGROLE", +} diff --git a/vendor/github.com/lib/pq/pqerror/codes.go b/vendor/github.com/lib/pq/pqerror/codes.go new file mode 100644 index 00000000..f5576644 --- /dev/null +++ b/vendor/github.com/lib/pq/pqerror/codes.go @@ -0,0 +1,581 @@ +// Code generated by gen.go. DO NOT EDIT. + +// Last updated for PostgreSQL 18.3 + +package pqerror + +var ( + ClassSuccessfulCompletion = Class("00") // Successful Completion + ClassWarning = Class("01") // Warning + ClassNoData = Class("02") // No Data (this is also a warning class per the SQL standard) + ClassSQLStatementNotYetComplete = Class("03") // SQL Statement Not Yet Complete + ClassConnectionException = Class("08") // Connection Exception + ClassTriggeredActionException = Class("09") // Triggered Action Exception + ClassFeatureNotSupported = Class("0A") // Feature Not Supported + ClassInvalidTransactionInitiation = Class("0B") // Invalid Transaction Initiation + ClassLocatorException = Class("0F") // Locator Exception + ClassInvalidGrantor = Class("0L") // Invalid Grantor + ClassInvalidRoleSpecification = Class("0P") // Invalid Role Specification + ClassDiagnosticsException = Class("0Z") // Diagnostics Exception + ClassCaseNotFound = Class("20") // Case Not Found + ClassCardinalityViolation = Class("21") // Cardinality Violation + ClassDataException = Class("22") // Data Exception + ClassIntegrityConstraintViolation = Class("23") // Integrity Constraint Violation + ClassInvalidCursorState = Class("24") // Invalid Cursor State + ClassInvalidTransactionState = Class("25") // Invalid Transaction State + ClassInvalidSQLStatementName = Class("26") // Invalid SQL Statement Name + ClassTriggeredDataChangeViolation = Class("27") // Triggered Data Change Violation + ClassInvalidAuthorizationSpecification = Class("28") // Invalid Authorization Specification + ClassDependentPrivilegeDescriptorsStillExist = Class("2B") // Dependent Privilege Descriptors Still Exist + ClassInvalidTransactionTermination = Class("2D") // Invalid Transaction Termination + ClassSQLRoutineException = Class("2F") // SQL Routine Exception + ClassInvalidCursorName = Class("34") // Invalid Cursor Name + ClassExternalRoutineException = Class("38") // External Routine Exception + ClassExternalRoutineInvocationException = Class("39") // External Routine Invocation Exception + ClassSavepointException = Class("3B") // Savepoint Exception + ClassInvalidCatalogName = Class("3D") // Invalid Catalog Name + ClassInvalidSchemaName = Class("3F") // Invalid Schema Name + ClassTransactionRollback = Class("40") // Transaction Rollback + ClassSyntaxErrorOrAccessRuleViolation = Class("42") // Syntax Error or Access Rule Violation + ClassWithCheckOptionViolation = Class("44") // WITH CHECK OPTION Violation + ClassInsufficientResources = Class("53") // Insufficient Resources + ClassProgramLimitExceeded = Class("54") // Program Limit Exceeded + ClassObjectNotInPrerequisiteState = Class("55") // Object Not In Prerequisite State + ClassOperatorIntervention = Class("57") // Operator Intervention + ClassSystemError = Class("58") // System Error (errors external to PostgreSQL itself) + ClassConfigFileError = Class("F0") // Configuration File Error + ClassFDWError = Class("HV") // Foreign Data Wrapper Error (SQL/MED) + ClassPLpgSQLError = Class("P0") // PL/pgSQL Error + ClassInternalError = Class("XX") // Internal Error +) + +// A list of all error codes used in PostgreSQL. +var ( + SuccessfulCompletion = Code("00000") // Class 00 - Successful Completion + Warning = Code("01000") // Class 01 - Warning + WarningDynamicResultSetsReturned = Code("0100C") + WarningImplicitZeroBitPadding = Code("01008") + WarningNullValueEliminatedInSetFunction = Code("01003") + WarningPrivilegeNotGranted = Code("01007") + WarningPrivilegeNotRevoked = Code("01006") + WarningStringDataRightTruncation = Code("01004") + WarningDeprecatedFeature = Code("01P01") + NoData = Code("02000") // Class 02 - No Data (this is also a warning class per the SQL standard) + NoAdditionalDynamicResultSetsReturned = Code("02001") + SQLStatementNotYetComplete = Code("03000") // Class 03 - SQL Statement Not Yet Complete + ConnectionException = Code("08000") // Class 08 - Connection Exception + ConnectionDoesNotExist = Code("08003") + ConnectionFailure = Code("08006") + SQLClientUnableToEstablishSQLConnection = Code("08001") + SQLServerRejectedEstablishmentOfSQLConnection = Code("08004") + TransactionResolutionUnknown = Code("08007") + ProtocolViolation = Code("08P01") + TriggeredActionException = Code("09000") // Class 09 - Triggered Action Exception + FeatureNotSupported = Code("0A000") // Class 0A - Feature Not Supported + InvalidTransactionInitiation = Code("0B000") // Class 0B - Invalid Transaction Initiation + LocatorException = Code("0F000") // Class 0F - Locator Exception + LEInvalidSpecification = Code("0F001") + InvalidGrantor = Code("0L000") // Class 0L - Invalid Grantor + InvalidGrantOperation = Code("0LP01") + InvalidRoleSpecification = Code("0P000") // Class 0P - Invalid Role Specification + DiagnosticsException = Code("0Z000") // Class 0Z - Diagnostics Exception + StackedDiagnosticsAccessedWithoutActiveHandler = Code("0Z002") + InvalidArgumentForXquery = Code("10608") + CaseNotFound = Code("20000") // Class 20 - Case Not Found + CardinalityViolation = Code("21000") // Class 21 - Cardinality Violation + DataException = Code("22000") // Class 22 - Data Exception + ArraySubscriptError = Code("2202E") + CharacterNotInRepertoire = Code("22021") + DatetimeFieldOverflow = Code("22008") + DivisionByZero = Code("22012") + ErrorInAssignment = Code("22005") + EscapeCharacterConflict = Code("2200B") + IndicatorOverflow = Code("22022") + IntervalFieldOverflow = Code("22015") + InvalidArgumentForLog = Code("2201E") + InvalidArgumentForNtile = Code("22014") + InvalidArgumentForNthValue = Code("22016") + InvalidArgumentForPowerFunction = Code("2201F") + InvalidArgumentForWidthBucketFunction = Code("2201G") + InvalidCharacterValueForCast = Code("22018") + InvalidDatetimeFormat = Code("22007") + InvalidEscapeCharacter = Code("22019") + InvalidEscapeOctet = Code("2200D") + InvalidEscapeSequence = Code("22025") + NonstandardUseOfEscapeCharacter = Code("22P06") + InvalidIndicatorParameterValue = Code("22010") + InvalidParameterValue = Code("22023") + InvalidPrecedingOrFollowingSize = Code("22013") + InvalidRegularExpression = Code("2201B") + InvalidRowCountInLimitClause = Code("2201W") + InvalidRowCountInResultOffsetClause = Code("2201X") + InvalidTablesampleArgument = Code("2202H") + InvalidTablesampleRepeat = Code("2202G") + InvalidTimeZoneDisplacementValue = Code("22009") + InvalidUseOfEscapeCharacter = Code("2200C") + MostSpecificTypeMismatch = Code("2200G") + NullValueNotAllowed = Code("22004") + NullValueNoIndicatorParameter = Code("22002") + NumericValueOutOfRange = Code("22003") + SequenceGeneratorLimitExceeded = Code("2200H") + StringDataLengthMismatch = Code("22026") + StringDataRightTruncation = Code("22001") + SubstringError = Code("22011") + TrimError = Code("22027") + UnterminatedCString = Code("22024") + ZeroLengthCharacterString = Code("2200F") + FloatingPointException = Code("22P01") + InvalidTextRepresentation = Code("22P02") + InvalidBinaryRepresentation = Code("22P03") + BadCopyFileFormat = Code("22P04") + UntranslatableCharacter = Code("22P05") + NotAnXMLDocument = Code("2200L") + InvalidXMLDocument = Code("2200M") + InvalidXMLContent = Code("2200N") + InvalidXMLComment = Code("2200S") + InvalidXMLProcessingInstruction = Code("2200T") + DuplicateJSONObjectKeyValue = Code("22030") + InvalidArgumentForSQLJSONDatetimeFunction = Code("22031") + InvalidJSONText = Code("22032") + InvalidSQLJSONSubscript = Code("22033") + MoreThanOneSQLJSONItem = Code("22034") + NoSQLJSONItem = Code("22035") + NonNumericSQLJSONItem = Code("22036") + NonUniqueKeysInAJSONObject = Code("22037") + SingletonSQLJSONItemRequired = Code("22038") + SQLJSONArrayNotFound = Code("22039") + SQLJSONMemberNotFound = Code("2203A") + SQLJSONNumberNotFound = Code("2203B") + SQLJSONObjectNotFound = Code("2203C") + TooManyJSONArrayElements = Code("2203D") + TooManyJSONObjectMembers = Code("2203E") + SQLJSONScalarRequired = Code("2203F") + SQLJSONItemCannotBeCastToTargetType = Code("2203G") + IntegrityConstraintViolation = Code("23000") // Class 23 - Integrity Constraint Violation + RestrictViolation = Code("23001") + NotNullViolation = Code("23502") + ForeignKeyViolation = Code("23503") + UniqueViolation = Code("23505") + CheckViolation = Code("23514") + ExclusionViolation = Code("23P01") + InvalidCursorState = Code("24000") // Class 24 - Invalid Cursor State + InvalidTransactionState = Code("25000") // Class 25 - Invalid Transaction State + ActiveSQLTransaction = Code("25001") + BranchTransactionAlreadyActive = Code("25002") + HeldCursorRequiresSameIsolationLevel = Code("25008") + InappropriateAccessModeForBranchTransaction = Code("25003") + InappropriateIsolationLevelForBranchTransaction = Code("25004") + NoActiveSQLTransactionForBranchTransaction = Code("25005") + ReadOnlySQLTransaction = Code("25006") + SchemaAndDataStatementMixingNotSupported = Code("25007") + NoActiveSQLTransaction = Code("25P01") + InFailedSQLTransaction = Code("25P02") + IdleInTransactionSessionTimeout = Code("25P03") + TransactionTimeout = Code("25P04") + InvalidSQLStatementName = Code("26000") // Class 26 - Invalid SQL Statement Name + TriggeredDataChangeViolation = Code("27000") // Class 27 - Triggered Data Change Violation + InvalidAuthorizationSpecification = Code("28000") // Class 28 - Invalid Authorization Specification + InvalidPassword = Code("28P01") + DependentPrivilegeDescriptorsStillExist = Code("2B000") // Class 2B - Dependent Privilege Descriptors Still Exist + DependentObjectsStillExist = Code("2BP01") + InvalidTransactionTermination = Code("2D000") // Class 2D - Invalid Transaction Termination + SQLRoutineException = Code("2F000") // Class 2F - SQL Routine Exception + SREFunctionExecutedNoReturnStatement = Code("2F005") + SREModifyingSQLDataNotPermitted = Code("2F002") + SREProhibitedSQLStatementAttempted = Code("2F003") + SREReadingSQLDataNotPermitted = Code("2F004") + InvalidCursorName = Code("34000") // Class 34 - Invalid Cursor Name + ExternalRoutineException = Code("38000") // Class 38 - External Routine Exception + EREContainingSQLNotPermitted = Code("38001") + EREModifyingSQLDataNotPermitted = Code("38002") + EREProhibitedSQLStatementAttempted = Code("38003") + EREReadingSQLDataNotPermitted = Code("38004") + ExternalRoutineInvocationException = Code("39000") // Class 39 - External Routine Invocation Exception + ERIEInvalidSQLSTATEReturned = Code("39001") + ERIENullValueNotAllowed = Code("39004") + ERIETriggerProtocolViolated = Code("39P01") + ERIESrfProtocolViolated = Code("39P02") + ERIEEventTriggerProtocolViolated = Code("39P03") + SavepointException = Code("3B000") // Class 3B - Savepoint Exception + SEInvalidSpecification = Code("3B001") + InvalidCatalogName = Code("3D000") // Class 3D - Invalid Catalog Name + InvalidSchemaName = Code("3F000") // Class 3F - Invalid Schema Name + TransactionRollback = Code("40000") // Class 40 - Transaction Rollback + TRIntegrityConstraintViolation = Code("40002") + TRSerializationFailure = Code("40001") + TRStatementCompletionUnknown = Code("40003") + TRDeadlockDetected = Code("40P01") + SyntaxErrorOrAccessRuleViolation = Code("42000") // Class 42 - Syntax Error or Access Rule Violation + SyntaxError = Code("42601") + InsufficientPrivilege = Code("42501") + CannotCoerce = Code("42846") + GroupingError = Code("42803") + WindowingError = Code("42P20") + InvalidRecursion = Code("42P19") + InvalidForeignKey = Code("42830") + InvalidName = Code("42602") + NameTooLong = Code("42622") + ReservedName = Code("42939") + DatatypeMismatch = Code("42804") + IndeterminateDatatype = Code("42P18") + CollationMismatch = Code("42P21") + IndeterminateCollation = Code("42P22") + WrongObjectType = Code("42809") + GeneratedAlways = Code("428C9") + UndefinedColumn = Code("42703") + UndefinedFunction = Code("42883") + UndefinedTable = Code("42P01") + UndefinedParameter = Code("42P02") + UndefinedObject = Code("42704") + DuplicateColumn = Code("42701") + DuplicateCursor = Code("42P03") + DuplicateDatabase = Code("42P04") + DuplicateFunction = Code("42723") + DuplicatePstatement = Code("42P05") + DuplicateSchema = Code("42P06") + DuplicateTable = Code("42P07") + DuplicateAlias = Code("42712") + DuplicateObject = Code("42710") + AmbiguousColumn = Code("42702") + AmbiguousFunction = Code("42725") + AmbiguousParameter = Code("42P08") + AmbiguousAlias = Code("42P09") + InvalidColumnReference = Code("42P10") + InvalidColumnDefinition = Code("42611") + InvalidCursorDefinition = Code("42P11") + InvalidDatabaseDefinition = Code("42P12") + InvalidFunctionDefinition = Code("42P13") + InvalidPstatementDefinition = Code("42P14") + InvalidSchemaDefinition = Code("42P15") + InvalidTableDefinition = Code("42P16") + InvalidObjectDefinition = Code("42P17") + WithCheckOptionViolation = Code("44000") // Class 44 - WITH CHECK OPTION Violation + InsufficientResources = Code("53000") // Class 53 - Insufficient Resources + DiskFull = Code("53100") + OutOfMemory = Code("53200") + TooManyConnections = Code("53300") + ConfigurationLimitExceeded = Code("53400") + ProgramLimitExceeded = Code("54000") // Class 54 - Program Limit Exceeded + StatementTooComplex = Code("54001") + TooManyColumns = Code("54011") + TooManyArguments = Code("54023") + ObjectNotInPrerequisiteState = Code("55000") // Class 55 - Object Not In Prerequisite State + ObjectInUse = Code("55006") + CantChangeRuntimeParam = Code("55P02") + LockNotAvailable = Code("55P03") + UnsafeNewEnumValueUsage = Code("55P04") + OperatorIntervention = Code("57000") // Class 57 - Operator Intervention + QueryCanceled = Code("57014") + AdminShutdown = Code("57P01") + CrashShutdown = Code("57P02") + CannotConnectNow = Code("57P03") + DatabaseDropped = Code("57P04") + IdleSessionTimeout = Code("57P05") + SystemError = Code("58000") // Class 58 - System Error (errors external to PostgreSQL itself) + IOError = Code("58030") + UndefinedFile = Code("58P01") + DuplicateFile = Code("58P02") + FileNameTooLong = Code("58P03") + ConfigFileError = Code("F0000") // Class F0 - Configuration File Error + LockFileExists = Code("F0001") + FDWError = Code("HV000") // Class HV - Foreign Data Wrapper Error (SQL/MED) + FDWColumnNameNotFound = Code("HV005") + FDWDynamicParameterValueNeeded = Code("HV002") + FDWFunctionSequenceError = Code("HV010") + FDWInconsistentDescriptorInformation = Code("HV021") + FDWInvalidAttributeValue = Code("HV024") + FDWInvalidColumnName = Code("HV007") + FDWInvalidColumnNumber = Code("HV008") + FDWInvalidDataType = Code("HV004") + FDWInvalidDataTypeDescriptors = Code("HV006") + FDWInvalidDescriptorFieldIdentifier = Code("HV091") + FDWInvalidHandle = Code("HV00B") + FDWInvalidOptionIndex = Code("HV00C") + FDWInvalidOptionName = Code("HV00D") + FDWInvalidStringLengthOrBufferLength = Code("HV090") + FDWInvalidStringFormat = Code("HV00A") + FDWInvalidUseOfNullPointer = Code("HV009") + FDWTooManyHandles = Code("HV014") + FDWOutOfMemory = Code("HV001") + FDWNoSchemas = Code("HV00P") + FDWOptionNameNotFound = Code("HV00J") + FDWReplyHandle = Code("HV00K") + FDWSchemaNotFound = Code("HV00Q") + FDWTableNotFound = Code("HV00R") + FDWUnableToCreateExecution = Code("HV00L") + FDWUnableToCreateReply = Code("HV00M") + FDWUnableToEstablishConnection = Code("HV00N") + PLpgSQLError = Code("P0000") // Class P0 - PL/pgSQL Error + RaiseException = Code("P0001") + NoDataFound = Code("P0002") + TooManyRows = Code("P0003") + AssertFailure = Code("P0004") + InternalError = Code("XX000") // Class XX - Internal Error + DataCorrupted = Code("XX001") + IndexCorrupted = Code("XX002") +) + +var errorCodeNames = map[Code]string{ + "00000": "successful_completion", + "01000": "warning", + "0100C": "dynamic_result_sets_returned", + "01008": "implicit_zero_bit_padding", + "01003": "null_value_eliminated_in_set_function", + "01007": "privilege_not_granted", + "01006": "privilege_not_revoked", + "01004": "string_data_right_truncation", + "01P01": "deprecated_feature", + "02000": "no_data", + "02001": "no_additional_dynamic_result_sets_returned", + "03000": "sql_statement_not_yet_complete", + "08000": "connection_exception", + "08003": "connection_does_not_exist", + "08006": "connection_failure", + "08001": "sqlclient_unable_to_establish_sqlconnection", + "08004": "sqlserver_rejected_establishment_of_sqlconnection", + "08007": "transaction_resolution_unknown", + "08P01": "protocol_violation", + "09000": "triggered_action_exception", + "0A000": "feature_not_supported", + "0B000": "invalid_transaction_initiation", + "0F000": "locator_exception", + "0F001": "invalid_locator_specification", + "0L000": "invalid_grantor", + "0LP01": "invalid_grant_operation", + "0P000": "invalid_role_specification", + "0Z000": "diagnostics_exception", + "0Z002": "stacked_diagnostics_accessed_without_active_handler", + "10608": "invalid_argument_for_xquery", + "20000": "case_not_found", + "21000": "cardinality_violation", + "22000": "data_exception", + "2202E": "array_subscript_error", + "22021": "character_not_in_repertoire", + "22008": "datetime_field_overflow", + "22012": "division_by_zero", + "22005": "error_in_assignment", + "2200B": "escape_character_conflict", + "22022": "indicator_overflow", + "22015": "interval_field_overflow", + "2201E": "invalid_argument_for_logarithm", + "22014": "invalid_argument_for_ntile_function", + "22016": "invalid_argument_for_nth_value_function", + "2201F": "invalid_argument_for_power_function", + "2201G": "invalid_argument_for_width_bucket_function", + "22018": "invalid_character_value_for_cast", + "22007": "invalid_datetime_format", + "22019": "invalid_escape_character", + "2200D": "invalid_escape_octet", + "22025": "invalid_escape_sequence", + "22P06": "nonstandard_use_of_escape_character", + "22010": "invalid_indicator_parameter_value", + "22023": "invalid_parameter_value", + "22013": "invalid_preceding_or_following_size", + "2201B": "invalid_regular_expression", + "2201W": "invalid_row_count_in_limit_clause", + "2201X": "invalid_row_count_in_result_offset_clause", + "2202H": "invalid_tablesample_argument", + "2202G": "invalid_tablesample_repeat", + "22009": "invalid_time_zone_displacement_value", + "2200C": "invalid_use_of_escape_character", + "2200G": "most_specific_type_mismatch", + "22004": "null_value_not_allowed", + "22002": "null_value_no_indicator_parameter", + "22003": "numeric_value_out_of_range", + "2200H": "sequence_generator_limit_exceeded", + "22026": "string_data_length_mismatch", + "22001": "string_data_right_truncation", + "22011": "substring_error", + "22027": "trim_error", + "22024": "unterminated_c_string", + "2200F": "zero_length_character_string", + "22P01": "floating_point_exception", + "22P02": "invalid_text_representation", + "22P03": "invalid_binary_representation", + "22P04": "bad_copy_file_format", + "22P05": "untranslatable_character", + "2200L": "not_an_xml_document", + "2200M": "invalid_xml_document", + "2200N": "invalid_xml_content", + "2200S": "invalid_xml_comment", + "2200T": "invalid_xml_processing_instruction", + "22030": "duplicate_json_object_key_value", + "22031": "invalid_argument_for_sql_json_datetime_function", + "22032": "invalid_json_text", + "22033": "invalid_sql_json_subscript", + "22034": "more_than_one_sql_json_item", + "22035": "no_sql_json_item", + "22036": "non_numeric_sql_json_item", + "22037": "non_unique_keys_in_a_json_object", + "22038": "singleton_sql_json_item_required", + "22039": "sql_json_array_not_found", + "2203A": "sql_json_member_not_found", + "2203B": "sql_json_number_not_found", + "2203C": "sql_json_object_not_found", + "2203D": "too_many_json_array_elements", + "2203E": "too_many_json_object_members", + "2203F": "sql_json_scalar_required", + "2203G": "sql_json_item_cannot_be_cast_to_target_type", + "23000": "integrity_constraint_violation", + "23001": "restrict_violation", + "23502": "not_null_violation", + "23503": "foreign_key_violation", + "23505": "unique_violation", + "23514": "check_violation", + "23P01": "exclusion_violation", + "24000": "invalid_cursor_state", + "25000": "invalid_transaction_state", + "25001": "active_sql_transaction", + "25002": "branch_transaction_already_active", + "25008": "held_cursor_requires_same_isolation_level", + "25003": "inappropriate_access_mode_for_branch_transaction", + "25004": "inappropriate_isolation_level_for_branch_transaction", + "25005": "no_active_sql_transaction_for_branch_transaction", + "25006": "read_only_sql_transaction", + "25007": "schema_and_data_statement_mixing_not_supported", + "25P01": "no_active_sql_transaction", + "25P02": "in_failed_sql_transaction", + "25P03": "idle_in_transaction_session_timeout", + "25P04": "transaction_timeout", + "26000": "invalid_sql_statement_name", + "27000": "triggered_data_change_violation", + "28000": "invalid_authorization_specification", + "28P01": "invalid_password", + "2B000": "dependent_privilege_descriptors_still_exist", + "2BP01": "dependent_objects_still_exist", + "2D000": "invalid_transaction_termination", + "2F000": "sql_routine_exception", + "2F005": "function_executed_no_return_statement", + "2F002": "modifying_sql_data_not_permitted", + "2F003": "prohibited_sql_statement_attempted", + "2F004": "reading_sql_data_not_permitted", + "34000": "invalid_cursor_name", + "38000": "external_routine_exception", + "38001": "containing_sql_not_permitted", + "38002": "modifying_sql_data_not_permitted", + "38003": "prohibited_sql_statement_attempted", + "38004": "reading_sql_data_not_permitted", + "39000": "external_routine_invocation_exception", + "39001": "invalid_sqlstate_returned", + "39004": "null_value_not_allowed", + "39P01": "trigger_protocol_violated", + "39P02": "srf_protocol_violated", + "39P03": "event_trigger_protocol_violated", + "3B000": "savepoint_exception", + "3B001": "invalid_savepoint_specification", + "3D000": "invalid_catalog_name", + "3F000": "invalid_schema_name", + "40000": "transaction_rollback", + "40002": "transaction_integrity_constraint_violation", + "40001": "serialization_failure", + "40003": "statement_completion_unknown", + "40P01": "deadlock_detected", + "42000": "syntax_error_or_access_rule_violation", + "42601": "syntax_error", + "42501": "insufficient_privilege", + "42846": "cannot_coerce", + "42803": "grouping_error", + "42P20": "windowing_error", + "42P19": "invalid_recursion", + "42830": "invalid_foreign_key", + "42602": "invalid_name", + "42622": "name_too_long", + "42939": "reserved_name", + "42804": "datatype_mismatch", + "42P18": "indeterminate_datatype", + "42P21": "collation_mismatch", + "42P22": "indeterminate_collation", + "42809": "wrong_object_type", + "428C9": "generated_always", + "42703": "undefined_column", + "42883": "undefined_function", + "42P01": "undefined_table", + "42P02": "undefined_parameter", + "42704": "undefined_object", + "42701": "duplicate_column", + "42P03": "duplicate_cursor", + "42P04": "duplicate_database", + "42723": "duplicate_function", + "42P05": "duplicate_prepared_statement", + "42P06": "duplicate_schema", + "42P07": "duplicate_table", + "42712": "duplicate_alias", + "42710": "duplicate_object", + "42702": "ambiguous_column", + "42725": "ambiguous_function", + "42P08": "ambiguous_parameter", + "42P09": "ambiguous_alias", + "42P10": "invalid_column_reference", + "42611": "invalid_column_definition", + "42P11": "invalid_cursor_definition", + "42P12": "invalid_database_definition", + "42P13": "invalid_function_definition", + "42P14": "invalid_prepared_statement_definition", + "42P15": "invalid_schema_definition", + "42P16": "invalid_table_definition", + "42P17": "invalid_object_definition", + "44000": "with_check_option_violation", + "53000": "insufficient_resources", + "53100": "disk_full", + "53200": "out_of_memory", + "53300": "too_many_connections", + "53400": "configuration_limit_exceeded", + "54000": "program_limit_exceeded", + "54001": "statement_too_complex", + "54011": "too_many_columns", + "54023": "too_many_arguments", + "55000": "object_not_in_prerequisite_state", + "55006": "object_in_use", + "55P02": "cant_change_runtime_param", + "55P03": "lock_not_available", + "55P04": "unsafe_new_enum_value_usage", + "57000": "operator_intervention", + "57014": "query_canceled", + "57P01": "admin_shutdown", + "57P02": "crash_shutdown", + "57P03": "cannot_connect_now", + "57P04": "database_dropped", + "57P05": "idle_session_timeout", + "58000": "system_error", + "58030": "io_error", + "58P01": "undefined_file", + "58P02": "duplicate_file", + "58P03": "file_name_too_long", + "F0000": "config_file_error", + "F0001": "lock_file_exists", + "HV000": "fdw_error", + "HV005": "fdw_column_name_not_found", + "HV002": "fdw_dynamic_parameter_value_needed", + "HV010": "fdw_function_sequence_error", + "HV021": "fdw_inconsistent_descriptor_information", + "HV024": "fdw_invalid_attribute_value", + "HV007": "fdw_invalid_column_name", + "HV008": "fdw_invalid_column_number", + "HV004": "fdw_invalid_data_type", + "HV006": "fdw_invalid_data_type_descriptors", + "HV091": "fdw_invalid_descriptor_field_identifier", + "HV00B": "fdw_invalid_handle", + "HV00C": "fdw_invalid_option_index", + "HV00D": "fdw_invalid_option_name", + "HV090": "fdw_invalid_string_length_or_buffer_length", + "HV00A": "fdw_invalid_string_format", + "HV009": "fdw_invalid_use_of_null_pointer", + "HV014": "fdw_too_many_handles", + "HV001": "fdw_out_of_memory", + "HV00P": "fdw_no_schemas", + "HV00J": "fdw_option_name_not_found", + "HV00K": "fdw_reply_handle", + "HV00Q": "fdw_schema_not_found", + "HV00R": "fdw_table_not_found", + "HV00L": "fdw_unable_to_create_execution", + "HV00M": "fdw_unable_to_create_reply", + "HV00N": "fdw_unable_to_establish_connection", + "P0000": "plpgsql_error", + "P0001": "raise_exception", + "P0002": "no_data_found", + "P0003": "too_many_rows", + "P0004": "assert_failure", + "XX000": "internal_error", + "XX001": "data_corrupted", + "XX002": "index_corrupted", +} diff --git a/vendor/github.com/lib/pq/pqerror/pqerror.go b/vendor/github.com/lib/pq/pqerror/pqerror.go new file mode 100644 index 00000000..29e49e99 --- /dev/null +++ b/vendor/github.com/lib/pq/pqerror/pqerror.go @@ -0,0 +1,35 @@ +//go:generate go run gen.go + +// Package pqerror contains PostgreSQL error codes for use with pq.Error. +package pqerror + +// Code is a five-character error code. +type Code string + +// Name returns a more human friendly rendering of the error code, namely the +// "condition name". +func (ec Code) Name() string { return errorCodeNames[ec] } + +// Class returns the error class, e.g. "28". +func (ec Code) Class() Class { return Class(ec[:2]) } + +// Class is only the class part of an error code. +type Class string + +// Name returns the condition name of an error class. It is equivalent to the +// condition name of the "standard" error code (i.e. the one having the last +// three characters "000"). +func (ec Class) Name() string { return errorCodeNames[Code(ec+"000")] } + +// TODO(v2): use "type Severity string" for the below. + +// Error severity values. +const ( + SeverityFatal = "FATAL" + SeverityPanic = "PANIC" + SeverityWarning = "WARNING" + SeverityNotice = "NOTICE" + SeverityDebug = "DEBUG" + SeverityInfo = "INFO" + SeverityLog = "LOG" +) diff --git a/vendor/github.com/lib/pq/quote.go b/vendor/github.com/lib/pq/quote.go new file mode 100644 index 00000000..909e41ec --- /dev/null +++ b/vendor/github.com/lib/pq/quote.go @@ -0,0 +1,71 @@ +package pq + +import ( + "bytes" + "strings" +) + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be case +// sensitive when used in a query. If the input string contains a zero byte, the +// result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} + +// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a +// byte buffer. +func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { + // TODO(v2): this should have accepted an io.Writer, not *bytes.Buffer. + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + buffer.WriteRune('"') + buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) + buffer.WriteRune('"') +} + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} diff --git a/vendor/github.com/lib/pq/rows.go b/vendor/github.com/lib/pq/rows.go new file mode 100644 index 00000000..2029bfed --- /dev/null +++ b/vendor/github.com/lib/pq/rows.go @@ -0,0 +1,245 @@ +package pq + +import ( + "database/sql/driver" + "fmt" + "io" + "math" + "reflect" + "time" + + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" +) + +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { return 0, errNoLastInsertID } +func (noRows) RowsAffected() (int64, error) { return 0, errNoRowsAffected } + +type ( + rowsHeader struct { + colNames []string + colTyps []fieldDesc + colFmts []format + } + rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader + } +) + +func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } + // no need to look at cn.bad as Next() will + for { + err := rs.Next(nil) + switch err { + case nil: + case io.EOF: + // rs.Next can return io.EOF on both ReadyForQuery and + // RowDescription (used with HasNextResultSet). We need to fetch + // messages until we hit a ReadyForQuery, which is done by waiting + // for done to be set. + if rs.done { + return nil + } + default: + return err + } + } +} + +func (rs *rows) Columns() []string { + return rs.colNames +} + +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + +func (rs *rows) Next(dest []driver.Value) (resErr error) { + if rs.done { + return io.EOF + } + if err := rs.cn.err.getForNext(); err != nil { + return err + } + + for { + t, err := rs.cn.recv1Buf(&rs.rb) + if err != nil { + return rs.cn.handleError(err) + } + switch t { + case proto.ErrorResponse: + resErr = parseError(&rs.rb, "") + case proto.CommandComplete, proto.EmptyQueryResponse: + if t == proto.CommandComplete { + rs.result, rs.tag, err = rs.cn.parseComplete(rs.rb.string()) + if err != nil { + return rs.cn.handleError(err) + } + } + continue + case proto.ReadyForQuery: + rs.cn.processReadyForQuery(&rs.rb) + rs.done = true + if resErr != nil { + return rs.cn.handleError(resErr) + } + return io.EOF + case proto.DataRow: + n := rs.rb.int16() + if resErr != nil { + rs.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected DataRow after error %s", resErr) + } + if n < len(dest) { + dest = dest[:n] + } + for i := range dest { + l := rs.rb.int32() + if l == -1 { + dest[i] = nil + continue + } + dest[i], err = decode(&rs.cn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) + if err != nil { + return rs.cn.handleError(err) + } + } + return rs.cn.handleError(resErr) + case proto.RowDescription: + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next + return io.EOF + default: + return fmt.Errorf("pq: unexpected message after execute: %q", t) + } + } +} + +func (rs *rows) HasNextResultSet() bool { + hasNext := rs.next != nil && !rs.done + return hasNext +} + +func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil + return nil +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} + +const headerSize = 4 + +type fieldDesc struct { + // The object ID of the data type. + OID oid.Oid + // The data type size (see pg_type.typlen). + // Note that negative values denote variable-width types. + Len int + // The type modifier (see pg_attribute.atttypmod). + // The meaning of the modifier is type-specific. + Mod int +} + +func (fd fieldDesc) Type() reflect.Type { + switch fd.OID { + case oid.T_int8: + return reflect.TypeOf(int64(0)) + case oid.T_int4: + return reflect.TypeOf(int32(0)) + case oid.T_int2: + return reflect.TypeOf(int16(0)) + case oid.T_float8: + return reflect.TypeOf(float64(0)) + case oid.T_float4: + return reflect.TypeOf(float32(0)) + case oid.T_varchar, oid.T_text, oid.T_varbit, oid.T_bit: + return reflect.TypeOf("") + case oid.T_bool: + return reflect.TypeOf(false) + case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: + return reflect.TypeOf(time.Time{}) + case oid.T_bytea: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(any)).Elem() + } +} + +func (fd fieldDesc) Name() string { + return oid.TypeName[fd.OID] +} + +func (fd fieldDesc) Length() (length int64, ok bool) { + switch fd.OID { + case oid.T_text, oid.T_bytea: + return math.MaxInt64, true + case oid.T_varchar, oid.T_bpchar: + return int64(fd.Mod - headerSize), true + case oid.T_varbit, oid.T_bit: + return int64(fd.Mod), true + default: + return 0, false + } +} + +func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.OID { + case oid.T_numeric, oid.T__numeric: + mod := fd.Mod - headerSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} diff --git a/vendor/github.com/lib/pq/scram/scram.go b/vendor/github.com/lib/pq/scram/scram.go new file mode 100644 index 00000000..7ed7a993 --- /dev/null +++ b/vendor/github.com/lib/pq/scram/scram.go @@ -0,0 +1,261 @@ +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-256, for example, use: +// +// client := scram.NewClient(sha256.New, user, pass) +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that occurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 16 + buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %w", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go new file mode 100644 index 00000000..71b0b288 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl.go @@ -0,0 +1,312 @@ +package pq + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/lib/pq/internal/pqutil" +) + +// Registry for custom tls.Configs +var ( + tlsConfs = make(map[string]*tls.Config) + tlsConfsMu sync.RWMutex +) + +// RegisterTLSConfig registers a custom [tls.Config]. They are used by using +// sslmode=pqgo-«key» in the connection string. +// +// Set the config to nil to remove a configuration. +func RegisterTLSConfig(key string, config *tls.Config) error { + key = strings.TrimPrefix(key, "pqgo-") + if config == nil { + tlsConfsMu.Lock() + delete(tlsConfs, key) + tlsConfsMu.Unlock() + return nil + } + + tlsConfsMu.Lock() + tlsConfs[key] = config + tlsConfsMu.Unlock() + return nil +} + +func hasTLSConfig(key string) bool { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + _, ok := tlsConfs[key] + return ok +} + +func getTLSConfigClone(key string) *tls.Config { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + if v, ok := tlsConfs[key]; ok { + return v.Clone() + } + return nil +} + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +// +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) { + var ( + home = pqutil.Home() + // Don't set defaults here, because tlsConf may be overwritten if a + // custom one was registered. Set it after the sslmode switch. + tlsConf = &tls.Config{} + // Only verify the CA signing but not the hostname. + verifyCaOnly = false + ) + if mode.useSSL() && !cfg.SSLInline && cfg.SSLRootCert == "" && home != "" { + f := filepath.Join(home, "root.crt") + if _, err := os.Stat(f); err == nil { + cfg.SSLRootCert = f + } + } + switch { + case mode == SSLModeDisable || mode == SSLModeAllow: + return nil, nil + + case mode == SSLModeRequire || mode == SSLModePrefer: + // Skip TLS's own verification since it requires full verification. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // For backwards compatibility with earlier versions of PostgreSQL, if a + // root CA file exists, the behavior of sslmode=require will be the same + // as that of verify-ca, meaning the server certificate is validated + // against the CA. Relying on this behavior is discouraged, and + // applications that need certificate validation should always use + // verify-ca or verify-full. + if cfg.SSLRootCert != "" { + if cfg.SSLInline { + verifyCaOnly = true + } else if _, err := os.Stat(cfg.SSLRootCert); err == nil { + verifyCaOnly = true + } else if cfg.SSLRootCert != "system" { + cfg.SSLRootCert = "" + } + } + case mode == SSLModeVerifyCA: + // Skip TLS's own verification since it requires full verification. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case mode == SSLModeVerifyFull: + tlsConf.ServerName = cfg.Host + case strings.HasPrefix(string(mode), "pqgo-"): + tlsConf = getTLSConfigClone(string(mode[5:])) + if tlsConf == nil { + return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode) + } + default: + panic("unreachable") + } + + tlsConf.MinVersion = cfg.SSLMinProtocolVersion.tlsconf() + tlsConf.MaxVersion = cfg.SSLMaxProtocolVersion.tlsconf() + + // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or + // IPv6). This check is coded already crypto.tls.hostnameInSNI, so just + // always set ServerName here and let crypto/tls do the filtering. + if cfg.SSLSNI { + tlsConf.ServerName = cfg.Host + } + + err := sslClientCertificates(tlsConf, cfg, home) + if err != nil { + return nil, err + } + rootPem, err := sslCertificateAuthority(tlsConf, cfg) + if err != nil { + return nil, err + } + sslAppendIntermediates(tlsConf, cfg, rootPem) + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but the + // default configuration of older versions has it enabled. Redshift also + // initiates renegotiations and cannot be reconfigured. + // + // TODO: I think this can be removed? + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient + + return func(conn net.Conn) (net.Conn, error) { + client := tls.Client(conn, tlsConf) + if verifyCaOnly { + err := client.Handshake() + if err != nil { + return client, err + } + var ( + certs = client.ConnectionState().PeerCertificates + opts = x509.VerifyOptions{Intermediates: x509.NewCertPool(), Roots: tlsConf.RootCAs} + ) + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + return client, err + } + return client, nil + }, nil +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, cfg Config, home string) error { + if cfg.SSLInline { + cert, err := tls.X509KeyPair([]byte(cfg.SSLCert), []byte(cfg.SSLKey)) + if err != nil { + return err + } + // Use GetClientCertificate instead of the Certificates field. When + // Certificates is set, Go's TLS client only sends the cert if the + // server's CertificateRequest includes a CA that issued it. When the + // client cert was signed by an intermediate CA but the server only + // advertises the root CA, Go skips sending the cert entirely. + // GetClientCertificate bypasses this filtering. + tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &cert, nil + } + return nil + } + + // Only load client certificate and key if the setting is not blank, like libpq. + if cfg.SSLCert == "" && home != "" { + cfg.SSLCert = filepath.Join(home, "postgresql.crt") + } + if cfg.SSLCert == "" { + return nil + } + _, err := os.Stat(cfg.SSLCert) + if err != nil { + if pqutil.ErrNotExists(err) { + return nil + } + return err + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + if cfg.SSLKey == "" && home != "" { + cfg.SSLKey = filepath.Join(home, "postgresql.key") + } + if cfg.SSLKey != "" { + err := pqutil.SSLKeyPermissions(cfg.SSLKey) + if err != nil { + return err + } + } + + cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey) + if err != nil { + return err + } + + // Using GetClientCertificate instead of Certificates per comment above. + tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &cert, nil + } + return nil +} + +var testSystemRoots *x509.CertPool + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, cfg Config) ([]byte, error) { + // Only load root certificate if not blank, like libpq. + if cfg.SSLRootCert == "" { + return nil, nil + } + + if cfg.SSLRootCert == "system" { + // No work to do as system CAs are used by default if RootCAs is nil. + tlsConf.RootCAs = testSystemRoots + return nil, nil + } + + tlsConf.RootCAs = x509.NewCertPool() + + var cert []byte + if cfg.SSLInline { + cert = []byte(cfg.SSLRootCert) + } else { + var err error + cert, err = os.ReadFile(cfg.SSLRootCert) + if err != nil { + return nil, err + } + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + return nil, errors.New("pq: couldn't parse pem from sslrootcert") + } + return cert, nil +} + +// sslAppendIntermediates appends intermediate CA certificates from sslrootcert +// to the client certificate chain. This is needed so the server can verify the +// client cert when it was signed by an intermediate CA — without this, the TLS +// handshake only sends the leaf client cert. +func sslAppendIntermediates(tlsConf *tls.Config, cfg Config, rootPem []byte) { + if cfg.SSLRootCert == "" || tlsConf.GetClientCertificate == nil || len(rootPem) == 0 { + return + } + + var ( + pemData = slices.Clone(rootPem) + intermediates [][]byte + ) + for { + var block *pem.Block + block, pemData = pem.Decode(pemData) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + continue + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + continue + } + // Skip self-signed root CAs; only append intermediates. + if cert.IsCA && !bytes.Equal(cert.RawIssuer, cert.RawSubject) { + intermediates = append(intermediates, block.Bytes) + } + } + if len(intermediates) == 0 { + return + } + + // Wrap the existing GetClientCertificate to append intermediate certs to + // the certificate chain returned during the TLS handshake. + origGetCert := tlsConf.GetClientCertificate + tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert, err := origGetCert(info) + if err != nil { + return cert, err + } + cert.Certificate = append(cert.Certificate, intermediates...) + return cert, nil + } +} diff --git a/vendor/github.com/lib/pq/staticcheck.conf b/vendor/github.com/lib/pq/staticcheck.conf new file mode 100644 index 00000000..83abe48e --- /dev/null +++ b/vendor/github.com/lib/pq/staticcheck.conf @@ -0,0 +1,5 @@ +checks = [ + 'all', + '-ST1000', # "Must have at least one package comment" + '-ST1003', # "func EnableInfinityTs should be EnableInfinityTS" +] diff --git a/vendor/github.com/lib/pq/stmt.go b/vendor/github.com/lib/pq/stmt.go new file mode 100644 index 00000000..ca6ecc89 --- /dev/null +++ b/vendor/github.com/lib/pq/stmt.go @@ -0,0 +1,150 @@ +package pq + +import ( + "context" + "database/sql/driver" + "fmt" + "os" + + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" +) + +type stmt struct { + cn *conn + name string + rowsHeader + colFmtData []byte + paramTyps []oid.Oid + closed bool +} + +func (st *stmt) Close() error { + if st.closed { + return nil + } + if err := st.cn.err.get(); err != nil { + return err + } + + w := st.cn.writeBuf(proto.Close) + w.byte(proto.Sync) + w.string(st.name) + err := st.cn.send(w) + if err != nil { + return st.cn.handleError(err) + } + err = st.cn.send(st.cn.writeBuf(proto.Sync)) + if err != nil { + return st.cn.handleError(err) + } + + t, _, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.CloseComplete { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected close response: %q", t) + } + st.closed = true + + t, r, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.ReadyForQuery { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: expected ready for query, but got: %q", t) + } + st.cn.processReadyForQuery(r) + + return nil +} + +func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { + return st.query(toNamedValue(v)) +} + +func (st *stmt) query(v []driver.NamedValue) (*rows, error) { + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(v) + if err != nil { + return nil, st.cn.handleError(err) + } + return &rows{ + cn: st.cn, + rowsHeader: st.rowsHeader, + }, nil +} + +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { + return st.ExecContext(context.Background(), toNamedValue(v)) +} + +func (st *stmt) exec(v []driver.NamedValue) error { + if debugProto { + fmt.Fprintf(os.Stderr, " START stmt.exec\n") + defer fmt.Fprintf(os.Stderr, " END stmt.exec\n") + } + if len(v) >= 65536 { + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) + } + if len(v) != len(st.paramTyps) { + return fmt.Errorf("pq: got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) + } + + cn := st.cn + w := cn.writeBuf(proto.Bind) + w.byte(0) // unnamed portal + w.string(st.name) + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryParameters(w, v) + if err != nil { + return err + } + } else { + w.int16(0) + w.int16(len(v)) + for i, x := range v { + if x.Value == nil { + w.int32(-1) + } else { + b, err := encode(x.Value, st.paramTyps[i]) + if err != nil { + return err + } + if b == nil { + w.int32(-1) + } else { + w.int32(len(b)) + w.bytes(b) + } + } + } + } + w.bytes(st.colFmtData) + + w.next(proto.Execute) + w.byte(0) + w.int32(0) + + w.next(proto.Sync) + err := cn.send(w) + if err != nil { + return err + } + err = cn.readBindResponse() + if err != nil { + return err + } + return cn.postExecuteWorkaround() +} + +func (st *stmt) NumInput() int { + return len(st.paramTyps) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 603bba86..01c951c3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -74,6 +74,18 @@ github.com/hashicorp/serf/coordinate # github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef ## explicit github.com/howeyc/gopass +# github.com/lib/pq v1.12.0 +## explicit; go 1.21 +github.com/lib/pq +github.com/lib/pq/internal/pgpass +github.com/lib/pq/internal/pgservice +github.com/lib/pq/internal/pqsql +github.com/lib/pq/internal/pqtime +github.com/lib/pq/internal/pqutil +github.com/lib/pq/internal/proto +github.com/lib/pq/oid +github.com/lib/pq/pqerror +github.com/lib/pq/scram # github.com/mattn/go-colorable v0.1.14 ## explicit; go 1.18 github.com/mattn/go-colorable From 890d2bac39c732b89ea0642cd6f436978e988606 Mon Sep 17 00:00:00 2001 From: Rene Cannao Date: Tue, 24 Mar 2026 14:06:32 +0000 Subject: [PATCH 2/2] Address review: fix standby lag query, configurable SSL, safe DSN, recovery guards --- go/config/config.go | 2 ++ go/inst/provider_postgresql.go | 56 ++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/go/config/config.go b/go/config/config.go index 28461d62..6f2cec31 100644 --- a/go/config/config.go +++ b/go/config/config.go @@ -98,6 +98,7 @@ type Configuration struct { MySQLTopologyPassword string PostgreSQLTopologyUser string // Username for connecting to PostgreSQL topology instances PostgreSQLTopologyPassword string // Password for connecting to PostgreSQL topology instances + PostgreSQLSSLMode string // SSL mode for PostgreSQL connections: disable, require, verify-ca, verify-full. Default: "require" MySQLTopologyCredentialsConfigFile string // my.cnf style configuration file from where to pick credentials. Expecting `user`, `password` under `[client]` section MySQLTopologySSLPrivateKeyFile string // Private key file used to authenticate with a Topology mysql instance with TLS MySQLTopologySSLCertFile string // Certificate PEM file used to authenticate with a Topology mysql instance with TLS @@ -336,6 +337,7 @@ func newConfiguration() *Configuration { MySQLOrchestratorPort: 3306, MySQLTopologyUseMutualTLS: false, MySQLTopologyUseMixedTLS: true, + PostgreSQLSSLMode: "require", MySQLTopologyMaxAllowedPacket: -1, MySQLOrchestratorUseMutualTLS: false, MySQLConnectTimeoutSeconds: 2, diff --git a/go/inst/provider_postgresql.go b/go/inst/provider_postgresql.go index 8f6ad253..d4579459 100644 --- a/go/inst/provider_postgresql.go +++ b/go/inst/provider_postgresql.go @@ -19,6 +19,7 @@ package inst import ( "database/sql" "fmt" + "net/url" _ "github.com/lib/pq" "github.com/proxysql/golib/log" @@ -41,13 +42,15 @@ func (p *PostgreSQLProvider) ProviderName() string { // openPostgreSQLTopology opens a connection to a PostgreSQL instance using // credentials from the orchestrator configuration. -func openPostgreSQLTopology(hostname string, port int) (*sql.DB, error) { - cfg := config.Config - connStr := fmt.Sprintf( - "host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable connect_timeout=5", - hostname, port, cfg.PostgreSQLTopologyUser, cfg.PostgreSQLTopologyPassword, - ) - db, err := sql.Open("postgres", connStr) +func openPostgreSQLTopology(key InstanceKey) (*sql.DB, error) { + u := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(config.Config.PostgreSQLTopologyUser, config.Config.PostgreSQLTopologyPassword), + Host: fmt.Sprintf("%s:%d", key.Hostname, key.Port), + Path: "postgres", + RawQuery: fmt.Sprintf("sslmode=%s&connect_timeout=5", config.Config.PostgreSQLSSLMode), + } + db, err := sql.Open("postgres", u.String()) if err != nil { return nil, err } @@ -58,9 +61,9 @@ func openPostgreSQLTopology(hostname string, port int) (*sql.DB, error) { // GetReplicationStatus retrieves the replication state for a PostgreSQL instance. // On a standby it queries pg_stat_wal_receiver; on a primary it queries -// pg_stat_replication. +// pg_current_wal_lsn(). func (p *PostgreSQLProvider) GetReplicationStatus(key InstanceKey) (*ReplicationStatus, error) { - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return nil, log.Errore(err) } @@ -86,12 +89,10 @@ func (p *PostgreSQLProvider) getStandbyReplicationStatus(db *sql.DB) (*Replicati err := db.QueryRow(` SELECT - w.status, + COALESCE(r.status, ''), pg_last_wal_replay_lsn()::text, - EXTRACT(EPOCH FROM replay_lag) - FROM pg_stat_wal_receiver w - LEFT JOIN pg_stat_replication r ON true - LIMIT 1 + COALESCE(EXTRACT(EPOCH FROM now() - pg_last_xact_replay_timestamp()), -1) + FROM (SELECT 'streaming' as status FROM pg_stat_wal_receiver LIMIT 1) r `).Scan(&status, &lsn, &lagSeconds) if err == sql.ErrNoRows { @@ -148,7 +149,7 @@ func (p *PostgreSQLProvider) getPrimaryReplicationStatus(db *sql.DB) (*Replicati // IsReplicaRunning checks whether the WAL receiver is active on a PostgreSQL // standby instance. func (p *PostgreSQLProvider) IsReplicaRunning(key InstanceKey) (bool, error) { - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return false, log.Errore(err) } @@ -168,7 +169,7 @@ func (p *PostgreSQLProvider) IsReplicaRunning(key InstanceKey) (bool, error) { // SetReadOnly sets or clears the default_transaction_read_only parameter on // a PostgreSQL instance and reloads the configuration. func (p *PostgreSQLProvider) SetReadOnly(key InstanceKey, readOnly bool) error { - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return log.Errore(err) } @@ -190,7 +191,7 @@ func (p *PostgreSQLProvider) SetReadOnly(key InstanceKey, readOnly bool) error { // IsReadOnly checks whether default_transaction_read_only is enabled on a // PostgreSQL instance. func (p *PostgreSQLProvider) IsReadOnly(key InstanceKey) (bool, error) { - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return false, log.Errore(err) } @@ -209,12 +210,21 @@ func (p *PostgreSQLProvider) IsReadOnly(key InstanceKey) (bool, error) { func (p *PostgreSQLProvider) StartReplication(key InstanceKey) error { log.Infof("PostgreSQL streaming replication on %s:%d starts automatically; resuming WAL replay if paused", key.Hostname, key.Port) - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return log.Errore(err) } defer db.Close() + var inRecovery bool + if err := db.QueryRow("SELECT pg_is_in_recovery()").Scan(&inRecovery); err != nil { + return log.Errore(err) + } + if !inRecovery { + log.Infof("StartReplication: %s:%d is a primary, WAL replay resume not applicable", key.Hostname, key.Port) + return nil + } + if _, err := db.Exec("SELECT pg_wal_replay_resume()"); err != nil { return log.Errore(err) } @@ -225,12 +235,20 @@ func (p *PostgreSQLProvider) StartReplication(key InstanceKey) error { // closest equivalent to stopping replication in MySQL. Note that the WAL // receiver (IO thread equivalent) remains connected; only replay is paused. func (p *PostgreSQLProvider) StopReplication(key InstanceKey) error { - db, err := openPostgreSQLTopology(key.Hostname, key.Port) + db, err := openPostgreSQLTopology(key) if err != nil { return log.Errore(err) } defer db.Close() + var inRecovery bool + if err := db.QueryRow("SELECT pg_is_in_recovery()").Scan(&inRecovery); err != nil { + return log.Errore(err) + } + if !inRecovery { + return fmt.Errorf("StopReplication: %s:%d is a primary, WAL replay pause not applicable", key.Hostname, key.Port) + } + if _, err := db.Exec("SELECT pg_wal_replay_pause()"); err != nil { return log.Errore(err) }