|
7 | 7 | from os.path import join as pjoin |
8 | 8 | import numpy as np |
9 | 9 | import numpy.testing as npt |
| 10 | +import pandas as pd |
10 | 11 | from cesium_app.config import cfg |
11 | 12 | import json |
12 | 13 | import requests |
@@ -185,6 +186,26 @@ def test_download_prediction_csv_class(driver): |
185 | 186 | os.remove('/tmp/cesium_prediction_results.csv') |
186 | 187 |
|
187 | 188 |
|
| 189 | +def test_download_prediction_csv_class_prob(driver): |
| 190 | + driver.get('/') |
| 191 | + with create_test_project() as p, create_test_dataset(p) as ds,\ |
| 192 | + create_test_featureset(p) as fs,\ |
| 193 | + create_test_model(fs, model_type='RandomForestClassifier') as m,\ |
| 194 | + create_test_prediction(ds, m): |
| 195 | + _click_download(p.id, driver) |
| 196 | + assert os.path.exists('/tmp/cesium_prediction_results.csv') |
| 197 | + try: |
| 198 | + result = pd.read_csv('/tmp/cesium_prediction_results.csv') |
| 199 | + npt.assert_array_equal(result.ts_name, np.arange(5)) |
| 200 | + npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid', |
| 201 | + 'Mira', 'Classical_Cepheid', |
| 202 | + 'Mira']) |
| 203 | + npt.assert_array_equal(result.label, result.prediction) |
| 204 | + assert (result.probability >= 0.0).all() |
| 205 | + finally: |
| 206 | + os.remove('/tmp/cesium_prediction_results.csv') |
| 207 | + |
| 208 | + |
188 | 209 | def test_download_prediction_csv_regr(driver): |
189 | 210 | driver.get('/') |
190 | 211 | with create_test_project() as p, create_test_dataset(p, label_type='regr') as ds,\ |
|
0 commit comments