diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index ee2a788a7..75e9a7023 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -4,8 +4,9 @@ use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; -use pyo3::types::PyComplex; -use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple}; +use pyo3::types::{ + PyByteArray, PyBytes, PyComplex, PyDict, PyFrozenSet, PyIterator, PyList, PyModule, PySet, PyString, PyTuple, +}; use pyo3::IntoPyObjectExt; use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; @@ -241,6 +242,10 @@ pub(crate) fn infer_to_python_known( let complex_str = type_serializers::complex::complex_to_str(v); complex_str.into_py_any(py)? } + ObType::Module => { + let v = value.downcast::()?; + v.name()?.into() + } ObType::Path => value.str()?.into_py_any(py)?, ObType::Pattern => value.getattr(intern!(py, "pattern"))?.unbind(), ObType::Unknown => { @@ -554,6 +559,11 @@ pub(crate) fn infer_serialize_known( .map_err(py_err_se_err)?; serializer.serialize_str(&s) } + ObType::Module => { + let v = value.downcast::().map_err(py_err_se_err)?; + let s: PyBackedStr = v.name().and_then(|name| name.extract()).map_err(py_err_se_err)?; + serializer.serialize_str(&s) + } ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; @@ -678,6 +688,10 @@ pub(crate) fn infer_json_key_known<'a>( let v = key.downcast::()?; Ok(type_serializers::complex::complex_to_str(v).into()) } + ObType::Module => { + let v = key.downcast::()?; + Ok(Cow::Owned(v.name()?.to_string_lossy().into_owned())) + } ObType::Pattern => Ok(Cow::Owned( key.getattr(intern!(key.py(), "pattern"))? .str()? diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index f1c161dfb..50bba51be 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use pyo3::sync::PyOnceLock; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, - PyIterator, PyList, PyNone, PySet, PyString, PyTime, PyTuple, PyType, + PyIterator, PyList, PyModule, PyNone, PySet, PyString, PyTime, PyTuple, PyType, }; use pyo3::{intern, PyTypeInfo}; @@ -48,6 +48,7 @@ pub struct ObTypeLookup { pattern_object: Py, // uuid type uuid_object: Py, + module_object: usize, complex: usize, } @@ -87,6 +88,7 @@ impl ObTypeLookup { path_object: py.import("pathlib").unwrap().getattr("Path").unwrap().unbind(), pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().unbind(), uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().unbind(), + module_object: PyModule::type_object_raw(py) as usize, complex: PyComplex::type_object_raw(py) as usize, } } @@ -157,8 +159,9 @@ impl ObTypeLookup { ObType::Path => self.path_object.as_ptr() as usize == ob_type, ObType::Pattern => self.path_object.as_ptr() as usize == ob_type, ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type, - ObType::Unknown => false, ObType::Complex => self.complex == ob_type, + ObType::Module => self.module_object == ob_type, + ObType::Unknown => false, }; if ans { @@ -241,6 +244,8 @@ impl ObTypeLookup { ObType::Complex } else if ob_type == self.uuid_object.as_ptr() as usize { ObType::Uuid + } else if ob_type == self.module_object { + ObType::Module } else if is_pydantic_serializable(op_value) { ObType::PydanticSerializable } else if is_dataclass(op_value) { @@ -414,9 +419,10 @@ pub enum ObType { Pattern, // Uuid Uuid, + Complex, + Module, // unknown type Unknown, - Complex, } impl PartialEq for ObType { diff --git a/tests/serializers/test_infer.py b/tests/serializers/test_infer.py index 7762f68bc..c1fd79836 100644 --- a/tests/serializers/test_infer.py +++ b/tests/serializers/test_infer.py @@ -1,10 +1,11 @@ +import os from enum import Enum from pydantic_core import SchemaSerializer, core_schema # serializing enum calls methods in serializers::infer -def test_infer_to_python(): +def test_infer_complex_to_python(): class MyEnum(Enum): complex_ = complex(1, 2) @@ -12,7 +13,7 @@ class MyEnum(Enum): assert v.to_python(MyEnum.complex_, mode='json') == '1+2j' -def test_infer_serialize(): +def test_infer_complex_serialize(): class MyEnum(Enum): complex_ = complex(1, 2) @@ -20,9 +21,26 @@ class MyEnum(Enum): assert v.to_json(MyEnum.complex_) == b'"1+2j"' -def test_infer_json_key(): +def test_infer_complex_json_key(): class MyEnum(Enum): complex_ = {complex(1, 2): 1} v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) assert v.to_json(MyEnum.complex_) == b'{"1+2j":1}' + + +def test_infer_module_type(): + v = SchemaSerializer(core_schema.any_schema()) + assert v.to_python(os) is os + assert v.to_json(os).decode('utf-8') == '"os"' + assert v.to_python(os, serialize_as_any=True) is os + assert v.to_json(os, serialize_as_any=True).decode('utf-8') == '"os"' + + v_as_key = SchemaSerializer( + core_schema.dict_schema(keys_schema=core_schema.any_schema(), values_schema=core_schema.any_schema()) + ) + + assert v_as_key.to_python({os: 1}) == {os: 1} + assert v_as_key.to_json({os: 1}).decode('utf-8') == '{"os":1}' + assert v_as_key.to_python({os: 1}, serialize_as_any=True) == {os: 1} + assert v_as_key.to_json({os: 1}, serialize_as_any=True).decode('utf-8') == '{"os":1}'