Skip to content

Commit 84232e1

Browse files
authored
Merge pull request #35 from qaspen-python/feature/custom_encoder_decoder
Custom decoders and encoders for not supported types
2 parents 15dcaf9 + 23bc3d0 commit 84232e1

File tree

9 files changed

+122
-16
lines changed

9 files changed

+122
-16
lines changed

python/psqlpy/_internal/__init__.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ _CustomClass = TypeVar(
1111
class QueryResult:
1212
"""Result."""
1313

14-
def result(self: Self) -> list[dict[Any, Any]]:
14+
def result(
15+
self: Self,
16+
custom_decoders: dict[str, Callable[[bytes], Any]] | None = None,
17+
) -> list[dict[Any, Any]]:
1518
"""Return result from database as a list of dicts."""
1619
def as_class(
1720
self: Self,

python/psqlpy/_internal/extra_types.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ class PyMacAddr8:
123123
### Parameters:
124124
- `value`: value for MACADDR8 field.
125125
"""
126+
127+
class PyCustomType:
128+
def __init__(self, value: bytes) -> None: ...

python/psqlpy/extra_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._internal.extra_types import (
22
BigInt,
33
Integer,
4+
PyCustomType,
45
PyJSON,
56
PyJSONB,
67
PyMacAddr6,
@@ -22,4 +23,5 @@
2223
"PyMacAddr8",
2324
"PyVarChar",
2425
"PyText",
26+
"PyCustomType",
2527
]

python/tests/test_value_converter.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tests.conftest import DefaultPydanticModel, DefaultPythonModelClass
99

1010
from psqlpy import ConnectionPool
11+
from psqlpy._internal.extra_types import PyCustomType
1112
from psqlpy.extra_types import (
1213
BigInt,
1314
Integer,
@@ -425,3 +426,52 @@ class TopLevelModel(BaseModel):
425426
)
426427

427428
assert isinstance(model_result[0], TopLevelModel)
429+
430+
431+
async def test_custom_type_as_parameter(
432+
psql_pool: ConnectionPool,
433+
) -> None:
434+
"""Tests that we can use `PyCustomType`."""
435+
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
436+
await psql_pool.execute(
437+
"CREATE TABLE for_test (nickname VARCHAR)",
438+
)
439+
440+
await psql_pool.execute(
441+
querystring="INSERT INTO for_test VALUES ($1)",
442+
parameters=[PyCustomType(b"Some Real Nickname")],
443+
)
444+
445+
qs_result = await psql_pool.execute(
446+
"SELECT * FROM for_test",
447+
)
448+
449+
result = qs_result.result()
450+
assert result[0]["nickname"] == "Some Real Nickname"
451+
452+
453+
async def test_custom_decoder(
454+
psql_pool: ConnectionPool,
455+
) -> None:
456+
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
457+
await psql_pool.execute(
458+
"CREATE TABLE for_test (geo_point POINT)",
459+
)
460+
461+
await psql_pool.execute(
462+
"INSERT INTO for_test VALUES ('(1, 1)')",
463+
)
464+
465+
def point_encoder(point_bytes: bytes) -> str:
466+
return "Just An Example"
467+
468+
qs_result = await psql_pool.execute(
469+
"SELECT * FROM for_test",
470+
)
471+
result = qs_result.result(
472+
custom_decoders={
473+
"geo_point": point_encoder,
474+
},
475+
)
476+
477+
assert result[0]["geo_point"] == "Just An Example"

src/driver/connection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ impl Connection {
257257
};
258258

259259
Python::with_gil(|gil| match result.columns().first() {
260-
Some(first_column) => postgres_to_py(gil, &result, first_column, 0),
260+
Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None),
261261
None => Ok(gil.None()),
262262
})
263263
}

src/driver/transaction.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ impl Transaction {
331331
};
332332

333333
Python::with_gil(|gil| match result.columns().first() {
334-
Some(first_column) => postgres_to_py(gil, &result, first_column, 0),
334+
Some(first_column) => postgres_to_py(gil, &result, first_column, 0, &None),
335335
None => Ok(gil.None()),
336336
})
337337
}

src/extra_types.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,27 @@ macro_rules! build_macaddr_type {
190190
build_macaddr_type!(PyMacAddr6, MacAddr6);
191191
build_macaddr_type!(PyMacAddr8, MacAddr8);
192192

193+
#[pyclass]
194+
#[derive(Clone, Debug)]
195+
pub struct PyCustomType {
196+
inner: Vec<u8>,
197+
}
198+
199+
impl PyCustomType {
200+
#[must_use]
201+
pub fn inner(&self) -> Vec<u8> {
202+
self.inner.clone()
203+
}
204+
}
205+
206+
#[pymethods]
207+
impl PyCustomType {
208+
#[new]
209+
fn new_class(type_bytes: Vec<u8>) -> Self {
210+
PyCustomType { inner: type_bytes }
211+
}
212+
}
213+
193214
#[allow(clippy::module_name_repetitions)]
194215
#[allow(clippy::missing_errors_doc)]
195216
pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> {
@@ -203,5 +224,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
203224
pymod.add_class::<PyJSON>()?;
204225
pymod.add_class::<PyMacAddr6>()?;
205226
pymod.add_class::<PyMacAddr8>()?;
227+
pymod.add_class::<PyCustomType>()?;
206228
Ok(())
207229
}

src/query_result.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ use crate::{exceptions::rust_errors::RustPSQLDriverPyResult, value_converter::po
1313
fn row_to_dict<'a>(
1414
py: Python<'a>,
1515
postgres_row: &'a Row,
16+
custom_decoders: &Option<Py<PyDict>>,
1617
) -> RustPSQLDriverPyResult<pyo3::Bound<'a, PyDict>> {
1718
let python_dict = PyDict::new_bound(py);
1819
for (column_idx, column) in postgres_row.columns().iter().enumerate() {
19-
let python_type = postgres_to_py(py, postgres_row, column, column_idx)?;
20+
let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?;
2021
python_dict.set_item(column.name().to_object(py), python_type)?;
2122
}
2223
Ok(python_dict)
@@ -55,10 +56,14 @@ impl PSQLDriverPyQueryResult {
5556
/// postgres type to python or set new key-value pair
5657
/// in python dict.
5758
#[allow(clippy::needless_pass_by_value)]
58-
pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult<Py<PyAny>> {
59+
pub fn result(
60+
&self,
61+
py: Python<'_>,
62+
custom_decoders: Option<Py<PyDict>>,
63+
) -> RustPSQLDriverPyResult<Py<PyAny>> {
5964
let mut result: Vec<pyo3::Bound<'_, PyDict>> = vec![];
6065
for row in &self.inner {
61-
result.push(row_to_dict(py, row)?);
66+
result.push(row_to_dict(py, row, &custom_decoders)?);
6267
}
6368
Ok(result.to_object(py))
6469
}
@@ -77,7 +82,7 @@ impl PSQLDriverPyQueryResult {
7782
) -> RustPSQLDriverPyResult<Py<PyAny>> {
7883
let mut res: Vec<Py<PyAny>> = vec![];
7984
for row in &self.inner {
80-
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row)?;
85+
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?;
8186
let convert_class_inst = as_class.call_bound(py, (), Some(&pydict))?;
8287
res.push(convert_class_inst);
8388
}
@@ -117,7 +122,7 @@ impl PSQLDriverSinglePyQueryResult {
117122
/// postgres type to python, can not set new key-value pair
118123
/// in python dict or there are no result.
119124
pub fn result(&self, py: Python<'_>) -> RustPSQLDriverPyResult<Py<PyAny>> {
120-
Ok(row_to_dict(py, &self.inner)?.to_object(py))
125+
Ok(row_to_dict(py, &self.inner, &None)?.to_object(py))
121126
}
122127

123128
/// Convert result from database to any class passed from Python.
@@ -133,7 +138,7 @@ impl PSQLDriverSinglePyQueryResult {
133138
py: Python<'a>,
134139
as_class: Py<PyAny>,
135140
) -> RustPSQLDriverPyResult<Py<PyAny>> {
136-
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner)?;
141+
let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?;
137142
Ok(as_class.call_bound(py, (), Some(&pydict))?)
138143
}
139144
}

src/value_converter.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use crate::{
2323
additional_types::{RustMacAddr6, RustMacAddr8},
2424
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
2525
extra_types::{
26-
BigInt, Integer, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID, PyVarChar,
27-
SmallInt,
26+
BigInt, Integer, PyCustomType, PyJSON, PyJSONB, PyMacAddr6, PyMacAddr8, PyText, PyUUID,
27+
PyVarChar, SmallInt,
2828
},
2929
};
3030

@@ -62,6 +62,7 @@ pub enum PythonDTO {
6262
PyJson(Value),
6363
PyMacAddr6(MacAddr6),
6464
PyMacAddr8(MacAddr8),
65+
PyCustomType(Vec<u8>),
6566
}
6667

6768
impl PythonDTO {
@@ -174,6 +175,9 @@ impl ToSql for PythonDTO {
174175

175176
match self {
176177
PythonDTO::PyNone => {}
178+
PythonDTO::PyCustomType(some_bytes) => {
179+
<&[u8] as ToSql>::to_sql(&some_bytes.as_slice(), ty, out)?;
180+
}
177181
PythonDTO::PyBytes(pybytes) => {
178182
<Vec<u8> as ToSql>::to_sql(pybytes, ty, out)?;
179183
}
@@ -284,6 +288,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
284288
return Ok(PythonDTO::PyNone);
285289
}
286290

291+
if parameter.is_instance_of::<PyCustomType>() {
292+
return Ok(PythonDTO::PyCustomType(
293+
parameter.extract::<PyCustomType>()?.inner(),
294+
));
295+
}
296+
287297
if parameter.is_instance_of::<PyBool>() {
288298
return Ok(PythonDTO::PyBool(parameter.extract::<bool>()?));
289299
}
@@ -652,10 +662,9 @@ fn postgres_bytes_to_py(
652662
None => Ok(py.None().to_object(py)),
653663
}
654664
}
655-
_ => Ok(
656-
_composite_field_postgres_to_py::<Option<Vec<u8>>>(type_, buf, is_simple)?
657-
.to_object(py),
658-
),
665+
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
666+
format!("Cannot convert {type_} into Python type, please look at the custom_decoders functionality.")
667+
)),
659668
}
660669
}
661670

@@ -720,9 +729,21 @@ pub fn postgres_to_py(
720729
row: &Row,
721730
column: &Column,
722731
column_i: usize,
732+
custom_decoders: &Option<Py<PyDict>>,
723733
) -> RustPSQLDriverPyResult<Py<PyAny>> {
724-
let column_type = column.type_();
725734
let raw_bytes_data = row.col_buffer(column_i);
735+
736+
if let Some(custom_decoders) = custom_decoders {
737+
let py_encoder_func = custom_decoders
738+
.bind(py)
739+
.get_item(column.name().to_lowercase());
740+
741+
if let Ok(Some(py_encoder_func)) = py_encoder_func {
742+
return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
743+
}
744+
}
745+
746+
let column_type = column.type_();
726747
match raw_bytes_data {
727748
Some(mut raw_bytes_data) => match column_type.kind() {
728749
Kind::Simple | Kind::Array(_) => {

0 commit comments

Comments
 (0)