11from collections import defaultdict
2- from typing import Any , Optional , Union
2+ from typing import Any , Dict , Optional , Union , List
33
44import numpy as np
55import pandas as pd
@@ -25,11 +25,11 @@ def is_column_categorical(values: pd.Series) -> bool:
2525
2626def infer_schema (
2727 df : pd .DataFrame ,
28- numerical_columns : Optional [list [str ]] = None ,
29- categorical_columns : Optional [list [str ]] = None ,
30- text_columns : Optional [list [str ]] = None ,
31- time_columns : Optional [list [str ]] = None ,
32- prediction_columns : Optional [list [PredictionColumn ]] = None ,
28+ numerical_columns : Optional [List [str ]] = None ,
29+ categorical_columns : Optional [List [str ]] = None ,
30+ text_columns : Optional [List [str ]] = None ,
31+ time_columns : Optional [List [str ]] = None ,
32+ prediction_columns : Optional [List [PredictionColumn ]] = None ,
3333) -> ColumnMapping :
3434 """
3535 The keyword arguments will take precedence over the inferred schema.
@@ -159,8 +159,8 @@ def check_all_columns_in_df(df: pd.DataFrame, column_mapping: ColumnMapping):
159159
160160
161161def check_categorical_columns (
162- df : pd .DataFrame , categorical_columns : list [str ]
163- ) -> dict [str , list [Union [str , np .number ]]]:
162+ df : pd .DataFrame , categorical_columns : List [str ]
163+ ) -> Dict [str , List [Union [str , np .number ]]]:
164164 """
165165 Make sure the dtype is numeric or string (not mixed) and that unique categories <= MAX_CATEGORICAL_UNIQUE
166166 """
@@ -184,33 +184,33 @@ def check_categorical_columns(
184184 return column_to_categories
185185
186186
187- def check_numerical_columns (df : pd .DataFrame , numerical_columns : list [str ]):
187+ def check_numerical_columns (df : pd .DataFrame , numerical_columns : List [str ]):
188188 for column in numerical_columns :
189189 if not pd .api .types .is_numeric_dtype (df [column ].dtype ):
190190 raise Exception (f"Column { column } is not of type numerical." )
191191
192192
193- def check_text_columns (df : pd .DataFrame , text_columns : list [str ]):
193+ def check_text_columns (df : pd .DataFrame , text_columns : List [str ]):
194194 for column in text_columns :
195195 if not pd .api .types .infer_dtype (df [column ], skipna = True ) == "string" :
196196 raise Exception (f"Text column { column } is not of type string." )
197197
198198
199- def check_time_columns (df : pd .DataFrame , time_columns : list [str ]):
199+ def check_time_columns (df : pd .DataFrame , time_columns : List [str ]):
200200 for column in time_columns :
201201 try :
202202 _ = pd .to_datetime (df [column ], errors = "raise" )
203203 except :
204204 raise Exception (f"Column { column } cannot be cast to a datetime." )
205205
206206
207- def is_subset (list1 : list [Any ], list2 : list [Any ]) -> bool :
207+ def is_subset (list1 : List [Any ], list2 : List [Any ]) -> bool :
208208 return len (set (list1 ).difference (set (list2 ))) == 0
209209
210210
211211def check_prediction_columns (
212212 column_mapping : ColumnMapping ,
213- column_to_categories : dict [str , list [Union [str , np .number ]]],
213+ column_to_categories : Dict [str , List [Union [str , np .number ]]],
214214) -> dict :
215215 schema_dict = dict (
216216 (column , {"type" : ColumnType .CATEGORICAL })
@@ -297,8 +297,8 @@ def check_prediction_columns(
297297def format_validated_schema (
298298 df : pd .DataFrame ,
299299 schema_dict : dict ,
300- prediction_columns : list [PredictionColumn ],
301- column_to_categories : dict [str , list [Union [str , np .number ]]],
300+ prediction_columns : List [PredictionColumn ],
301+ column_to_categories : Dict [str , List [Union [str , np .number ]]],
302302) -> dict :
303303 new_schema = []
304304 prediction_columns_dict = {
0 commit comments