Skip to content

Commit 27fc955

Browse files
committed
Supported ENUM PostgreSQL Type
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 4b39ac1 commit 27fc955

File tree

2 files changed

+123
-16
lines changed

2 files changed

+123
-16
lines changed

python/tests/test_value_converter.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import uuid
3+
from enum import Enum, StrEnum
34
from ipaddress import IPv4Address
45
from typing import Any, Dict, List, Union
56

@@ -446,6 +447,41 @@ class TopLevelModel(BaseModel):
446447
assert isinstance(model_result[0], TopLevelModel)
447448

448449

450+
async def test_enum_type(psql_pool: ConnectionPool) -> None:
451+
"""Test that we can decode ENUM type from PostgreSQL."""
452+
453+
class TestEnum(Enum):
454+
OK = "ok"
455+
SAD = "sad"
456+
HAPPY = "happy"
457+
458+
class TestStrEnum(StrEnum):
459+
OK = "ok"
460+
SAD = "sad"
461+
HAPPY = "happy"
462+
463+
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
464+
await psql_pool.execute("DROP TYPE IF EXISTS mood")
465+
await psql_pool.execute(
466+
"CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')",
467+
)
468+
await psql_pool.execute(
469+
"CREATE TABLE for_test (test_mood mood, test_mood2 mood)",
470+
)
471+
472+
await psql_pool.execute(
473+
querystring="INSERT INTO for_test VALUES ($1, $2)",
474+
parameters=[TestEnum.HAPPY, TestEnum.OK],
475+
)
476+
477+
qs_result = await psql_pool.execute(
478+
"SELECT * FROM for_test",
479+
)
480+
assert qs_result.result()[0]["test_mood"] == TestEnum.HAPPY.value
481+
assert qs_result.result()[0]["test_mood"] != TestEnum.HAPPY
482+
assert qs_result.result()[0]["test_mood2"] == TestStrEnum.OK
483+
484+
449485
async def test_custom_type_as_parameter(
450486
psql_pool: ConnectionPool,
451487
) -> None:

src/value_converter.rs

Lines changed: 87 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,15 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
447447
return Ok(PythonDTO::PyIpAddress(id_address));
448448
}
449449

450+
// It's used for Enum.
451+
// If StrEnum is used on Python side,
452+
// we simply stop at the `is_instance_of::<PyString>``.
453+
if let Ok(value_attr) = parameter.getattr("value") {
454+
if let Ok(possible_string) = value_attr.extract::<String>() {
455+
return Ok(PythonDTO::PyString(possible_string));
456+
}
457+
}
458+
450459
Err(RustPSQLDriverError::PyToRustValueConversionError(format!(
451460
"Can not covert you type {parameter} into inner one",
452461
)))
@@ -692,6 +701,7 @@ pub fn composite_postgres_to_py(
692701
py: Python<'_>,
693702
fields: &Vec<Field>,
694703
buf: &[u8],
704+
custom_decoders: &Option<Py<PyDict>>,
695705
) -> RustPSQLDriverPyResult<Py<PyAny>> {
696706
let mut vec_buf: Vec<u8> = vec![];
697707
vec_buf.extend_from_slice(buf);
@@ -718,61 +728,122 @@ pub fn composite_postgres_to_py(
718728
"Cannot read bytes data from PostgreSQL: {err}"
719729
))
720730
})? as u32;
731+
721732
if oid != field.type_().oid() {
722733
return Err(RustPSQLDriverError::RustToPyValueConversionError(
723734
"unexpected OID".into(),
724735
));
725736
}
726737

727-
result_py_dict.set_item(
728-
field.name(),
729-
postgres_bytes_to_py(py, field.type_(), &mut buf, false)?.to_object(py),
730-
)?;
738+
if let Ok(data_from_psql) = postgres_bytes_to_py(py, field.type_(), &mut buf, false) {
739+
result_py_dict.set_item(field.name(), data_from_psql.to_object(py))?;
740+
} else {
741+
let new_buf = &buf[4..];
742+
743+
result_py_dict.set_item(
744+
field.name(),
745+
raw_bytes_data_process(
746+
py,
747+
Some(new_buf),
748+
field.name(),
749+
field.type_(),
750+
custom_decoders,
751+
)?
752+
.to_object(py),
753+
)?;
754+
}
731755
}
732756

