@@ -672,20 +672,23 @@ def test_dml_data_decimal_to_float_conversion():
672672 data = {
673673 "y" : [Decimal (i * 0.1 ) for i in range (n_obs )],
674674 "d" : [Decimal (i * 0.05 ) for i in range (n_obs )],
675- "X" : [Decimal (i ) for i in range (n_obs )],
675+ "x" : [Decimal (i ) for i in range (n_obs )],
676+ "z" : [Decimal (i * 2 ) for i in range (n_obs )],
676677 }
677678 df = pd .DataFrame (data )
678679
679- dml_data = DoubleMLData (df , y_col = "y" , d_cols = "d" , x_cols = "X " )
680+ dml_data = DoubleMLData (df , y_col = "y" , d_cols = "d" , x_cols = "x" , z_cols = "z " )
680681
681682 assert dml_data .y .dtype == np .float64 , f"Expected y to be float64, got { dml_data .y .dtype } "
682683 assert dml_data .d .dtype == np .float64 , f"Expected d to be float64, got { dml_data .d .dtype } "
683- assert dml_data .x .dtype == np .float64 , f"Expected x to be float64, got { dml_data .x .dtype } "
684+ assert dml_data .z .dtype == np .float64 , f"Expected z to be float64, got { dml_data .z .dtype } "
685+ # x is not converted to float, so its dtype remains Decimal
686+ assert dml_data .x .dtype == Decimal
684687
685688 expected_y = np .array ([float (Decimal (i * 0.1 )) for i in range (n_obs )])
686689 expected_d = np .array ([float (Decimal (i * 0.05 )) for i in range (n_obs )])
687- expected_x = np .array ([float (Decimal (i )) for i in range (n_obs )]).reshape (- 1 , 1 )
690+ expected_z = np .array ([float (Decimal (i * 2 )) for i in range (n_obs )]).reshape (- 1 , 1 )
688691
689692 np .testing .assert_array_almost_equal (dml_data .y , expected_y )
690693 np .testing .assert_array_almost_equal (dml_data .d , expected_d )
691- np .testing .assert_array_almost_equal (dml_data .x , expected_x )
694+ np .testing .assert_array_almost_equal (dml_data .z , expected_z )
0 commit comments