@@ -677,14 +677,14 @@ def load_szilard_1m(
677677
678678 label_col = "dep_delayed_15min"
679679 y_train = (d_train [label_col ] == "Y" ).astype (int ).values
680- y_test = (d_test [label_col ] == "Y" ).astype (int ).values
680+ y_test = (d_test [label_col ] == "Y" ).astype (int ).values
681681 y = np .concatenate ([y_train , y_test ])
682682
683683 X_train_raw = d_train .drop (columns = [label_col ])
684- X_test_raw = d_test .drop (columns = [label_col ])
684+ X_test_raw = d_test .drop (columns = [label_col ])
685685
686686 combined = pd .concat ([X_train_raw , X_test_raw ], axis = 0 , ignore_index = True )
687- X_combined_oh = pd .get_dummies (combined , drop_first = False , dtype = np . uint8 )
687+ X_combined_oh = pd .get_dummies (combined )
688688 x = sparse .csr_matrix (X_combined_oh .values )
689689
690690 n_train = len (d_train )
@@ -705,18 +705,18 @@ def load_szilard_10m(
705705 d_train = download_and_read_csv (url , raw_data_cache )
706706
707707 url = "https://s3.amazonaws.com/benchm-ml--main/test.csv"
708- d_test = download_and_read_csv (url , raw_data_cache )
708+ d_test = download_and_read_csv (url , raw_data_cache )
709709
710710 label_col = "dep_delayed_15min"
711711 y_train = (d_train [label_col ] == "Y" ).astype (int ).values
712- y_test = (d_test [label_col ] == "Y" ).astype (int ).values
712+ y_test = (d_test [label_col ] == "Y" ).astype (int ).values
713713 y = np .concatenate ([y_train , y_test ])
714714
715715 X_train_raw = d_train .drop (columns = [label_col ])
716- X_test_raw = d_test .drop (columns = [label_col ])
716+ X_test_raw = d_test .drop (columns = [label_col ])
717717
718718 combined = pd .concat ([X_train_raw , X_test_raw ], axis = 0 , ignore_index = True )
719- X_combined_oh = pd .get_dummies (combined , drop_first = False , dtype = np . uint8 )
719+ X_combined_oh = pd .get_dummies (combined )
720720 x = sparse .csr_matrix (X_combined_oh .values )
721721
722722 n_train = len (d_train )
0 commit comments