diff --git a/src/pgo.erl b/src/pgo.erl index b2aab52..bfac17e 100644 --- a/src/pgo.erl +++ b/src/pgo.erl @@ -14,6 +14,10 @@ query/2, query/3, query/4, + prepare/2, + prepare/3, + query_prepared/3, + query_prepared/4, transaction/1, transaction/2, transaction/3, @@ -148,6 +152,89 @@ query(Query, Params, Options, Conn=#conn{trace=TraceDefault, #{queue_time => undefined}) end). +%% @doc Prepare a named statement on the given pool. +%% Returns {ok, Name, ParameterOIDs} which can be passed to query_prepared/3,4. +-spec prepare(iodata(), iodata()) -> {ok, iodata(), [pg_types:oid()]} | {error, term()}. +prepare(Name, Query) -> + prepare(Name, Query, #{}). + +-spec prepare(iodata(), iodata(), options()) -> {ok, iodata(), [pg_types:oid()]} | {error, term()}. +prepare(Name, Query, Options) -> + pgo_prepared_cache:init(), + Pool = maps:get(pool, Options, default), + PoolOptions = maps:get(pool_options, Options, []), + case checkout(Pool, PoolOptions) of + {ok, Ref, Conn} -> + try + case pgo_handler:prepare(Conn, Name, Query) of + {ok, N, OIDs} = Result -> + pgo_prepared_cache:store(N, Query, OIDs), + Result; + Error -> + Error + end + after + checkin(Ref, Conn) + end; + {error, _} = E -> + E + end. + +%% @doc Execute a previously prepared named statement. +%% ParameterOIDs is the list returned by prepare/2,3. +-spec query_prepared(iodata(), list(), [pg_types:oid()]) -> result(). +query_prepared(Name, Params, ParameterOIDs) -> + query_prepared(Name, Params, ParameterOIDs, #{}). + +-spec query_prepared(iodata(), list(), [pg_types:oid()], options()) -> result(). +query_prepared(Name, Params, ParameterOIDs, Options) -> + pgo_prepared_cache:init(), + Pool = maps:get(pool, Options, default), + PoolOptions = maps:get(pool_options, Options, []), + DecodeOptions = maps:get(decode_opts, Options, []), + case checkout(Pool, PoolOptions) of + {ok, Ref={_, _, _, Holder}, Conn=#conn{owner=Owner, decode_opts=DefaultDecodeOpts}} -> + try + NameBin = iolist_to_binary(Name), + _ = maybe_prepare_on_conn(Owner, NameBin, Conn), + pgo_handler:prepared_query(Conn, Name, Params, ParameterOIDs, + DecodeOptions ++ DefaultDecodeOpts) + of + {error, closed} -> + maybe_timeout_error(Holder); + {error, einval} -> + maybe_timeout_error(Holder); + Result -> + Result + after + checkin(Ref, Conn) + end; + {error, _} = E -> + E + end. + +maybe_prepare_on_conn(Owner, NameBin, Conn) -> + Key = {Owner, NameBin}, + case pgo_prepared_cache:is_conn_prepared(Key) of + true -> + ok; + false -> + case pgo_prepared_cache:lookup(NameBin) of + {ok, Query, _OIDs} -> + case pgo_handler:prepare(Conn, NameBin, Query) of + {ok, _, _} -> + pgo_prepared_cache:mark_conn_prepared(Key); + {error, {pgsql_error, #{code := <<"42P05">>}}} -> + %% Already prepared (e.g. from a previous session) + pgo_prepared_cache:mark_conn_prepared(Key); + Error -> + Error + end; + not_found -> + ok + end + end. + %% @equiv transaction(default, Fun, []) -spec transaction(fun(() -> any())) -> any() | {error, any()}. transaction(Fun) -> diff --git a/src/pgo_handler.erl b/src/pgo_handler.erl index a3ac18c..e96a64f 100644 --- a/src/pgo_handler.erl +++ b/src/pgo_handler.erl @@ -5,6 +5,9 @@ extended_query/3, extended_query/4, extended_query/5, + prepare/3, + prepared_query/4, + prepared_query/5, ping/1, setopts/3, simple_query/2, @@ -78,6 +81,108 @@ extended_query(Socket, Query, Parameters, DecodeOptions, _Timings) -> DecodeFun = proplists:get_value(decode_fun, DecodeOptions, undefined), extended_query(Socket, Query, Parameters, DecodeOptions, DecodeFun, []). +%% @doc Parse a named prepared statement. Returns {ok, Name, ParameterOIDs} +%% on success. The statement is cached server-side per connection. +-spec prepare(#conn{}, iodata(), iodata()) -> {ok, iodata(), [pg_types:oid()]} | {error, term()}. +prepare(Conn=#conn{socket=Socket, + socket_module=SocketModule}, Name, Query) -> + _ = setopts(SocketModule, Socket, [{active, false}]), + ParseMessage = pgo_protocol:encode_parse_message(Name, Query, []), + DescribeMessage = pgo_protocol:encode_describe_message(statement, Name), + FlushMessage = pgo_protocol:encode_flush_message(), + SyncMessage = pgo_protocol:encode_sync_message(), + Packet = [ParseMessage, DescribeMessage, FlushMessage, SyncMessage], + Result = case SocketModule:send(Socket, Packet) of + ok -> + prepare_receive_loop(Name, Conn); + {error, _} = SendError -> + SendError + end, + _ = setopts(SocketModule, Socket, [{active, once}]), + Result. + +prepare_receive_loop(Name, Conn=#conn{socket=Socket, socket_module=SocketModule}) -> + case receive_message(SocketModule, Socket, Conn, []) of + {ok, #parse_complete{}} -> + prepare_receive_loop_describe(Name, Conn); + {ok, #error_response{fields = Fields}} -> + flush_until_ready_for_query({error, {pgsql_error, Fields}}, Conn); + {error, _} = Error -> + Error + end. + +prepare_receive_loop_describe(Name, Conn=#conn{socket=Socket, socket_module=SocketModule}) -> + case receive_message(SocketModule, Socket, Conn, []) of + {ok, #parameter_description{data_types=DataTypes}} -> + prepare_skip_to_ready(Name, DataTypes, Conn); + {ok, #error_response{fields = Fields}} -> + flush_until_ready_for_query({error, {pgsql_error, Fields}}, Conn); + {error, _} = Error -> + Error + end. + +prepare_skip_to_ready(Name, DataTypes, Conn=#conn{socket=Socket, socket_module=SocketModule}) -> + case receive_message(SocketModule, Socket, Conn, []) of + {ok, #ready_for_query{}} -> + {ok, Name, DataTypes}; + {ok, #error_response{fields = Fields}} -> + flush_until_ready_for_query({error, {pgsql_error, Fields}}, Conn); + {ok, #parameter_description{data_types = DTs}} -> + prepare_skip_to_ready(Name, DTs, Conn); + {ok, _} -> + prepare_skip_to_ready(Name, DataTypes, Conn); + {error, _} = Error -> + Error + end. + +%% @doc Execute a previously prepared named statement. Skips PARSE entirely — +%% only sends BIND, DESCRIBE portal, EXECUTE, SYNC. The statement must have +%% been prepared on this connection via prepare/3 first. +-spec prepared_query(#conn{}, iodata(), list(), [pg_types:oid()]) -> pgo:result(). +prepared_query(Conn, Name, Parameters, ParameterDataTypes) -> + prepared_query(Conn, Name, Parameters, ParameterDataTypes, []). + +-spec prepared_query(#conn{}, iodata(), list(), [pg_types:oid()], pgo:decode_opts()) -> pgo:result(). +prepared_query(Conn=#conn{socket=Socket, + socket_module=SocketModule}, + Name, Parameters, ParameterDataTypes, DecodeOptions) -> + _ = setopts(SocketModule, Socket, [{active, false}]), + DecodeFun = proplists:get_value(decode_fun, DecodeOptions, undefined), + Result = case encode_bind_describe_execute_named(Conn, Name, Parameters, ParameterDataTypes) of + {ok, SinglePacket} -> + case SocketModule:send(Socket, SinglePacket) of + ok -> + try + receive_loop(bind_complete, DecodeFun, [], DecodeOptions, Conn) + catch + Class:Reason:Stacktrace -> + flush_until_ready_for_query(error, Conn), + erlang:raise(Class, Reason, Stacktrace) + end; + {error, _} = SendError -> + SendError + end; + {_, _} = Error -> + Error + end, + _ = setopts(SocketModule, Socket, [{active, once}]), + Result. + +-spec encode_bind_describe_execute_named(pgo_pool:conn(), iodata(), [any()], [pg_types:oid()]) -> + {ok, iodata()} | {term(), any()}. +encode_bind_describe_execute_named(Conn, StatementName, Parameters, ParameterDataTypes) -> + DescribeMessage = pgo_protocol:encode_describe_message(portal, ""), + ExecuteMessage = pgo_protocol:encode_execute_message("", 0), + SyncMessage = pgo_protocol:encode_sync_message(), + try + BindMessage = pgo_protocol:encode_bind_message(Conn, "", StatementName, Parameters, ParameterDataTypes), + SinglePacket = [BindMessage, DescribeMessage, ExecuteMessage, SyncMessage], + {ok, SinglePacket} + catch + Class:Exception -> + {Class, Exception} + end. + -spec ping(#conn{}) -> ok | {error, term()}. ping(Conn=#conn{socket=Socket, socket_module=SocketModule}) -> diff --git a/src/pgo_prepared_cache.erl b/src/pgo_prepared_cache.erl new file mode 100644 index 0000000..094c67b --- /dev/null +++ b/src/pgo_prepared_cache.erl @@ -0,0 +1,48 @@ +%% @doc ETS-based cache for prepared statement metadata. +%% +%% Stores statement name to {query, parameter OIDs} mappings, and tracks +%% which connections have each statement prepared. +-module(pgo_prepared_cache). + +-export([init/0, store/3, lookup/1, is_conn_prepared/1, mark_conn_prepared/1]). + +-define(TABLE, pgo_prepared_cache). +-define(CONN_TABLE, pgo_prepared_conn_cache). + +%% @doc Initialize cache tables. Safe to call multiple times. +init() -> + init_table(?TABLE), + init_table(?CONN_TABLE). + +%% @doc Store a prepared statement's query and parameter OIDs. +-spec store(iodata(), iodata(), [pg_types:oid()]) -> ok. +store(Name, Query, OIDs) -> + ets:insert(?TABLE, {iolist_to_binary(Name), iolist_to_binary(Query), OIDs}), + ok. + +%% @doc Look up a prepared statement's query and OIDs by name. +-spec lookup(iodata()) -> {ok, binary(), [pg_types:oid()]} | not_found. +lookup(Name) -> + case ets:lookup(?TABLE, iolist_to_binary(Name)) of + [{_, Query, OIDs}] -> {ok, Query, OIDs}; + [] -> not_found + end. + +%% @doc Check if a statement has been prepared on a specific connection. +-spec is_conn_prepared({pid(), binary()}) -> boolean(). +is_conn_prepared(Key) -> + ets:member(?CONN_TABLE, Key). + +%% @doc Mark a statement as prepared on a specific connection. +-spec mark_conn_prepared({pid(), binary()}) -> ok. +mark_conn_prepared(Key) -> + ets:insert(?CONN_TABLE, {Key}), + ok. + +init_table(Name) -> + case ets:whereis(Name) of + undefined -> + ets:new(Name, [named_table, public, set, {read_concurrency, true}]); + _ -> + ok + end. diff --git a/test/pgo_prepared_SUITE.erl b/test/pgo_prepared_SUITE.erl new file mode 100644 index 0000000..518d75c --- /dev/null +++ b/test/pgo_prepared_SUITE.erl @@ -0,0 +1,173 @@ +-module(pgo_prepared_SUITE). + +-compile(export_all). + +-include_lib("common_test/include/ct.hrl"). +-include_lib("stdlib/include/assert.hrl"). + +all() -> + [prepare_select, + prepare_returns_oids, + prepared_query_select, + prepared_query_insert, + prepared_query_with_params, + prepared_query_multiple_rows, + prepared_query_no_rows, + prepared_query_wrong_params, + prepare_invalid_sql, + prepared_query_not_prepared, + prepare_cache_stores_metadata, + prepared_query_rows_as_maps, + with_conn_prepare_and_query, + auto_prepare_across_pool]. + +init_per_suite(Config) -> + application:ensure_all_started(pgo), + {ok, _} = pgo_sup:start_child(default, #{pool_size => 1, + port => 5432, + database => "test", + user => "test", + password => "password"}), + pgo:query("CREATE TABLE IF NOT EXISTS prepared_test (" + " id BIGSERIAL PRIMARY KEY," + " name VARCHAR(255) NOT NULL," + " value INTEGER" + ")"), + pgo:query("TRUNCATE prepared_test RESTART IDENTITY"), + pgo:query("INSERT INTO prepared_test (name, value) VALUES ('alice', 10)"), + pgo:query("INSERT INTO prepared_test (name, value) VALUES ('bob', 20)"), + pgo:query("INSERT INTO prepared_test (name, value) VALUES ('charlie', 30)"), + Config. + +end_per_suite(_Config) -> + pgo:query("DROP TABLE IF EXISTS prepared_test"), + application:stop(pgo), + ok. + +init_per_testcase(TestCase, Config) -> + %% Deallocate all prepared statements between tests + pgo:query("DEALLOCATE ALL"), + [{testcase, TestCase} | Config]. + +end_per_testcase(_, _Config) -> + ok. + +%%---------------------------------------------------------------------- +%% prepare/2,3 tests +%%---------------------------------------------------------------------- + +prepare_select(_Config) -> + {ok, _, OIDs} = pgo:prepare("test_select", "SELECT * FROM prepared_test WHERE id = $1"), + ?assert(is_list(OIDs)), + ?assertEqual(1, length(OIDs)). + +prepare_returns_oids(_Config) -> + {ok, Name, OIDs} = pgo:prepare("test_oids", "SELECT * FROM prepared_test WHERE id = $1 AND name = $2"), + ?assertEqual("test_oids", Name), + ?assertEqual(2, length(OIDs)), + %% OIDs should be integers + lists:foreach(fun(Oid) -> ?assert(is_integer(Oid)) end, OIDs). + +prepare_invalid_sql(_Config) -> + Result = pgo:prepare("bad_sql", "SELECTT * FROMM nonexistent"), + ?assertMatch({error, {pgsql_error, _}}, Result). + +prepare_cache_stores_metadata(_Config) -> + pgo_prepared_cache:init(), + {ok, _, _OIDs} = pgo:prepare("cached_stmt", "SELECT 1"), + ?assertMatch({ok, _, _}, pgo_prepared_cache:lookup(<<"cached_stmt">>)). + +%%---------------------------------------------------------------------- +%% query_prepared/3,4 tests +%%---------------------------------------------------------------------- + +prepared_query_select(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_select", "SELECT id, name, value FROM prepared_test WHERE id = $1"), + Result = pgo:query_prepared("q_select", [1], OIDs), + ?assertMatch(#{command := select, num_rows := 1, rows := [{1, <<"alice">>, 10}]}, Result). + +prepared_query_insert(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_insert", "INSERT INTO prepared_test (name, value) VALUES ($1, $2)"), + Result = pgo:query_prepared("q_insert", [<<"dave">>, 40], OIDs), + ?assertMatch(#{command := insert, num_rows := 1}, Result), + %% Verify it was inserted + ?assertMatch(#{rows := [{<<"dave">>, 40}]}, + pgo:query("SELECT name, value FROM prepared_test WHERE name = $1", [<<"dave">>])), + %% Clean up + pgo:query("DELETE FROM prepared_test WHERE name = $1", [<<"dave">>]). + +prepared_query_with_params(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_params", "SELECT name FROM prepared_test WHERE value > $1 ORDER BY value"), + Result = pgo:query_prepared("q_params", [15], OIDs), + ?assertMatch(#{command := select, rows := [{<<"bob">>}, {<<"charlie">>}]}, Result). + +prepared_query_multiple_rows(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_multi", "SELECT id, name FROM prepared_test ORDER BY id"), + Result = pgo:query_prepared("q_multi", [], OIDs), + ?assertMatch(#{command := select, num_rows := 3}, Result). + +prepared_query_no_rows(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_empty", "SELECT * FROM prepared_test WHERE id = $1"), + Result = pgo:query_prepared("q_empty", [999], OIDs), + ?assertMatch(#{command := select, num_rows := 0, rows := []}, Result). + +prepared_query_wrong_params(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_wrong", "SELECT * FROM prepared_test WHERE id = $1"), + %% Wrong number of parameters + Result = pgo:query_prepared("q_wrong", [1, 2], OIDs), + ?assertMatch({error, _}, Result). + +prepared_query_not_prepared(_Config) -> + %% Query a statement that doesn't exist on the connection + Result = pgo:query_prepared("nonexistent_stmt", [1], [23]), + ?assertMatch({error, {pgsql_error, _}}, Result). + +prepared_query_rows_as_maps(_Config) -> + {ok, _, OIDs} = pgo:prepare("q_maps", "SELECT id, name, value FROM prepared_test WHERE id = $1"), + Result = pgo:query_prepared("q_maps", [1], OIDs, #{decode_opts => [return_rows_as_maps, column_name_as_atom]}), + ?assertMatch(#{command := select, rows := [#{id := 1, name := <<"alice">>, value := 10}]}, Result). + +%%---------------------------------------------------------------------- +%% with_conn tests (prepare + query on same connection) +%%---------------------------------------------------------------------- + +with_conn_prepare_and_query(_Config) -> + %% Use with_conn to ensure prepare and query happen on same connection + Result = pgo:with_conn(default, fun() -> + {ok, _, OIDs} = pgo:prepare("wc_test", "SELECT name FROM prepared_test WHERE id = $1"), + pgo:query_prepared("wc_test", [2], OIDs) + end), + ?assertMatch(#{command := select, rows := [{<<"bob">>}]}, Result). + +auto_prepare_across_pool(_Config) -> + %% Start a pool with multiple connections + {ok, _} = pgo_sup:start_child(multi_pool, #{pool_size => 5, + port => 5432, + database => "test", + user => "test", + password => "password"}), + %% Prepare on one connection + {ok, _, OIDs} = pgo:prepare("auto_prep_test", + "SELECT name FROM prepared_test WHERE id = $1", + #{pool => multi_pool}), + %% Execute many times — will hit different connections, auto-prepare should kick in + Results = [pgo:query_prepared("auto_prep_test", [I], OIDs, #{pool => multi_pool}) + || I <- lists:seq(1, 20)], + %% All should succeed (no "statement not found" errors) + lists:foreach( + fun(R) -> ?assertMatch(#{command := select}, R) end, + Results + ), + %% Verify correct data comes back + ?assertMatch(#{rows := [{<<"alice">>}]}, + pgo:query_prepared("auto_prep_test", [1], OIDs, #{pool => multi_pool})), + ?assertMatch(#{rows := [{<<"bob">>}]}, + pgo:query_prepared("auto_prep_test", [2], OIDs, #{pool => multi_pool})), + application:stop(pgo), + application:ensure_all_started(pgo), + {ok, _} = pgo_sup:start_child(default, #{pool_size => 1, + port => 5432, + database => "test", + user => "test", + password => "password"}), + ok.