1+ from itertools import cycle
12import numpy as np
2- import pandas as pd
3- from sklearn .metrics import confusion_matrix
4- import plotly
5- import plotly .offline as py
6- from plotly .tools import FigureFactory as FF
7-
83from cesium import featurize
9- from .config import cfg
4+ from bokeh .plotting import figure
5+ from bokeh .layouts import gridplot
6+ from bokeh .palettes import Viridis as palette
7+ from bokeh .core .json_encoder import serialize_json
8+ from bokeh .document import Document
9+ from bokeh .util .serialization import make_id
1010
1111
1212def feature_scatterplot (fset_path , features_to_plot ):
@@ -21,42 +21,39 @@ def feature_scatterplot(fset_path, features_to_plot):
2121
2222 Returns
2323 -------
24- (fig.data, fig.layout)
25- Returns (fig.data, fig.layout) where `fig` is an instance of
26- `plotly.tools.FigureFactory`.
24+ (str, str)
25+ Returns (docs_json, render_items) json for the desired plot.
2726 """
2827 fset , data = featurize .load_featureset (fset_path )
2928 fset = fset [features_to_plot ]
30-
31- if 'label' in data :
32- fset ['label' ] = data ['label' ]
33- index = 'label'
34- else :
35- index = None
36-
37- # TODO replace 'trace {i}' with class labels
38- fig = FF .create_scatterplotmatrix (fset , diag = 'box' , index = index ,
39- height = 800 , width = 800 )
40-
41- py .plot (fig , auto_open = False , output_type = 'div' )
42-
43- return fig .data , fig .layout
44-
45-
46- #def prediction_heatmap(pred_path):
47- # with xr.open_dataset(pred_path) as pset:
48- # pred_df = pd.DataFrame(pset.prediction.values, index=pset.name,
49- # columns=pset.class_label.values)
50- # pred_labels = pred_df.idxmax(axis=1)
51- # C = confusion_matrix(pset.label, pred_labels)
52- # row_sums = C.sum(axis=1)
53- # C = C / row_sums[:, np.newaxis]
54- # fig = FF.create_annotated_heatmap(C, x=[str(el) for el in
55- # pset.class_label.values],
56- # y=[str(el) for el in
57- # pset.class_label.values],
58- # colorscale='Viridis')
59- #
60- # py.plot(fig, auto_open=False, output_type='div')
61- #
62- # return fig.data, fig.layout
29+ colors = cycle (palette [5 ])
30+ plots = np .array ([[figure (width = 300 , height = 200 )
31+ for j in range (len (features_to_plot ))]
32+ for i in range (len (features_to_plot ))])
33+
34+ for (j , i ), p in np .ndenumerate (plots ):
35+ if (j == i == 0 ):
36+ p .title .text = "Scatterplot matrix"
37+ p .circle (fset .values [:,i ], fset .values [:,j ], color = next (colors ))
38+ p .xaxis .minor_tick_line_color = None
39+ p .yaxis .minor_tick_line_color = None
40+ p .ygrid [0 ].ticker .desired_num_ticks = 2
41+ p .xgrid [0 ].ticker .desired_num_ticks = 4
42+ p .outline_line_color = None
43+ p .axis .visible = None
44+
45+ plot = gridplot (plots .tolist (), ncol = len (features_to_plot ), mergetools = True , responsive = True , title = "Test" )
46+
47+ # Convert plot to json objects necessary for rendering with bokeh on the
48+ # frontend
49+ render_items = [{'docid' : plot ._id , 'elementid' : make_id ()}]
50+
51+ doc = Document ()
52+ doc .add_root (plot )
53+ docs_json_inner = doc .to_json ()
54+ docs_json = {render_items [0 ]['docid' ]: docs_json_inner }
55+
56+ docs_json = serialize_json (docs_json )
57+ render_items = serialize_json (render_items )
58+
59+ return docs_json , render_items
0 commit comments