diff --git a/src/biotite/structure/io/pdbx/encoding.pyx b/src/biotite/structure/io/pdbx/encoding.py similarity index 72% rename from src/biotite/structure/io/pdbx/encoding.pyx rename to src/biotite/structure/io/pdbx/encoding.py index 2b1187b14..44da196c4 100644 --- a/src/biotite/structure/io/pdbx/encoding.pyx +++ b/src/biotite/structure/io/pdbx/encoding.py @@ -8,54 +8,31 @@ __name__ = "biotite.structure.io.pdbx" __author__ = "Patrick Kunzmann" -__all__ = ["ByteArrayEncoding", "FixedPointEncoding", - "IntervalQuantizationEncoding", "RunLengthEncoding", - "DeltaEncoding", "IntegerPackingEncoding", "StringArrayEncoding", - "TypeCode"] +__all__ = [ + "ByteArrayEncoding", + "FixedPointEncoding", + "IntervalQuantizationEncoding", + "RunLengthEncoding", + "DeltaEncoding", + "IntegerPackingEncoding", + "StringArrayEncoding", + "TypeCode", +] -cimport cython -cimport numpy as np - -from dataclasses import dataclass +import re from abc import ABCMeta, abstractmethod -from numbers import Integral +from dataclasses import dataclass from enum import IntEnum -import re +from numbers import Integral import numpy as np -from .component import _Component -from ....file import InvalidFileError - -ctypedef np.int8_t int8 -ctypedef np.int16_t int16 -ctypedef np.int32_t int32 -ctypedef np.uint8_t uint8 -ctypedef np.uint16_t uint16 -ctypedef np.uint32_t uint32 -ctypedef np.float32_t float32 -ctypedef np.float64_t float64 - -ctypedef fused Integer: - uint8 - uint16 - uint32 - int8 - int16 - int32 - -# Used to create cartesian product of type combinations -# in run-length encoding -ctypedef fused OutputInteger: - uint8 - uint16 - uint32 - int8 - int16 - int32 - -ctypedef fused Float: - float32 - float64 - +from biotite.file import InvalidFileError +from biotite.rust.structure.io.pdbx import ( + integer_packing_decode, + integer_packing_encode, + run_length_decode, + run_length_encode, +) +from biotite.structure.io.pdbx.component import _Component CAMEL_CASE_PATTERN = re.compile(r"(?>> print(ByteArrayEncoding().encode(data)) b'\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00' """ + type: ... = None def __post_init__(self): @@ -289,7 +248,7 @@ class FixedPointEncoding(Encoding): Lossy encoding that multiplies floating point values with a given factor and subsequently rounds them to the nearest integer. - Parameters + Attributes ---------- factor : float The factor by which the data is multiplied before rounding. @@ -300,11 +259,6 @@ class FixedPointEncoding(Encoding): If omitted, the data type is taken from the data the first time :meth:`encode()` is called. - Attributes - ---------- - factor : float - src_type : TypeCode - Examples -------- @@ -314,6 +268,7 @@ class FixedPointEncoding(Encoding): >>> print(FixedPointEncoding(factor=100).encode(data)) [987 654] """ + factor: ... src_type: ... = None @@ -321,36 +276,30 @@ def __post_init__(self): if self.src_type is not None: self.src_type = TypeCode.from_dtype(self.src_type) if self.src_type not in (TypeCode.FLOAT32, TypeCode.FLOAT64): - raise ValueError( - "Only floating point types are supported" - ) + raise ValueError("Only floating point types are supported") def encode(self, data): # If not given in constructor, it is determined from the data if self.src_type is None: self.src_type = TypeCode.from_dtype(data.dtype) if self.src_type not in (TypeCode.FLOAT32, TypeCode.FLOAT64): - raise ValueError( - "Only floating point types are supported" - ) + raise ValueError("Only floating point types are supported") # Round to avoid wrong values due to floating point inaccuracies scaled_data = np.round(data * self.factor) return _safe_cast(scaled_data, np.int32, allow_decimal_loss=True) def decode(self, data): - return (data / self.factor).astype( - dtype=self.src_type.to_dtype(), copy=False - ) + return (data / self.factor).astype(dtype=self.src_type.to_dtype(), copy=False) @dataclass class IntervalQuantizationEncoding(Encoding): """ Lossy encoding that sorts floating point values into bins. - Each bin is represented by an integer + Each bin is represented by an integer. - Parameters + Attributes ---------- min, max : float The minimum and maximum value the bins comprise. @@ -363,12 +312,6 @@ class IntervalQuantizationEncoding(Encoding): If omitted, the data type is taken from the data the first time :meth:`encode()` is called. - Attributes - ---------- - min, max : float - num_steps : int - src_type : TypeCode - Examples -------- @@ -385,6 +328,7 @@ class IntervalQuantizationEncoding(Encoding): >>> print(decoded) [11.0 11.5 11.5 12.0 12.0 12.0] """ + min: ... max: ... num_steps: ... @@ -399,9 +343,7 @@ def encode(self, data): if self.src_type is None: self.src_type = TypeCode.from_dtype(data.dtype) - steps = np.linspace( - self.min, self.max, self.num_steps, dtype=data.dtype - ) + steps = np.linspace(self.min, self.max, self.num_steps, dtype=data.dtype) indices = np.searchsorted(steps, data, side="left") return _safe_cast(indices, np.int32) @@ -418,7 +360,7 @@ class RunLengthEncoding(Encoding): Encoding that compresses runs of equal values into pairs of (value, run length). - Parameters + Attributes ---------- src_size : int, optional The size of the array to be encoded. @@ -431,11 +373,6 @@ class RunLengthEncoding(Encoding): If omitted, the data type is taken from the data the first time :meth:`encode()` is called. - Attributes - ---------- - src_size : int - src_type : TypeCode - Examples -------- @@ -451,6 +388,7 @@ class RunLengthEncoding(Encoding): [5 1] [3 2]] """ + src_size: ... = None src_type: ... = None @@ -465,72 +403,17 @@ def encode(self, data): if self.src_size is None: self.src_size = data.shape[0] elif self.src_size != data.shape[0]: - raise IndexError( - "Given source size does not match actual data size" - ) - return self._encode(_safe_cast(data, self.src_type.to_dtype())) + raise IndexError("Given source size does not match actual data size") + return np.asarray(run_length_encode(_safe_cast(data, self.src_type.to_dtype()))) def decode(self, data): - return self._decode( - data, np.empty(0, dtype=self.src_type.to_dtype()) - ) - - def _encode(self, const Integer[:] data): - # Pessimistic allocation of output array - # -> Run length is 1 for every element - cdef int32[:] output = np.zeros(data.shape[0] * 2, dtype=np.int32) - cdef int i=0, j=0 - cdef int val = data[0] - cdef int run_length = 0 - cdef int curr_val - for i in range(data.shape[0]): - curr_val = data[i] - if curr_val == val: - run_length += 1 - else: - # New element -> Write element with run-length - output[j] = val - output[j+1] = run_length - j += 2 - val = curr_val - run_length = 1 - # Write last element - output[j] = val - output[j+1] = run_length - j += 2 - # Trim to correct size - return np.asarray(output)[:j] - - def _decode(self, const Integer[:] data, OutputInteger[:] output_type): - """ - `output_type` is merely a typed placeholder to allow for static - typing of output. - """ - if data.shape[0] % 2 != 0: - raise ValueError("Invalid run-length encoded data") - - cdef int length = 0 - cdef int i, j - cdef int value, repeat - - if self.src_size is None: - # Determine length of output array by summing run lengths - for i in range(1, data.shape[0], 2): - length += data[i] - else: - length = self.src_size - - cdef OutputInteger[:] output = np.zeros( - length, dtype=np.asarray(output_type).dtype + return np.asarray( + run_length_decode( + data.astype(np.int32, copy=False), + self.src_size, + np.dtype(self.src_type.to_dtype()), + ) ) - # Fill output array - j = 0 - for i in range(0, data.shape[0], 2): - value = data[i] - repeat = data[i+1] - output[j : j+repeat] = value - j += repeat - return np.asarray(output) @dataclass @@ -539,7 +422,7 @@ class DeltaEncoding(Encoding): Encoding that encodes an array of integers into an array of consecutive differences. - Parameters + Attributes ---------- src_type : dtype or TypeCode, optional The data type of the array to be encoded. @@ -552,11 +435,6 @@ class DeltaEncoding(Encoding): If omitted, the value is taken from the first array element the first time :meth:`encode()` is called. - Attributes - ---------- - src_type : TypeCode - origin : int - Examples -------- @@ -567,6 +445,7 @@ class DeltaEncoding(Encoding): >>> print(encoding.origin) 1 """ + src_type: ... = None origin: ... = None @@ -606,7 +485,7 @@ class IntegerPackingEncoding(Encoding): the integer is represented by a sum of consecutive elements in the compressed array. - Parameters + Attributes ---------- byte_count : int The number of bytes the packed integers should occupy. @@ -622,12 +501,6 @@ class IntegerPackingEncoding(Encoding): If omitted, first time :meth:`encode()` is called, determines whether the values fit into unsigned integers. - Attributes - ---------- - byte_count : int - src_size : int - is_unsigned : bool - Examples -------- @@ -637,6 +510,7 @@ class IntegerPackingEncoding(Encoding): >>> print(IntegerPackingEncoding(byte_count=1).encode(data)) [ 1 2 -3 127 1] """ + byte_count: ... src_size: ... = None is_unsigned: ... = None @@ -645,44 +519,31 @@ def encode(self, data): if self.src_size is None: self.src_size = len(data) elif self.src_size != len(data): - raise IndexError( - "Given source size does not match actual data size" - ) + raise IndexError("Given source size does not match actual data size") if self.is_unsigned is None: # Only positive values -> use unsigned integers self.is_unsigned = data.min().item() >= 0 data = _safe_cast(data, np.int32) - return self._encode( - data, np.empty(0, dtype=self._determine_packed_dtype()) + packed_dtype = np.dtype(self._determine_packed_dtype()) + return np.asarray( + integer_packing_encode( + data, + self.byte_count, + self.is_unsigned, + packed_dtype, + ) ) - def decode(self, const Integer[:] data): - cdef int i, j - cdef int min_val, max_val - cdef int packed_val, unpacked_val - bounds = self._get_bounds(data) - min_val = bounds[0] - max_val = bounds[1] - # For signed integers, do not check lower bound (is always 0) - # -> Set lower bound to value that is never reached - if min_val == 0: - min_val = -1 - - cdef int32[:] output = np.zeros(self.src_size, dtype=np.int32) - j = 0 - unpacked_val = 0 - for i in range(data.shape[0]): - packed_val = data[i] - if packed_val == max_val or packed_val == min_val: - unpacked_val += packed_val - else: - unpacked_val += packed_val - output[j] = unpacked_val - unpacked_val = 0 - j += 1 - # Trim to correct size and return - return np.asarray(output) + def decode(self, data): + return np.asarray( + integer_packing_decode( + data, + self.byte_count, + self.is_unsigned, + self.src_size, + ) + ) def _determine_packed_dtype(self): if self.byte_count == 1: @@ -698,80 +559,6 @@ def _determine_packed_dtype(self): else: raise ValueError("Unsupported byte count") - @cython.cdivision(True) - def _encode(self, const Integer[:] data, OutputInteger[:] output_type): - """ - `output_type` is merely a typed placeholder to allow for static - typing of output. - """ - cdef int i=0, j=0 - - packed_type = np.asarray(output_type).dtype - cdef int min_val = np.iinfo(packed_type).min - cdef int max_val = np.iinfo(packed_type).max - - # Get length of output array - # by summing up required length of each element - cdef int number - cdef long length = 0 - for i in range(data.shape[0]): - number = data[i] - if number < 0: - if min_val == 0: - raise ValueError( - "Cannot pack negative numbers into unsigned type" - ) - # The required packed length for an element is the - # number of times min_val/max_val need to be repeated - length += number // min_val + 1 - elif number > 0: - length += number // max_val + 1 - else: - # number = 0 - length += 1 - - # Fill output - cdef OutputInteger[:] output = np.zeros(length, dtype=packed_type) - cdef int remainder - j = 0 - for i in range(data.shape[0]): - remainder = data[i] - if remainder < 0: - if min_val == 0: - raise ValueError( - "Cannot pack negative numbers into unsigned type" - ) - while remainder <= min_val: - remainder -= min_val - output[j] = min_val - j += 1 - elif remainder > 0: - while remainder >= max_val: - remainder -= max_val - output[j] = max_val - j += 1 - output[j] = remainder - j += 1 - return np.asarray(output) - - @staticmethod - def _get_bounds(const Integer[:] data): - if Integer is int8: - info = np.iinfo(np.int8) - elif Integer is int16: - info = np.iinfo(np.int16) - elif Integer is int32: - info = np.iinfo(np.int32) - elif Integer is uint8: - info = np.iinfo(np.uint8) - elif Integer is uint16: - info = np.iinfo(np.uint16) - elif Integer is uint32: - info = np.iinfo(np.uint32) - else: - raise ValueError("Unsupported integer type") - return info.min, info.max - @dataclass class StringArrayEncoding(Encoding): @@ -836,23 +623,20 @@ def __init__(self, strings=None, data_encoding=None, offset_encoding=None): @staticmethod def deserialize(content): - data_encoding = [ - deserialize_encoding(e) for e in content["dataEncoding"] - ] - offset_encoding = [ - deserialize_encoding(e) for e in content["offsetEncoding"] - ] - cdef str concatenated_strings = content["stringData"] - cdef np.ndarray offsets = decode_stepwise( - content["offsets"], offset_encoding + data_encoding = [deserialize_encoding(e) for e in content["dataEncoding"]] + offset_encoding = [deserialize_encoding(e) for e in content["offsetEncoding"]] + concatenated_strings = content["stringData"] + offsets = decode_stepwise(content["offsets"], offset_encoding) + + strings = np.array( + [ + concatenated_strings[offsets[i] : offsets[i + 1]] + # The final offset is the exclusive stop index + for i in range(len(offsets) - 1) + ], + dtype="U", ) - strings = np.array([ - concatenated_strings[offsets[i]:offsets[i+1]] - # The final offset is the exclusive stop index - for i in range(len(offsets)-1) - ], dtype="U") - return StringArrayEncoding(strings, data_encoding, offset_encoding) def serialize(self): @@ -970,9 +754,7 @@ def deserialize_encoding(content): try: encoding_class = _encoding_classes[content["kind"]] except KeyError: - raise ValueError( - f"Unknown encoding kind '{content['kind']}'" - ) + raise ValueError(f"Unknown encoding kind '{content['kind']}'") return encoding_class.deserialize(content) @@ -1048,9 +830,7 @@ def _camel_to_snake_case(attribute_name): def _snake_to_camel_case(attribute_name): - attribute_name = "".join( - word.capitalize() for word in attribute_name.split("_") - ) + attribute_name = "".join(word.capitalize() for word in attribute_name.split("_")) return attribute_name[0].lower() + attribute_name[1:] @@ -1075,4 +855,4 @@ def _safe_cast(array, dtype, allow_decimal_loss=False): if np.max(array) > dtype_info.max or np.min(array) < dtype_info.min: raise ValueError("Values do not fit into the given dtype") - return array.astype(target_dtype) \ No newline at end of file + return array.astype(target_dtype) diff --git a/src/rust/structure/io/mod.rs b/src/rust/structure/io/mod.rs index 91c7372df..d553ab1de 100644 --- a/src/rust/structure/io/mod.rs +++ b/src/rust/structure/io/mod.rs @@ -2,6 +2,7 @@ use crate::add_subpackage; use pyo3::prelude::*; pub mod pdb; +pub mod pdbx; pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult> { let module = PyModule::new(parent_module.py(), "io")?; @@ -10,5 +11,10 @@ pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult>(val: i32) -> PyResult { + T::try_from(val).map_err(|_| { + exceptions::PyOverflowError::new_err(format!("Value {} does not fit into target type", val)) + }) +} + +/// Convert a `T` to `i32`, returning a `PyErr` on overflow. +#[inline(always)] +fn to_i32>(val: T) -> PyResult { + val.try_into() + .map_err(|_| exceptions::PyOverflowError::new_err("Value does not fit into i32")) +} + +/// Convenience wrapper around [`dispatch_dtype!`] for the six BinaryCIF integer types. +macro_rules! dispatch_integer_dtype { + ($py:expr, $dtype:expr, $func:ident $args:tt) => { + dispatch_dtype!($py, $dtype, [i8, i16, i32, u8, u16, u32], $func $args) + }; +} + +/// Encode an integer array using run-length encoding. +/// +/// The input array can be any integer dtype (i8, i16, i32, u8, u16, u32). +/// The output is always int32 (value/run-length pairs). +#[pyfunction] +pub fn run_length_encode<'py>( + py: Python<'py>, + data: &Bound<'py, numpy::PyUntypedArray>, +) -> PyResult> { + let dt = data.dtype(); + dispatch_integer_dtype!(py, &dt, run_length_encode_inner(py, data)) +} + +fn run_length_encode_inner<'py, T: numpy::Element + Copy + TryInto>( + py: Python<'py>, + data: &Bound<'py, numpy::PyUntypedArray>, +) -> PyResult>> { + let typed = data.cast::>()?; + let data = unsafe { typed.as_slice()? }; + + if data.is_empty() { + return Ok(Vec::::new().into_pyarray(py)); + } + + let mut output: Vec = Vec::with_capacity(data.len() * 2); + let mut val = data[0]; + let mut run_length: i32 = 1; + for &curr_val in &data[1..] { + if to_i32(curr_val)? == to_i32(val)? { + run_length += 1; + } else { + output.push(to_i32(val)?); + output.push(run_length); + val = curr_val; + run_length = 1; + } + } + output.push(to_i32(val)?); + output.push(run_length); + Ok(output.into_pyarray(py)) +} + +/// Decode a run-length encoded int32 array into the given output dtype. +#[pyfunction] +pub fn run_length_decode<'py>( + py: Python<'py>, + data: PyReadonlyArray1<'py, i32>, + src_size: Option, + out_dtype: &Bound<'py, PyArrayDescr>, +) -> PyResult> { + let data = data.as_slice()?; + if data.len() % 2 != 0 { + return Err(exceptions::PyValueError::new_err( + "Invalid run-length encoded data", + )); + } + + let length = match src_size { + Some(s) => s, + None => (1..data.len()).step_by(2).map(|i| data[i] as usize).sum(), + }; + + dispatch_integer_dtype!(py, out_dtype, run_length_decode_inner(py, data, length)) +} + +fn run_length_decode_inner<'py, T: numpy::Element + Copy + Default + TryFrom>( + py: Python<'py>, + data: &[i32], + length: usize, +) -> PyResult>> { + let mut output: Vec = vec![T::default(); length]; + let mut j: usize = 0; + for i in (0..data.len()).step_by(2) { + let value = from_i32::(data[i])?; + let repeat = data[i + 1] as usize; + let end = j + repeat; + if end > length { + return Err(exceptions::PyValueError::new_err( + "Run-length data exceeds expected output size", + )); + } + for slot in &mut output[j..end] { + *slot = value; + } + j = end; + } + Ok(output.into_pyarray(py)) +} + +/// Encode an int32 array using integer packing into the given packed dtype. +#[pyfunction] +pub fn integer_packing_encode<'py>( + py: Python<'py>, + data: PyReadonlyArray1<'py, i32>, + byte_count: usize, + is_unsigned: bool, + out_dtype: &Bound<'py, PyArrayDescr>, +) -> PyResult> { + let data = data.as_slice()?; + let (min_val, max_val) = get_bounds(byte_count, is_unsigned)?; + + dispatch_integer_dtype!( + py, + out_dtype, + integer_packing_encode_inner(py, data, min_val as i32, max_val as i32) + ) +} + +fn integer_packing_encode_inner<'py, T: numpy::Element + Default + Clone + TryFrom>( + py: Python<'py>, + data: &[i32], + min_val: i32, + max_val: i32, +) -> PyResult>> { + // Compute output length + let mut length: usize = 0; + for &number in data { + match number.cmp(&0) { + std::cmp::Ordering::Less => { + if min_val == 0 { + return Err(exceptions::PyValueError::new_err( + "Cannot pack negative numbers into unsigned type", + )); + } + length += (number / min_val) as usize + 1; + } + std::cmp::Ordering::Greater => { + length += (number / max_val) as usize + 1; + } + std::cmp::Ordering::Equal => { + length += 1; + } + } + } + + let mut output: Vec = vec![T::default(); length]; + let mut j: usize = 0; + for &number in data { + let mut remainder = number; + match remainder.cmp(&0) { + std::cmp::Ordering::Less => { + while remainder <= min_val { + remainder -= min_val; + output[j] = from_i32::(min_val)?; + j += 1; + } + } + std::cmp::Ordering::Greater => { + while remainder >= max_val { + remainder -= max_val; + output[j] = from_i32::(max_val)?; + j += 1; + } + } + std::cmp::Ordering::Equal => {} + } + output[j] = from_i32::(remainder)?; + j += 1; + } + Ok(output.into_pyarray(py)) +} + +/// Decode an integer-packed array back to int32. +/// +/// Accepts any integer-typed packed array. +#[pyfunction] +pub fn integer_packing_decode<'py>( + py: Python<'py>, + data: &Bound<'py, numpy::PyUntypedArray>, + byte_count: usize, + is_unsigned: bool, + src_size: usize, +) -> PyResult> { + let dt = data.dtype(); + dispatch_integer_dtype!( + py, + &dt, + integer_packing_decode_inner(py, data, byte_count, is_unsigned, src_size) + ) +} + +fn integer_packing_decode_inner<'py, T: numpy::Element + Copy + TryInto>( + py: Python<'py>, + data: &Bound<'py, numpy::PyUntypedArray>, + byte_count: usize, + is_unsigned: bool, + src_size: usize, +) -> PyResult>> { + let typed = data.cast::>()?; + let data = unsafe { typed.as_slice()? }; + + let (min_val, max_val) = get_bounds(byte_count, is_unsigned)?; + // For unsigned integers, do not check lower bound + let effective_min = if min_val == 0 { -1 } else { min_val }; + + let mut output: Vec = vec![0; src_size]; + let mut j: usize = 0; + let mut unpacked_val: i32 = 0; + for &packed in data { + let packed_val = to_i32(packed)?; + if packed_val == max_val as i32 || packed_val == effective_min as i32 { + unpacked_val += packed_val; + } else { + unpacked_val += packed_val; + if j >= src_size { + return Err(exceptions::PyValueError::new_err( + "Decoded data exceeds expected output size", + )); + } + output[j] = unpacked_val; + unpacked_val = 0; + j += 1; + } + } + Ok(output.into_pyarray(py)) +} + +/// Return the `(min, max)` bounds for a packed integer type given its byte count and signedness. +fn get_bounds(byte_count: usize, is_unsigned: bool) -> PyResult<(isize, isize)> { + match (byte_count, is_unsigned) { + (1, false) => Ok((i8::MIN as isize, i8::MAX as isize)), + (1, true) => Ok((0, u8::MAX as isize)), + (2, false) => Ok((i16::MIN as isize, i16::MAX as isize)), + (2, true) => Ok((0, u16::MAX as isize)), + _ => Err(exceptions::PyValueError::new_err("Unsupported byte count")), + } +} diff --git a/src/rust/structure/io/pdbx/mod.rs b/src/rust/structure/io/pdbx/mod.rs new file mode 100644 index 000000000..66ca9603e --- /dev/null +++ b/src/rust/structure/io/pdbx/mod.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; + +pub mod encoding; + +pub fn module<'py>(parent_module: &Bound<'py, PyModule>) -> PyResult> { + let module = PyModule::new(parent_module.py(), "pdbx")?; + module.add_function(wrap_pyfunction!(encoding::run_length_encode, &module)?)?; + module.add_function(wrap_pyfunction!(encoding::run_length_decode, &module)?)?; + module.add_function(wrap_pyfunction!(encoding::integer_packing_encode, &module)?)?; + module.add_function(wrap_pyfunction!(encoding::integer_packing_decode, &module)?)?; + Ok(module) +} diff --git a/src/rust/util.rs b/src/rust/util.rs index 26953f2ff..b558aaef7 100644 --- a/src/rust/util.rs +++ b/src/rust/util.rs @@ -11,3 +11,85 @@ pub fn check_signals_periodically(py: Python<'_>, iteration: usize) -> PyResult< } Ok(()) } + +/// Dispatch a generic function call based on a NumPy dtype descriptor. +/// +/// Accepts a custom type list. +/// Arguments are captured as a single token tree (the parenthesized group), +/// which avoids the `$T repeats N times but $arg repeats M times` error. +/// The inner function's return value is erased to `Bound<'py, PyAny>` via +/// `.into_any()` so that all branches have a uniform type. +/// +/// # Usage +/// +/// ```ignore +/// dispatch_dtype!(py, dtype, [i8, i16, i32, f32, f64], inner_fn(py, data)) +/// ``` +#[macro_export] +macro_rules! dispatch_dtype { + ($py:expr, $dtype:expr, [$($T:ty),+], $func:ident $args:tt) => {{ + if false { unreachable!() } + $( + else if $dtype.is_equiv_to(&numpy::dtype::<$T>($py)) { + $func::<$T> $args .map(|v| v.into_any()) + } + )+ + else { + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Unsupported dtype: {}", $dtype + ))) + } + }}; +} + +/// Dispatch a generic function call over the cartesian product of two dtype +/// descriptors. +/// +/// The inner function is called as `$func::(args)` for the matching +/// pair of types. Uses an internal `@inner` arm to avoid nesting two +/// repetition levels of independently captured type lists. +/// +/// # Usage +/// +/// ```ignore +/// dispatch_dtypes!( +/// py, +/// in_dtype, [i8, i16, i32], +/// out_dtype, [u8, u16, u32], +/// inner_fn(py, data, length) +/// ) +/// ``` +#[macro_export] +macro_rules! dispatch_dtypes { + ( + $py:expr, + $dtype1:expr, [$($T1:ty),+], + $dtype2:expr, [$($T2:ty),+], + $func:ident $args:tt + ) => {{ + if false { unreachable!() } + $( + else if $dtype1.is_equiv_to(&numpy::dtype::<$T1>($py)) { + dispatch_dtypes!(@inner $py, $dtype2, [$($T2),+], $func, [$T1] $args) + } + )+ + else { + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Unsupported dtype: {}", $dtype1 + ))) + } + }}; + (@inner $py:expr, $dtype2:expr, [$($T2:ty),+], $func:ident, [$T1:ty] $args:tt) => {{ + if false { unreachable!() } + $( + else if $dtype2.is_equiv_to(&numpy::dtype::<$T2>($py)) { + $func::<$T1, $T2> $args .map(|v| v.into_any()) + } + )+ + else { + Err(pyo3::exceptions::PyTypeError::new_err(format!( + "Unsupported dtype: {}", $dtype2 + ))) + } + }}; +}