Skip to content

Commit 6514c98

Browse files
committed
Supported ENUM PostgreSQL Type
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 0be222e commit 6514c98

File tree

3 files changed

+75
-75
lines changed

3 files changed

+75
-75
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/tests/test_value_converter.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ async def test_deserialization_composite_into_python(
283283
"""Test that it's possible to deserialize custom postgresql type."""
284284
await psql_pool.execute("DROP TABLE IF EXISTS for_test")
285285
await psql_pool.execute("DROP TYPE IF EXISTS all_types")
286+
await psql_pool.execute("DROP TYPE IF EXISTS inner_type")
287+
await psql_pool.execute("DROP TYPE IF EXISTS enum_type")
288+
await psql_pool.execute("CREATE TYPE inner_type AS (inner_value VARCHAR)")
289+
await psql_pool.execute("CREATE TYPE enum_type AS ENUM ('sad', 'ok', 'happy')")
286290
create_type_query = """
287291
CREATE type all_types AS (
288292
bytea_ BYTEA,
@@ -317,7 +321,9 @@ async def test_deserialization_composite_into_python(
317321
uuid_arr UUID ARRAY,
318322
inet_arr INET ARRAY,
319323
jsonb_arr JSONB ARRAY,
320-
json_arr JSON ARRAY
324+
json_arr JSON ARRAY,
325+
test_inner_value inner_type,
326+
test_enum_type enum_type
321327
)
322328
"""
323329
create_table_query = """
@@ -331,7 +337,7 @@ async def test_deserialization_composite_into_python(
331337
querystring=create_table_query,
332338
)
333339
await psql_pool.execute(
334-
querystring="INSERT INTO for_test VALUES (ROW($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32))", # noqa: E501
340+
querystring="INSERT INTO for_test VALUES (ROW($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, ROW($33), $34))", # noqa: E501
335341
parameters=[
336342
b"Bytes",
337343
"Some String",
@@ -395,9 +401,19 @@ async def test_deserialization_composite_into_python(
395401
},
396402
),
397403
],
404+
"inner type value",
405+
"ok",
398406
],
399407
)
400408

409+
class TestEnum(Enum):
410+
OK = "ok"
411+
SAD = "sad"
412+
HAPPY = "happy"
413+
414+
class ValidateModelForInnerValueType(BaseModel):
415+
inner_value: str
416+
401417
class ValidateModelForCustomType(BaseModel):
402418
bytea_: List[int]
403419
varchar_: str
@@ -433,6 +449,9 @@ class ValidateModelForCustomType(BaseModel):
433449
jsonb_arr: List[Dict[str, List[Union[str, int, List[str]]]]]
434450
json_arr: List[Dict[str, List[Union[str, int, List[str]]]]]
435451

452+
test_inner_value: ValidateModelForInnerValueType
453+
test_enum_type: TestEnum
454+
436455
class TopLevelModel(BaseModel):
437456
custom_type: ValidateModelForCustomType
438457

src/value_converter.rs

Lines changed: 52 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
22
use macaddr::{MacAddr6, MacAddr8};
3-
use postgres_types::{Field, FromSql, Kind};
3+
use postgres_types::{Field, FromSql, Kind, ToSql};
44
use serde_json::{json, Map, Value};
55
use std::{fmt::Debug, net::IpAddr};
66
use uuid::Uuid;
@@ -15,7 +15,7 @@ use pyo3::{
1515
Bound, Py, PyAny, Python, ToPyObject,
1616
};
1717
use tokio_postgres::{
18-
types::{to_sql_checked, ToSql, Type},
18+
types::{to_sql_checked, Type},
1919
Column, Row,
2020
};
2121

@@ -700,16 +700,12 @@ fn postgres_bytes_to_py(
700700
pub fn composite_postgres_to_py(
701701
py: Python<'_>,
702702
fields: &Vec<Field>,
703-
buf: &[u8],
703+
buf: &mut &[u8],
704704
custom_decoders: &Option<Py<PyDict>>,
705705
) -> RustPSQLDriverPyResult<Py<PyAny>> {
706-
let mut vec_buf: Vec<u8> = vec![];
707-
vec_buf.extend_from_slice(buf);
708-
let mut buf: &[u8] = vec_buf.as_slice();
709-
710706
let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py);
711707

712-
let num_fields = postgres_types::private::read_be_i32(&mut buf).map_err(|err| {
708+
let num_fields = postgres_types::private::read_be_i32(buf).map_err(|err| {
713709
RustPSQLDriverError::RustToPyValueConversionError(format!(
714710
"Cannot read bytes data from PostgreSQL: {err}"
715711
))
@@ -723,7 +719,7 @@ pub fn composite_postgres_to_py(
723719
}
724720

725721
for field in fields {
726-
let oid = postgres_types::private::read_be_i32(&mut buf).map_err(|err| {
722+
let oid = postgres_types::private::read_be_i32(buf).map_err(|err| {
727723
RustPSQLDriverError::RustToPyValueConversionError(format!(
728724
"Cannot read bytes data from PostgreSQL: {err}"
729725
))
@@ -735,22 +731,28 @@ pub fn composite_postgres_to_py(
735731
));
736732
}
737733

738-
if let Ok(data_from_psql) = postgres_bytes_to_py(py, field.type_(), &mut buf, false) {
739-
result_py_dict.set_item(field.name(), data_from_psql.to_object(py))?;
740-
} else {
741-
let new_buf = &buf[4..];
742-
743-
result_py_dict.set_item(
744-
field.name(),
745-
raw_bytes_data_process(
746-
py,
747-
Some(new_buf),
734+
match field.type_().kind() {
735+
Kind::Simple | Kind::Array(_) => {
736+
result_py_dict.set_item(
748737
field.name(),
749-
field.type_(),
750-
custom_decoders,
751-
)?
752-
.to_object(py),
753-
)?;
738+
postgres_bytes_to_py(py, field.type_(), buf, false)?.to_object(py),
739+
)?;
740+
}
741+
Kind::Enum(_) => {
742+
result_py_dict.set_item(
743+
field.name(),
744+
postgres_bytes_to_py(py, &Type::VARCHAR, buf, false)?.to_object(py),
745+
)?;
746+
}
747+
_ => {
748+
let (_, tail) = buf.split_at(4_usize);
749+
*buf = tail;
750+
result_py_dict.set_item(
751+
field.name(),
752+
raw_bytes_data_process(py, buf, field.name(), field.type_(), custom_decoders)?
753+
.to_object(py),
754+
)?;
755+
}
754756
}
755757
}
756758

@@ -765,7 +767,7 @@ pub fn composite_postgres_to_py(
765767
/// type into rust one.
766768
pub fn raw_bytes_data_process(
767769
py: Python<'_>,
768-
raw_bytes_data: Option<&[u8]>,
770+
raw_bytes_data: &mut &[u8],
769771
column_name: &str,
770772
column_type: &Type,
771773
custom_decoders: &Option<Py<PyDict>>,
@@ -776,24 +778,24 @@ pub fn raw_bytes_data_process(
776778
.get_item(column_name.to_lowercase());
777779

778780
if let Ok(Some(py_encoder_func)) = py_encoder_func {
779-
return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
781+
return Ok(py_encoder_func
782+
.call((raw_bytes_data.to_vec(),), None)?
783+
.unbind());
780784
}
781785
}
782786

783-
match raw_bytes_data {
784-
Some(mut raw_bytes_data) => match column_type.kind() {
785-
Kind::Simple | Kind::Array(_) => {
786-
postgres_bytes_to_py(py, column_type, &mut raw_bytes_data, true)
787-
}
788-
Kind::Composite(fields) => {
789-
composite_postgres_to_py(py, fields, raw_bytes_data, custom_decoders)
790-
}
791-
Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, &mut raw_bytes_data, true),
792-
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
793-
column_type.to_string(),
794-
)),
795-
},
796-
None => Ok(py.None()),
787+
match column_type.kind() {
788+
Kind::Simple | Kind::Array(_) => {
789+
postgres_bytes_to_py(py, column_type, raw_bytes_data, true)
790+
}
791+
Kind::Composite(fields) => {
792+
println!("1 {:p}", &raw_bytes_data);
793+
composite_postgres_to_py(py, fields, raw_bytes_data, custom_decoders)
794+
}
795+
Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, raw_bytes_data, true),
796+
_ => Err(RustPSQLDriverError::RustToPyValueConversionError(
797+
column_type.to_string(),
798+
)),
797799
}
798800
}
799801

@@ -811,37 +813,16 @@ pub fn postgres_to_py(
811813
custom_decoders: &Option<Py<PyDict>>,
812814
) -> RustPSQLDriverPyResult<Py<PyAny>> {
813815
let raw_bytes_data = row.col_buffer(column_i);
814-
raw_bytes_data_process(
815-
py,
816-
raw_bytes_data,
817-
column.name(),
818-
column.type_(),
819-
custom_decoders,
820-
)
821-
// if let Some(custom_decoders) = custom_decoders {
822-
// let py_encoder_func = custom_decoders
823-
// .bind(py)
824-
// .get_item(column.name().to_lowercase());
825-
826-
// if let Ok(Some(py_encoder_func)) = py_encoder_func {
827-
// return Ok(py_encoder_func.call((raw_bytes_data,), None)?.unbind());
828-
// }
829-
// }
830-
831-
// let column_type = column.type_();
832-
// match raw_bytes_data {
833-
// Some(mut raw_bytes_data) => match column_type.kind() {
834-
// Kind::Simple | Kind::Array(_) => {
835-
// postgres_bytes_to_py(py, column.type_(), &mut raw_bytes_data, true)
836-
// }
837-
// Kind::Composite(fields) => composite_postgres_to_py(py, fields, raw_bytes_data, column),
838-
// Kind::Enum(_) => postgres_bytes_to_py(py, &Type::VARCHAR, &mut raw_bytes_data, true),
839-
// _ => Err(RustPSQLDriverError::RustToPyValueConversionError(
840-
// column.type_().to_string(),
841-
// )),
842-
// },
843-
// None => Ok(py.None()),
844-
// }
816+
if let Some(mut raw_bytes_data) = raw_bytes_data {
817+
return raw_bytes_data_process(
818+
py,
819+
&mut raw_bytes_data,
820+
column.name(),
821+
column.type_(),
822+
custom_decoders,
823+
);
824+
}
825+
Ok(py.None())
845826
}
846827

847828
/// Convert python List of Dict type or just Dict into serde `Value`.

0 commit comments

Comments
 (0)