Skip to content

Commit 993cb7e

Browse files
authored
Merge pull request #191 from bnaul/download_fix
Fix prediction downloads for probabilistic classifiers
2 parents b6d6a4f + 6a5f968 commit 993cb7e

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

cesium_app/handlers/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def get(self, prediction_id=None, action=None):
141141
'label': data['labels'],
142142
'prediction': data['preds']},
143143
columns=['ts_name', 'label', 'prediction'])
144-
if data.get('pred_probs'):
145-
result['probability'] = np.max(data['pred_probs'], axis=1)
144+
if len(data.get('pred_probs', [])) > 0:
145+
result['probability'] = data['pred_probs'].max(axis=1).values
146146
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
147147
self.set_header("Content-Disposition", "attachment; "
148148
"filename=cesium_prediction_results.csv")

cesium_app/tests/frontend/test_predict.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from os.path import join as pjoin
88
import numpy as np
99
import numpy.testing as npt
10+
import pandas as pd
1011
from cesium_app.config import cfg
1112
import json
1213
import requests
@@ -185,6 +186,26 @@ def test_download_prediction_csv_class(driver):
185186
os.remove('/tmp/cesium_prediction_results.csv')
186187

187188

189+
def test_download_prediction_csv_class_prob(driver):
190+
driver.get('/')
191+
with create_test_project() as p, create_test_dataset(p) as ds,\
192+
create_test_featureset(p) as fs,\
193+
create_test_model(fs, model_type='RandomForestClassifier') as m,\
194+
create_test_prediction(ds, m):
195+
_click_download(p.id, driver)
196+
assert os.path.exists('/tmp/cesium_prediction_results.csv')
197+
try:
198+
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
199+
npt.assert_array_equal(result.ts_name, np.arange(5))
200+
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
201+
'Mira', 'Classical_Cepheid',
202+
'Mira'])
203+
npt.assert_array_equal(result.label, result.prediction)
204+
assert (result.probability >= 0.0).all()
205+
finally:
206+
os.remove('/tmp/cesium_prediction_results.csv')
207+
208+
188209
def test_download_prediction_csv_regr(driver):
189210
driver.get('/')
190211
with create_test_project() as p, create_test_dataset(p, label_type='regr') as ds,\

0 commit comments

Comments
 (0)