@@ -2,6 +2,7 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
22use geo_types:: { coord, Coord , Line as LineSegment , LineString , Point , Rect } ;
33use itertools:: Itertools ;
44use macaddr:: { MacAddr6 , MacAddr8 } ;
5+ use pg_interval:: Interval ;
56use postgres_types:: { Field , FromSql , Kind , ToSql } ;
67use rust_decimal:: Decimal ;
78use serde_json:: { json, Map , Value } ;
@@ -13,8 +14,8 @@ use postgres_protocol::types;
1314use pyo3:: {
1415 sync:: GILOnceCell ,
1516 types:: {
16- PyAnyMethods , PyBool , PyBytes , PyDate , PyDateTime , PyDict , PyDictMethods , PyFloat , PyInt ,
17- PyIterator , PyList , PyListMethods , PySequence , PySet , PyString , PyTime , PyTuple , PyType ,
17+ PyAnyMethods , PyBool , PyBytes , PyDate , PyDateTime , PyDelta , PyDict , PyDictMethods , PyFloat ,
18+ PyInt , PyList , PyListMethods , PySequence , PySet , PyString , PyTime , PyTuple , PyType ,
1819 PyTypeMethods ,
1920 } ,
2021 Bound , FromPyObject , IntoPy , Py , PyAny , PyObject , PyResult , Python , ToPyObject ,
@@ -35,6 +36,7 @@ use crate::{
3536use postgres_array:: { array:: Array , Dimension } ;
3637
3738static DECIMAL_CLS : GILOnceCell < Py < PyType > > = GILOnceCell :: new ( ) ;
39+ static TIMEDELTA_CLS : GILOnceCell < Py < PyType > > = GILOnceCell :: new ( ) ;
3840
3941pub type QueryParameter = ( dyn ToSql + Sync ) ;
4042
@@ -50,6 +52,18 @@ fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
5052 . map ( |ty| ty. bind ( py) )
5153}
5254
55+ fn get_timedelta_cls ( py : Python < ' _ > ) -> PyResult < & Bound < ' _ , PyType > > {
56+ TIMEDELTA_CLS
57+ . get_or_try_init ( py, || {
58+ let type_object = py
59+ . import_bound ( "datetime" ) ?
60+ . getattr ( "timedelta" ) ?
61+ . downcast_into ( ) ?;
62+ Ok ( type_object. unbind ( ) )
63+ } )
64+ . map ( |ty| ty. bind ( py) )
65+ }
66+
5367/// Struct for Uuid.
5468///
5569/// We use custom struct because we need to implement external traits
@@ -138,13 +152,43 @@ impl<'a> FromSql<'a> for InnerDecimal {
138152 }
139153}
140154
155+ struct InnerInterval ( Interval ) ;
156+
157+ impl ToPyObject for InnerInterval {
158+ fn to_object ( & self , py : Python < ' _ > ) -> PyObject {
159+ let td_cls = get_timedelta_cls ( py) . expect ( "failed to load datetime.timedelta" ) ;
160+ let pydict = PyDict :: new_bound ( py) ;
161+ let months = self . 0 . months * 30 ;
162+ let _ = pydict. set_item ( "days" , self . 0 . days + months) ;
163+ let _ = pydict. set_item ( "microseconds" , self . 0 . microseconds ) ;
164+ let ret = td_cls
165+ . call ( ( ) , Some ( & pydict) )
166+ . expect ( "failed to call datetime.timedelta(days=<>, microseconds=<>)" ) ;
167+ ret. to_object ( py)
168+ }
169+ }
170+
171+ impl < ' a > FromSql < ' a > for InnerInterval {
172+ fn from_sql (
173+ ty : & Type ,
174+ raw : & ' a [ u8 ] ,
175+ ) -> Result < Self , Box < dyn std:: error:: Error + Sync + Send > > {
176+ Ok ( InnerInterval ( <Interval as FromSql >:: from_sql ( ty, raw) ?) )
177+ }
178+
179+ fn accepts ( _ty : & Type ) -> bool {
180+ true
181+ }
182+ }
183+
141184/// Additional type for types come from Python.
142185///
143186/// It's necessary because we need to pass this
144187/// enum into `to_sql` method of `ToSql` trait from
145188/// `postgres` crate.
146189#[ derive( Debug , Clone , PartialEq ) ]
147190pub enum PythonDTO {
191+ // Primitive
148192 PyNone ,
149193 PyBytes ( Vec < u8 > ) ,
150194 PyBool ( bool ) ,
@@ -164,6 +208,7 @@ pub enum PythonDTO {
164208 PyTime ( NaiveTime ) ,
165209 PyDateTime ( NaiveDateTime ) ,
166210 PyDateTimeTz ( DateTime < FixedOffset > ) ,
211+ PyInterval ( Interval ) ,
167212 PyIpAddress ( IpAddr ) ,
168213 PyList ( Vec < PythonDTO > ) ,
169214 PyArray ( Array < PythonDTO > ) ,
@@ -180,6 +225,7 @@ pub enum PythonDTO {
180225 PyLine ( Line ) ,
181226 PyLineSegment ( LineSegment ) ,
182227 PyCircle ( Circle ) ,
228+ // Arrays
183229 PyBoolArray ( Array < PythonDTO > ) ,
184230 PyUuidArray ( Array < PythonDTO > ) ,
185231 PyVarCharArray ( Array < PythonDTO > ) ,
@@ -206,6 +252,7 @@ pub enum PythonDTO {
206252 PyLineArray ( Array < PythonDTO > ) ,
207253 PyLsegArray ( Array < PythonDTO > ) ,
208254 PyCircleArray ( Array < PythonDTO > ) ,
255+ PyIntervalArray ( Array < PythonDTO > ) ,
209256}
210257
211258impl ToPyObject for PythonDTO {
@@ -267,6 +314,7 @@ impl PythonDTO {
267314 PythonDTO :: PyLine ( _) => Ok ( tokio_postgres:: types:: Type :: LINE_ARRAY ) ,
268315 PythonDTO :: PyLineSegment ( _) => Ok ( tokio_postgres:: types:: Type :: LSEG_ARRAY ) ,
269316 PythonDTO :: PyCircle ( _) => Ok ( tokio_postgres:: types:: Type :: CIRCLE_ARRAY ) ,
317+ PythonDTO :: PyInterval ( _) => Ok ( tokio_postgres:: types:: Type :: INTERVAL_ARRAY ) ,
270318 _ => Err ( RustPSQLDriverError :: PyToRustValueConversionError (
271319 "Can't process array type, your type doesn't have support yet" . into ( ) ,
272320 ) ) ,
@@ -385,6 +433,9 @@ impl ToSql for PythonDTO {
385433 PythonDTO :: PyDateTimeTz ( pydatetime_tz) => {
386434 <& DateTime < FixedOffset > as ToSql >:: to_sql ( & pydatetime_tz, ty, out) ?;
387435 }
436+ PythonDTO :: PyInterval ( pyinterval) => {
437+ <& Interval as ToSql >:: to_sql ( & pyinterval, ty, out) ?;
438+ }
388439 PythonDTO :: PyIpAddress ( pyidaddress) => {
389440 <& IpAddr as ToSql >:: to_sql ( & pyidaddress, ty, out) ?;
390441 }
@@ -525,6 +576,9 @@ impl ToSql for PythonDTO {
525576 PythonDTO :: PyCircleArray ( array) => {
526577 array. to_sql ( & Type :: CIRCLE_ARRAY , out) ?;
527578 }
579+ PythonDTO :: PyIntervalArray ( array) => {
580+ array. to_sql ( & Type :: INTERVAL_ARRAY , out) ?;
581+ }
528582 }
529583
530584 if return_is_null_true {
@@ -787,6 +841,16 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
787841 return Ok ( PythonDTO :: PyTime ( parameter. extract :: < NaiveTime > ( ) ?) ) ;
788842 }
789843
844+ if parameter. is_instance_of :: < PyDelta > ( ) {
845+ let duration = parameter. extract :: < chrono:: Duration > ( ) ?;
846+ if let Some ( interval) = Interval :: from_duration ( duration) {
847+ return Ok ( PythonDTO :: PyInterval ( interval) ) ;
848+ }
849+ return Err ( RustPSQLDriverError :: PyToRustValueConversionError (
850+ "Cannot convert timedelta from Python to inner Rust type." . to_string ( ) ,
851+ ) ) ;
852+ }
853+
790854 if parameter. is_instance_of :: < PyList > ( ) | parameter. is_instance_of :: < PyTuple > ( ) {
791855 return Ok ( PythonDTO :: PyArray ( py_sequence_into_postgres_array (
792856 parameter,
@@ -1052,6 +1116,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
10521116 . _convert_to_python_dto ( ) ;
10531117 }
10541118
1119+ if parameter. is_instance_of :: < extra_types:: IntervalArray > ( ) {
1120+ return parameter
1121+ . extract :: < extra_types:: IntervalArray > ( ) ?
1122+ . _convert_to_python_dto ( ) ;
1123+ }
1124+
10551125 if let Ok ( id_address) = parameter. extract :: < IpAddr > ( ) {
10561126 return Ok ( PythonDTO :: PyIpAddress ( id_address) ) ;
10571127 }
@@ -1065,9 +1135,6 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
10651135 }
10661136 }
10671137
1068- let a = parameter. downcast :: < PyIterator > ( ) ;
1069- println ! ( "{:?}" , a. iter( ) ) ;
1070-
10711138 Err ( RustPSQLDriverError :: PyToRustValueConversionError ( format ! (
10721139 "Can not covert you type {parameter} into inner one" ,
10731140 ) ) )
@@ -1387,6 +1454,13 @@ fn postgres_bytes_to_py(
13871454 None => Ok ( py. None ( ) . to_object ( py) ) ,
13881455 }
13891456 }
1457+ Type :: INTERVAL => {
1458+ let interval = _composite_field_postgres_to_py :: < Option < Interval > > ( type_, buf, is_simple) ?;
1459+ if let Some ( interval) = interval {
1460+ return Ok ( InnerInterval ( interval) . to_object ( py) ) ;
1461+ }
1462+ Ok ( py. None ( ) )
1463+ }
13901464 // ---------- Array Text Types ----------
13911465 Type :: BOOL_ARRAY => Ok ( postgres_array_to_py ( py, _composite_field_postgres_to_py :: < Option < Array < bool > > > (
13921466 type_, buf, is_simple,
@@ -1505,6 +1579,11 @@ fn postgres_bytes_to_py(
15051579
15061580 Ok ( postgres_array_to_py ( py, circle_array_) . to_object ( py) )
15071581 }
1582+ Type :: INTERVAL_ARRAY => {
1583+ let interval_array_ = _composite_field_postgres_to_py :: < Option < Array < InnerInterval > > > ( type_, buf, is_simple) ?;
1584+
1585+ Ok ( postgres_array_to_py ( py, interval_array_) . to_object ( py) )
1586+ }
15081587 _ => Err ( RustPSQLDriverError :: RustToPyValueConversionError (
15091588 format ! ( "Cannot convert {type_} into Python type, please look at the custom_decoders functionality." )
15101589 ) ) ,
0 commit comments