|
| 1 | +from decimal import Decimal |
| 2 | + |
1 | 3 | import numpy as np |
2 | 4 | import pandas as pd |
3 | 5 | import pytest |
@@ -661,3 +663,29 @@ def test_property_setter_rollback_on_validation_failure(): |
661 | 663 | dml_data.z_cols = ["y"] |
662 | 664 | # Object should remain unchanged |
663 | 665 | assert dml_data.z_cols == original_z_cols |
| 666 | + |
| 667 | + |
| 668 | +@pytest.mark.ci |
| 669 | +def test_dml_data_decimal_to_float_conversion(): |
| 670 | + """Test that Decimal type columns are converted to float for y and d.""" |
| 671 | + n_obs = 100 |
| 672 | + data = { |
| 673 | + "y": [Decimal(i * 0.1) for i in range(n_obs)], |
| 674 | + "d": [Decimal(i * 0.05) for i in range(n_obs)], |
| 675 | + "X": [Decimal(i) for i in range(n_obs)], |
| 676 | + } |
| 677 | + df = pd.DataFrame(data) |
| 678 | + |
| 679 | + dml_data = DoubleMLData(df, y_col="y", d_cols="d", x_cols="X") |
| 680 | + |
| 681 | + assert dml_data.y.dtype == np.float64, f"Expected y to be float64, got {dml_data.y.dtype}" |
| 682 | + assert dml_data.d.dtype == np.float64, f"Expected d to be float64, got {dml_data.d.dtype}" |
| 683 | + assert dml_data.x.dtype == np.float64, f"Expected x to be float64, got {dml_data.x.dtype}" |
| 684 | + |
| 685 | + expected_y = np.array([float(Decimal(i * 0.1)) for i in range(n_obs)]) |
| 686 | + expected_d = np.array([float(Decimal(i * 0.05)) for i in range(n_obs)]) |
| 687 | + expected_x = np.array([float(Decimal(i)) for i in range(n_obs)]).reshape(-1, 1) |
| 688 | + |
| 689 | + np.testing.assert_array_almost_equal(dml_data.y, expected_y) |
| 690 | + np.testing.assert_array_almost_equal(dml_data.d, expected_d) |
| 691 | + np.testing.assert_array_almost_equal(dml_data.x, expected_x) |
0 commit comments