733757
Ok(result_py_dict.to_object(py))
734758
}
735759

736-
/// Convert type from postgres to python type.
760+
/// Process raw bytes from `PostgreSQL`.
737761
///
738762
/// # Errors
739763
///
740764
/// May return Err Result if cannot convert postgres
741765
/// type into rust one.
742-
pub fn postgres_to_py(
766+
pub fn raw_bytes_data_process(
743767
py: Python<'_>,
744-
row: &Row,
745-
column: &Column,
746-
column_i: usize,
768+
raw_bytes_data: Option<&[u8]>,
769+
column_name: &str,
770+
column_type: &Type,
747771
custom_decoders: &Option<Py<PyDict>>,
748772
) -> RustPSQLDriverPyResult<Py<PyAny>> {
749-
let raw_bytes_data = row.col_buffer(column_i);
750-
751773
if let Some(custom_decoders) = custom_decoders {
752774
let py_encoder_func = custom_decoders
753775
.bind(py)
754-
.get_item(column.name().to_lowercase());
776+
.get_item(column_name.to_lowercase());
755777

756778
if let Ok(Some(py_encoder_func)) = py_encoder_func {
757779
return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
758780
}
759781
}
760782

761-
let column_type = column.type_();
762783
match raw_bytes_data {
763784
Some(mut raw_bytes_data) => match column_type.kind() {
764785
Kind::Simple | Kind::Array(_) => {
765-
postgres_bytes_to_py(py, column.type_(), &mut raw_bytes_data, true)
786+
postgres_bytes_to_py(py, column_type, &mut raw_bytes_data, true)
787+
}
788+
Kind::Composite(fields) => {
789+
composite_postgres_to_py(py, fields, raw_bytes_data, custom_decoders)
766790
}
767-
Kind::Composite(fields) => composite_postgres_to_py(py, fields, raw_bytes_data),
791+
Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, &mut raw_bytes_data, true),
768792
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
769-
column.type_().to_string(),
793+
column_type.to_string(),
770794
)),
771795
},
772796
None => Ok(py.None()),
773797
}
774798
}
775799

800+
/// Convert type from postgres to python type.
801+
///
802+
/// # Errors
803+
///
804+
/// May return Err Result if cannot convert postgres
805+
/// type into rust one.
806+
pub fn postgres_to_py(
807+
py: Python<'_>,
808+
row: &Row,
809+
column: &Column,
810+
column_i: usize,
811+
custom_decoders: &Option<Py<PyDict>>,
812+
) -> RustPSQLDriverPyResult<Py<PyAny>> {
813+
let raw_bytes_data = row.col_buffer(column_i);
814+
raw_bytes_data_process(
815+
py,
816+
raw_bytes_data,
817+
column.name(),
818+
column.type_(),
819+
custom_decoders,
820+
)
821+
// if let Some(custom_decoders) = custom_decoders {
822+
// let py_encoder_func = custom_decoders
823+
// .bind(py)
824+
// .get_item(column.name().to_lowercase());
825+
826+
// if let Ok(Some(py_encoder_func)) = py_encoder_func {
827+
// return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
828+
// }
829+
// }
830+
831+
// let column_type = column.type_();
832+
// match raw_bytes_data {
833+
// Some(mut raw_bytes_data) => match column_type.kind() {
834+
// Kind::Simple | Kind::Array(_) => {
835+
// postgres_bytes_to_py(py, column.type_(), &mut raw_bytes_data, true)
836+
// }
837+
// Kind::Composite(fields) => composite_postgres_to_py(py, fields, raw_bytes_data, column),
838+
// Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, &mut raw_bytes_data, true),
839+
// _ => Err(RustPSQLDriverError::RustToPyValueConversionError(
840+
// column.type_().to_string(),
841+
// )),
842+
// },
843+
// None => Ok(py.None()),
844+
// }
845+
}
846+
776847
/// Convert python List of Dict type or just Dict into serde `Value`.
777848
///
778849
/// # Errors

0 commit comments

Comments
 (0)