|
| 1 | +from decimal import Decimal |
| 2 | + |
1 | 3 | import numpy as np |
2 | 4 | import pandas as pd |
3 | 5 | import pytest |
@@ -661,3 +663,32 @@ def test_property_setter_rollback_on_validation_failure(): |
661 | 663 | dml_data.z_cols = ["y"] |
662 | 664 | # Object should remain unchanged |
663 | 665 | assert dml_data.z_cols == original_z_cols |
| 666 | + |
| 667 | + |
| 668 | +@pytest.mark.ci |
| 669 | +def test_dml_data_decimal_to_float_conversion(): |
| 670 | + """Test that Decimal type columns are converted to float for y and d.""" |
| 671 | + n_obs = 100 |
| 672 | + data = { |
| 673 | + "y": [Decimal(i * 0.1) for i in range(n_obs)], |
| 674 | + "d": [Decimal(i * 0.05) for i in range(n_obs)], |
| 675 | + "x": [Decimal(i) for i in range(n_obs)], |
| 676 | + "z": [Decimal(i * 2) for i in range(n_obs)], |
| 677 | + } |
| 678 | + df = pd.DataFrame(data) |
| 679 | + |
| 680 | + dml_data = DoubleMLData(df, y_col="y", d_cols="d", x_cols="x", z_cols="z") |
| 681 | + |
| 682 | + assert dml_data.y.dtype == np.float64, f"Expected y to be float64, got {dml_data.y.dtype}" |
| 683 | + assert dml_data.d.dtype == np.float64, f"Expected d to be float64, got {dml_data.d.dtype}" |
| 684 | + assert dml_data.z.dtype == np.float64, f"Expected z to be float64, got {dml_data.z.dtype}" |
| 685 | + # x is not converted to float, so its dtype remains Decimal |
| 686 | + assert dml_data.x.dtype == Decimal |
| 687 | + |
| 688 | + expected_y = np.array([float(Decimal(i * 0.1)) for i in range(n_obs)]) |
| 689 | + expected_d = np.array([float(Decimal(i * 0.05)) for i in range(n_obs)]) |
| 690 | + expected_z = np.array([float(Decimal(i * 2)) for i in range(n_obs)]).reshape(-1, 1) |
| 691 | + |
| 692 | + np.testing.assert_array_almost_equal(dml_data.y, expected_y) |
| 693 | + np.testing.assert_array_almost_equal(dml_data.d, expected_d) |
| 694 | + np.testing.assert_array_almost_equal(dml_data.z, expected_z) |
0 commit comments