diff --git a/cesium_app/handlers/model.py b/cesium_app/handlers/model.py index a43e5f2..d2fa973 100644 --- a/cesium_app/handlers/model.py +++ b/cesium_app/handlers/model.py @@ -67,11 +67,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params, model = GridSearchCV(model, params_to_optimize, n_jobs=n_jobs) model.fit(fset, data['labels']) - score = model.score(fset, data['labels']) + + metrics = {} + metrics['train_score'] = model.score(fset, data['labels']) + best_params = model.best_params_ if params_to_optimize else {} joblib.dump(model, model_path) - return score, best_params + if model_type == 'RandomForestClassifier': + if params_to_optimize: + model = model.best_estimator_ + if hasattr(model, 'oob_score_'): + metrics['oob_score'] = model.oob_score_ + if hasattr(model, 'feature_importances_'): + metrics['feature_importances'] = dict(zip( + fset.columns.get_level_values(0).tolist(), + model.feature_importances_.tolist())) + + return metrics, best_params class ModelHandler(BaseHandler): @@ -102,12 +115,12 @@ def get(self, model_id=None, action=None): @auth_or_token async def _await_model_statistics(self, model_stats_future, model): try: - score, best_params = await model_stats_future + model_metrics, best_params = await model_stats_future model = DBSession().merge(model) model.task_id = None model.finished = datetime.datetime.now() - model.train_score = score + model.metrics = model_metrics model.params.update(best_params) DBSession().commit() diff --git a/cesium_app/models.py b/cesium_app/models.py index adf357a..dadcb1e 100644 --- a/cesium_app/models.py +++ b/cesium_app/models.py @@ -90,7 +90,7 @@ class Model(Base): file_uri = sa.Column(sa.String(), nullable=True, index=True) task_id = sa.Column(sa.String()) finished = sa.Column(sa.DateTime) - train_score = sa.Column(sa.Float) + metrics = sa.Column(sa.JSON, nullable=True) featureset = relationship('Featureset') project = relationship('Project') diff --git a/cesium_app/tests/frontend/test_build_model.py b/cesium_app/tests/frontend/test_build_model.py index af27b64..7fe3f71 100644 --- a/cesium_app/tests/frontend/test_build_model.py +++ b/cesium_app/tests/frontend/test_build_model.py @@ -96,16 +96,19 @@ def test_cannot_build_model_unlabeled_data(driver, project, featureset): "//div[contains(.,'Cannot build model for unlabeled feature set.')]") -def test_model_info_display(driver, project, featureset, model): +@pytest.mark.parametrize('featureset__name, model_type', + [('class', 'RandomForestClassifier (fast)')]) +def test_model_info_display(driver, project, featureset, model_type): driver.refresh() - proj_select = Select(driver.find_element_by_css_selector('[name=project]')) - proj_select.select_by_value(str(project.id)) - driver.find_element_by_id('react-tabs-6').click() + _build_model(project.id, model_type, driver) - driver.wait_for_xpath("//td[contains(text(),'{}')]".format(model.name)).click() + driver.wait_for_xpath("//td[contains(text(), 'Completed')]").click() + time.sleep(0.5) assert driver.wait_for_xpath("//th[contains(text(),'Model Type')]")\ .is_displayed() - assert driver.wait_for_xpath("//th[contains(text(),'Hyper" - "parameters')]").is_displayed() - assert driver.wait_for_xpath("//th[contains(text(),'Training " - "Data Score')]").is_displayed() + assert driver.wait_for_xpath("//th[contains(text()," + "'Hyperparameters')]").is_displayed() + assert driver.wait_for_xpath("//th[contains(text()," + "'train_score')]").is_displayed() + assert driver.wait_for_xpath("//canvas[@class='chartjs-render-monitor']")\ + .is_displayed() diff --git a/package.json b/package.json index 331ba17..76c7eba 100644 --- a/package.json +++ b/package.json @@ -9,12 +9,14 @@ "bokehjs": "^0.12.5", "bootstrap": "^3.3.7", "bootstrap-css": "^3.0.0", + "chart.js": "^2.7.1", "css-loader": "^0.26.2", "exports-loader": "^0.6.4", "imports-loader": "^0.7.1", "jquery": "^3.1.1", "prop-types": "^15.5.10", "react": "^15.1.0", + "react-chartjs-2": "^2.7.0", "react-dom": "^15.1.0", "react-redux": "^5.0.3", "react-tabs": "^0.8.2", diff --git a/static/js/components/FeatureImportances.jsx b/static/js/components/FeatureImportances.jsx new file mode 100644 index 0000000..10d3431 --- /dev/null +++ b/static/js/components/FeatureImportances.jsx @@ -0,0 +1,42 @@ +import React from 'react'; +import PropTypes from 'prop-types'; +import { HorizontalBar } from 'react-chartjs-2'; + + +const FeatureImportancesBarchart = props => { + const sorted_features = Object.keys(props.data).sort( + (a, b) => props.data[b] - props.data[a] + ).slice(0, 15); + const values = sorted_features.map(feature => props.data[feature]); + const data = { + labels: sorted_features, + datasets: [ + { + label: 'Feature Importance', + backgroundColor: '#2222ff', + hoverBackgroundColor: '#5555ff', + data: values + } + ] + }; + const options = { + scales: { + xAxes: [{ + ticks: { + min: 0.0 + } + }] + } + }; + + return ( +
+ +
+ ); +}; +FeatureImportancesBarchart.propTypes = { + data: PropTypes.array.isRequired +}; + +export default FeatureImportancesBarchart; diff --git a/static/js/components/Models.jsx b/static/js/components/Models.jsx index 8f3b6aa..38d16be 100644 --- a/static/js/components/Models.jsx +++ b/static/js/components/Models.jsx @@ -13,6 +13,7 @@ import Delete from './Delete'; import Download from './Download'; import { $try, reformatDatetime } from '../utils'; import FoldableRow from './FoldableRow'; +import FeatureImportances from './FeatureImportances'; const ModelsTab = props => ( @@ -178,7 +179,8 @@ const ModelInfo = props => ( Model Type Hyperparameters - Training Data Score + {Object.keys(props.model.metrics).map(metric => + {metric})} @@ -200,9 +202,16 @@ const ModelInfo = props => ( - - {props.model.train_score} - + { + Object.keys(props.model.metrics).map(metric => ( + + { + metric == 'feature_importances' ? + : + props.model.metrics[metric].toFixed(3) + } + )) + }