Skip to content

Commit 72371f6

Browse files
committed
add typcasting test
1 parent 7a79af4 commit 72371f6

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

doubleml/data/tests/test_dml_data.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from decimal import Decimal
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -661,3 +663,29 @@ def test_property_setter_rollback_on_validation_failure():
661663
dml_data.z_cols = ["y"]
662664
# Object should remain unchanged
663665
assert dml_data.z_cols == original_z_cols
666+
667+
668+
@pytest.mark.ci
669+
def test_dml_data_decimal_to_float_conversion():
670+
"""Test that Decimal type columns are converted to float for y and d."""
671+
n_obs = 100
672+
data = {
673+
"y": [Decimal(i * 0.1) for i in range(n_obs)],
674+
"d": [Decimal(i * 0.05) for i in range(n_obs)],
675+
"X": [Decimal(i) for i in range(n_obs)],
676+
}
677+
df = pd.DataFrame(data)
678+
679+
dml_data = DoubleMLData(df, y_col="y", d_cols="d", x_cols="X")
680+
681+
assert dml_data.y.dtype == np.float64, f"Expected y to be float64, got {dml_data.y.dtype}"
682+
assert dml_data.d.dtype == np.float64, f"Expected d to be float64, got {dml_data.d.dtype}"
683+
assert dml_data.x.dtype == np.float64, f"Expected x to be float64, got {dml_data.x.dtype}"
684+
685+
expected_y = np.array([float(Decimal(i * 0.1)) for i in range(n_obs)])
686+
expected_d = np.array([float(Decimal(i * 0.05)) for i in range(n_obs)])
687+
expected_x = np.array([float(Decimal(i)) for i in range(n_obs)]).reshape(-1, 1)
688+
689+
np.testing.assert_array_almost_equal(dml_data.y, expected_y)
690+
np.testing.assert_array_almost_equal(dml_data.d, expected_d)
691+
np.testing.assert_array_almost_equal(dml_data.x, expected_x)

0 commit comments

Comments
 (0)