1
1
from collections import defaultdict
2
- from typing import Any , Optional , Union
2
+ from typing import Any , Dict , Optional , Union , List
3
3
4
4
import numpy as np
5
5
import pandas as pd
@@ -25,11 +25,11 @@ def is_column_categorical(values: pd.Series) -> bool:
25
25
26
26
def infer_schema (
27
27
df : pd .DataFrame ,
28
- numerical_columns : Optional [list [str ]] = None ,
29
- categorical_columns : Optional [list [str ]] = None ,
30
- text_columns : Optional [list [str ]] = None ,
31
- time_columns : Optional [list [str ]] = None ,
32
- prediction_columns : Optional [list [PredictionColumn ]] = None ,
28
+ numerical_columns : Optional [List [str ]] = None ,
29
+ categorical_columns : Optional [List [str ]] = None ,
30
+ text_columns : Optional [List [str ]] = None ,
31
+ time_columns : Optional [List [str ]] = None ,
32
+ prediction_columns : Optional [List [PredictionColumn ]] = None ,
33
33
) -> ColumnMapping :
34
34
"""
35
35
The keyword arguments will take precedence over the inferred schema.
@@ -159,8 +159,8 @@ def check_all_columns_in_df(df: pd.DataFrame, column_mapping: ColumnMapping):
159
159
160
160
161
161
def check_categorical_columns (
162
- df : pd .DataFrame , categorical_columns : list [str ]
163
- ) -> dict [str , list [Union [str , np .number ]]]:
162
+ df : pd .DataFrame , categorical_columns : List [str ]
163
+ ) -> Dict [str , List [Union [str , np .number ]]]:
164
164
"""
165
165
Make sure the dtype is numeric or string (not mixed) and that unique categories <= MAX_CATEGORICAL_UNIQUE
166
166
"""
@@ -184,33 +184,33 @@ def check_categorical_columns(
184
184
return column_to_categories
185
185
186
186
187
- def check_numerical_columns (df : pd .DataFrame , numerical_columns : list [str ]):
187
+ def check_numerical_columns (df : pd .DataFrame , numerical_columns : List [str ]):
188
188
for column in numerical_columns :
189
189
if not pd .api .types .is_numeric_dtype (df [column ].dtype ):
190
190
raise Exception (f"Column { column } is not of type numerical." )
191
191
192
192
193
- def check_text_columns (df : pd .DataFrame , text_columns : list [str ]):
193
+ def check_text_columns (df : pd .DataFrame , text_columns : List [str ]):
194
194
for column in text_columns :
195
195
if not pd .api .types .infer_dtype (df [column ], skipna = True ) == "string" :
196
196
raise Exception (f"Text column { column } is not of type string." )
197
197
198
198
199
- def check_time_columns (df : pd .DataFrame , time_columns : list [str ]):
199
+ def check_time_columns (df : pd .DataFrame , time_columns : List [str ]):
200
200
for column in time_columns :
201
201
try :
202
202
_ = pd .to_datetime (df [column ], errors = "raise" )
203
203
except :
204
204
raise Exception (f"Column { column } cannot be cast to a datetime." )
205
205
206
206
207
- def is_subset (list1 : list [Any ], list2 : list [Any ]) -> bool :
207
+ def is_subset (list1 : List [Any ], list2 : List [Any ]) -> bool :
208
208
return len (set (list1 ).difference (set (list2 ))) == 0
209
209
210
210
211
211
def check_prediction_columns (
212
212
column_mapping : ColumnMapping ,
213
- column_to_categories : dict [str , list [Union [str , np .number ]]],
213
+ column_to_categories : Dict [str , List [Union [str , np .number ]]],
214
214
) -> dict :
215
215
schema_dict = dict (
216
216
(column , {"type" : ColumnType .CATEGORICAL })
@@ -297,8 +297,8 @@ def check_prediction_columns(
297
297
def format_validated_schema (
298
298
df : pd .DataFrame ,
299
299
schema_dict : dict ,
300
- prediction_columns : list [PredictionColumn ],
301
- column_to_categories : dict [str , list [Union [str , np .number ]]],
300
+ prediction_columns : List [PredictionColumn ],
301
+ column_to_categories : Dict [str , List [Union [str , np .number ]]],
302
302
) -> dict :
303
303
new_schema = []
304
304
prediction_columns_dict = {
0 commit comments