Skip to content

Commit a3583cc

Browse files
author
Dmitry Razdoburdin
committed
fix
1 parent 99cbf61 commit a3583cc

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

sklbench/datasets/loaders.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -677,14 +677,14 @@ def load_szilard_1m(
677677

678678
label_col = "dep_delayed_15min"
679679
y_train = (d_train[label_col] == "Y").astype(int).values
680-
y_test = (d_test[label_col] == "Y").astype(int).values
680+
y_test = (d_test[label_col] == "Y").astype(int).values
681681
y = np.concatenate([y_train, y_test])
682682

683683
X_train_raw = d_train.drop(columns=[label_col])
684-
X_test_raw = d_test.drop(columns=[label_col])
684+
X_test_raw = d_test.drop(columns=[label_col])
685685

686686
combined = pd.concat([X_train_raw, X_test_raw], axis=0, ignore_index=True)
687-
X_combined_oh = pd.get_dummies(combined, drop_first=False, dtype=np.uint8)
687+
X_combined_oh = pd.get_dummies(combined)
688688
x = sparse.csr_matrix(X_combined_oh.values)
689689

690690
n_train = len(d_train)
@@ -705,18 +705,18 @@ def load_szilard_10m(
705705
d_train = download_and_read_csv(url, raw_data_cache)
706706

707707
url = "https://s3.amazonaws.com/benchm-ml--main/test.csv"
708-
d_test = download_and_read_csv(url, raw_data_cache)
708+
d_test = download_and_read_csv(url, raw_data_cache)
709709

710710
label_col = "dep_delayed_15min"
711711
y_train = (d_train[label_col] == "Y").astype(int).values
712-
y_test = (d_test[label_col] == "Y").astype(int).values
712+
y_test = (d_test[label_col] == "Y").astype(int).values
713713
y = np.concatenate([y_train, y_test])
714714

715715
X_train_raw = d_train.drop(columns=[label_col])
716-
X_test_raw = d_test.drop(columns=[label_col])
716+
X_test_raw = d_test.drop(columns=[label_col])
717717

718718
combined = pd.concat([X_train_raw, X_test_raw], axis=0, ignore_index=True)
719-
X_combined_oh = pd.get_dummies(combined, drop_first=False, dtype=np.uint8)
719+
X_combined_oh = pd.get_dummies(combined)
720720
x = sparse.csr_matrix(X_combined_oh.values)
721721

722722
n_train = len(d_train)

0 commit comments

Comments
 (0)