Skip to content

Commit 815f651

Browse files
authored
Merge pull request #358 from DoubleML/s-add-type-casting
Convert Outcome, Treatment and instrument columns to float
2 parents da9bf49 + d422dce commit 815f651

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

doubleml/data/base_data.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,12 @@ def _set_y_z(self):
702702
def _set_attr(col):
703703
if col is None:
704704
return None
705-
assert_all_finite(self.data.loc[:, col])
706-
return self.data.loc[:, col]
705+
if isinstance(col, list):
706+
converted_data = self.data.loc[:, col].apply(pd.to_numeric, errors="raise")
707+
else:
708+
converted_data = pd.to_numeric(self.data.loc[:, col], errors="raise")
709+
assert_all_finite(converted_data)
710+
return converted_data
707711

708712
self._y = _set_attr(self.y_col)
709713
self._z = _set_attr(self.z_cols)
@@ -740,7 +744,13 @@ def set_x_d(self, treatment_var):
740744
assert_all_finite(self.data.loc[:, self.d_cols], allow_nan=self.force_all_d_finite == "allow-nan")
741745
if self.force_all_x_finite:
742746
assert_all_finite(self.data.loc[:, xd_list], allow_nan=self.force_all_x_finite == "allow-nan")
743-
self._d = self.data.loc[:, treatment_var]
747+
748+
treatment_data = self.data.loc[:, treatment_var]
749+
# For panel data, preserve datetime type for treatment variables
750+
if pd.api.types.is_datetime64_any_dtype(treatment_data):
751+
self._d = treatment_data
752+
else:
753+
self._d = pd.to_numeric(treatment_data, errors="raise")
744754
self._X = self.data.loc[:, xd_list]
745755

746756
def _get_optional_col_sets(self):

doubleml/data/tests/test_dml_data.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from decimal import Decimal
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -661,3 +663,32 @@ def test_property_setter_rollback_on_validation_failure():
661663
dml_data.z_cols = ["y"]
662664
# Object should remain unchanged
663665
assert dml_data.z_cols == original_z_cols
666+
667+
668+
@pytest.mark.ci
669+
def test_dml_data_decimal_to_float_conversion():
670+
"""Test that Decimal type columns are converted to float for y and d."""
671+
n_obs = 100
672+
data = {
673+
"y": [Decimal(i * 0.1) for i in range(n_obs)],
674+
"d": [Decimal(i * 0.05) for i in range(n_obs)],
675+
"x": [Decimal(i) for i in range(n_obs)],
676+
"z": [Decimal(i * 2) for i in range(n_obs)],
677+
}
678+
df = pd.DataFrame(data)
679+
680+
dml_data = DoubleMLData(df, y_col="y", d_cols="d", x_cols="x", z_cols="z")
681+
682+
assert dml_data.y.dtype == np.float64, f"Expected y to be float64, got {dml_data.y.dtype}"
683+
assert dml_data.d.dtype == np.float64, f"Expected d to be float64, got {dml_data.d.dtype}"
684+
assert dml_data.z.dtype == np.float64, f"Expected z to be float64, got {dml_data.z.dtype}"
685+
# x is not converted to float, so its dtype remains Decimal
686+
assert dml_data.x.dtype == Decimal
687+
688+
expected_y = np.array([float(Decimal(i * 0.1)) for i in range(n_obs)])
689+
expected_d = np.array([float(Decimal(i * 0.05)) for i in range(n_obs)])
690+
expected_z = np.array([float(Decimal(i * 2)) for i in range(n_obs)]).reshape(-1, 1)
691+
692+
np.testing.assert_array_almost_equal(dml_data.y, expected_y)
693+
np.testing.assert_array_almost_equal(dml_data.d, expected_d)
694+
np.testing.assert_array_almost_equal(dml_data.z, expected_z)

0 commit comments

Comments
 (0)