From db42bfdb75eca2ab1361e74f1ab273526d84c06e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 3 Jul 2023 18:07:34 +0200 Subject: [PATCH 1/5] first draft --- nwbwidgets/time_grid.py | 262 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 nwbwidgets/time_grid.py diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py new file mode 100644 index 00000000..d49ceb5f --- /dev/null +++ b/nwbwidgets/time_grid.py @@ -0,0 +1,262 @@ +import plotly.graph_objects as go +import matplotlib.cm as mplcm +import matplotlib.colors as colors +import matplotlib.pyplot as plt +from pynwb import ProcessingModule +from hdmf.utils import LabelledDict +from hdmf.common import DynamicTable, VectorData +from hdmf.data_utils import DataIO +from pynwb import NWBHDF5IO, NWBFile, TimeSeries +from pynwb.core import NWBContainer +from pynwb.behavior import Position +from pynwb.ecephys import LFP + + +def get_partial_path_from_parents(obj): + # ndx-events do not have object.data + parents = [] + while obj is not None and obj.name != 'root': + parents.append(obj.name) + obj = obj.parent # assuming the object has a 'parent' attribute + return parents + + +def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = False) -> dict: + + temporal_information_dict = {} + for object_id, object in nwbfile.objects.items(): + object_name = object.name + if hasattr(object, "data") and hasattr(object.data, "name"): + object_path = object.data.name + top_module = object_path[1:].split("/")[0] + if top_module == "processing": + top_module = object_path[1:].split("/")[1] + + elif hasattr(object, "data"): + object_path = get_partial_path_from_parents(object)[::-1] + top_module = object_path[0] if object_path else "" + else: + object_path = get_partial_path_from_parents(object)[::-1] + top_module = object_path[0] if object_path else "" + + + + if isinstance(object, DynamicTable) and hasattr(object, "start_time") and hasattr(object, "stop_time"): + if verbose: + print("------------------") + print(f"branch: dynamic table with intervals for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + + start_time = object.start_time.data[:] + stop_time = object.stop_time.data[:] + + intervals = [{"start_time": start_time, "stop_time": stop_time} for start_time, stop_time in zip(start_time, stop_time)] + timing_dict = dict(start_time=intervals[0]["start_time"], stop_time=intervals[-1]["stop_time"]) + object_info = dict(object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module) + temporal_information_dict[object_name] = dict(timing_dict=timing_dict, info=object_info, intervals=intervals) + + elif isinstance(object, DynamicTable) and hasattr(object, "start_time"): + if verbose: + print("------------------") + print("NOT DONE YET") + print(f"Branch of dynamic table with only start_time for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + + elif isinstance(object, TimeSeries): + time_series = object + if verbose: + print("------------------") + print(f"branch: time series for object {object_name} with path {object_path}") + print(f"Top module {top_module}") + print("------------------") + timing_dict = {} + if time_series.timestamps is not None: + timing_dict["start_time"] = time_series.timestamps[0] + timing_dict["stop_time"] = time_series.timestamps[-1] + else: + timing_dict["start_time"] = time_series.starting_time + timing_dict["stop_time"] = time_series.starting_time + time_series.num_samples / time_series.rate + + + # TODO This fails when the series has timestamps with gaps. One option to find out the gaps and write the series + # As intervals is that the naive way np.diff generates a too large memory allocation. + # Probably, a bisection algorithm can be used to find the gaps. + + object_info = dict(object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module) + temporal_information_dict[object_name] = dict(timing_dict=timing_dict, info=object_info) + + return temporal_information_dict + +def generate_epoch_guides_trace(temporal_data_dict): + epoch_dict = temporal_data_dict["epochs"] + intervals = epoch_dict["intervals"] + x_epoch = [] + y_epoch = [] + for i, epoch in enumerate(intervals, start=1): + start_time = epoch['start_time'] + stop_time = epoch['stop_time'] + x_epoch.extend([start_time, start_time, None]) + x_epoch.extend([stop_time, stop_time, None]) + y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) + y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) + + + trace = go.Scatter(x=x_epoch, + y=y_epoch, + mode='lines', + line=dict(color='#808080', width=1.5, dash='dash'), + name='epochs_guides', + visible=False) + + return trace + +def generate_epoch_annotations(temporal_data_dict) -> list: + if "epochs" not in temporal_data_dict: + return [] + + epoch_dict = temporal_data_dict["epochs"] + intervals = epoch_dict["intervals"] + + annotation_dict_list = [] + for i, epoch in enumerate(intervals, start=1): + start_time = epoch['start_time'] + stop_time = epoch['stop_time'] + + # Add a text box for each epoch + annotation_x = start_time + (stop_time - start_time) / 2 + annotation_y = len(temporal_data_dict) + len(temporal_data_dict) * 0.15 + annotation_dict = dict( + x=annotation_x, + y=annotation_y, + text=f'Epoch {i}', + showarrow=False, + font=dict( + size=14, + color="#000000" + ), + align="center", + bordercolor="#c7c7c7", + borderwidth=2, + borderpad=4, + bgcolor="#d3d3d3", + opacity=0.8 + ) + annotation_dict_list.append(annotation_dict) + return annotation_dict_list + +def generate_total_duration_traces(temporal_data_dict, styling_dict): + + traces_dict = dict() + for i, object_name in enumerate(styling_dict, start=1): + times_info_dict = temporal_data_dict[object_name] + info_dict = times_info_dict["info"] + timing_dict = times_info_dict["timing_dict"] + start_time = timing_dict['start_time'] + stop_time = timing_dict['stop_time'] + top_module = info_dict["top_module"] + color = styling_dict[object_name]["color"] + + trace = go.Scatter(x=[start_time, stop_time], + y=[i, i], + mode='lines', + line=dict(color=color, width=6), + name=f'{top_module} - {object_name}', + legendgroup=top_module, # Group by top_module + hovertemplate = f'{object_name}
start: {start_time}
end: {stop_time}') + traces_dict[object_name] = trace + + + return traces_dict + + +def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: + + fig = go.FigureWidget() + + # Key modules for temporal organization ad trials and epochs + key_modules = ["epochs", "trials"] + sorted_keys = [key for key in temporal_data_dict.keys() if key not in key_modules] + sorted_keys.sort(key=lambda x: temporal_data_dict[x]["info"]["top_module"]) + unique_top_modules = list(set([temporal_data_dict[key]["info"]["top_module"] for key in sorted_keys])) + + # if "epochs" in temporal_data_dict: + # sorted_keys = ["epochs"] + sorted_keys + # if "trials" in temporal_data_dict: + # sorted_keys = sorted_keys + ["trials"] + + # Generate a color map for the top modules + num_top_modules = len(unique_top_modules) + cm = plt.get_cmap('Accent') # Use 'gist_rainbow' colormap + cNorm = colors.Normalize(vmin=0, vmax=num_top_modules - 1) + scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm) + color_map = {module: colors.rgb2hex(scalarMap.to_rgba(i)) for i, module in enumerate(unique_top_modules)} + + styling_dict_epochs = dict(epochs=dict(color="#d3d3d3")) + styling_dict_rest = {key: dict(color=color_map[temporal_data_dict[key]["info"]["top_module"]]) for key in sorted_keys} + styling_dict_trials = dict(trials=dict(color="#d3d3d3")) + + styling_dict = dict() + if "trials" in temporal_data_dict: + styling_dict.update(styling_dict_trials) + + styling_dict.update(styling_dict_rest) + + if "epochs" in temporal_data_dict: + styling_dict.update(styling_dict_epochs) + + traces_dict = generate_total_duration_traces(temporal_data_dict, styling_dict) + for trace in traces_dict.values(): + fig.add_trace(trace) + + epoch_guide_trace = generate_epoch_guides_trace(temporal_data_dict) + fig.add_trace(epoch_guide_trace) + + annotation_dict_list = generate_epoch_annotations(temporal_data_dict) + # Button for toggling epoch guides + epoch_guide_button = dict( + type="buttons", + direction = "down", + showactive=True, + buttons=[ + dict( + label="Toggle Epoch Guides", + method="update", + args=[{"visible": [True]*len(traces_dict) + [False]}, {"annotations": []}], + args2 = [{"visible": [True]*len(traces_dict) + [True]}, {"annotations": annotation_dict_list}] , + ), + ], + x=1.15, # Position from left (0 to 1) + y=1.15, # Position from bottom (0 to 1) + xanchor="right", # Anchor point is on the right + yanchor="top", # Anchor point is at the top + ) + + ticktext = list(styling_dict.keys()) + tickvals = list(range(1, len(temporal_data_dict) + 1)) + fig.update_layout( + xaxis_title="Time (s)", + yaxis=dict(tickvals=tickvals, ticktext=ticktext, range=[0.5, len(ticktext) * 1.25], ticks="", showline=False, showticklabels=True, showgrid=True, gridwidth=0.1), + updatemenus=[epoch_guide_button], + ) + + return fig + + +from ipywidgets import Layout, fixed, widgets + + +class TimeGrid(widgets.VBox): + + def __init__(self, nwbfile: NWBFile): + super().__init__() + + self.nwbfile = nwbfile + # This is the calculation part + self.temporal_data_dict = extract_timing_information_from_nwbfile(nwbfile=self.nwbfile) + + + # This is the figure widget + self.figure_widget = generate_time_grid_widget(temporal_data_dict=self.temporal_data_dict) + self.children = [self.figure_widget] From 02803febe893697cd8ccfdd2eaabb4103ea0df18 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 3 Jul 2023 18:46:56 +0200 Subject: [PATCH 2/5] integrate on file --- nwbwidgets/file.py | 4 +- nwbwidgets/time_grid.py | 198 ++++++++++++++++++++-------------------- 2 files changed, 104 insertions(+), 98 deletions(-) diff --git a/nwbwidgets/file.py b/nwbwidgets/file.py index 8b036a44..20411f98 100644 --- a/nwbwidgets/file.py +++ b/nwbwidgets/file.py @@ -56,4 +56,6 @@ def show_nwbfile(nwbfile: NWBFile, neurodata_vis_spec: dict) -> widgets.Widget: func_ = partial(view.nwb2widget, neurodata_vis_spec=neurodata_vis_spec) accordion = lazy_show_over_data(neuro_data, func_, labels=labels, style=widgets.Accordion) - return widgets.VBox(info + [accordion]) + from .time_grid import TimeGrid + time_grid_widget = TimeGrid(nwbfile) + return widgets.VBox(info + [time_grid_widget] + [accordion]) diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py index d49ceb5f..952b4ff5 100644 --- a/nwbwidgets/time_grid.py +++ b/nwbwidgets/time_grid.py @@ -2,45 +2,36 @@ import matplotlib.cm as mplcm import matplotlib.colors as colors import matplotlib.pyplot as plt -from pynwb import ProcessingModule -from hdmf.utils import LabelledDict from hdmf.common import DynamicTable, VectorData -from hdmf.data_utils import DataIO -from pynwb import NWBHDF5IO, NWBFile, TimeSeries -from pynwb.core import NWBContainer -from pynwb.behavior import Position -from pynwb.ecephys import LFP +from pynwb import NWBFile, TimeSeries def get_partial_path_from_parents(obj): # ndx-events do not have object.data parents = [] - while obj is not None and obj.name != 'root': + while obj is not None and obj.name != "root": parents.append(obj.name) obj = obj.parent # assuming the object has a 'parent' attribute return parents def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = False) -> dict: - temporal_information_dict = {} for object_id, object in nwbfile.objects.items(): object_name = object.name if hasattr(object, "data") and hasattr(object.data, "name"): object_path = object.data.name - top_module = object_path[1:].split("/")[0] + top_module = object_path[1:].split("/")[0] if top_module == "processing": top_module = object_path[1:].split("/")[1] - + elif hasattr(object, "data"): object_path = get_partial_path_from_parents(object)[::-1] top_module = object_path[0] if object_path else "" else: object_path = get_partial_path_from_parents(object)[::-1] top_module = object_path[0] if object_path else "" - - if isinstance(object, DynamicTable) and hasattr(object, "start_time") and hasattr(object, "stop_time"): if verbose: print("------------------") @@ -48,14 +39,21 @@ def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = Fa print(f"Top module {top_module}") print("------------------") - start_time = object.start_time.data[:] + start_time = object.start_time.data[:] stop_time = object.stop_time.data[:] - - intervals = [{"start_time": start_time, "stop_time": stop_time} for start_time, stop_time in zip(start_time, stop_time)] + + intervals = [ + {"start_time": start_time, "stop_time": stop_time} + for start_time, stop_time in zip(start_time, stop_time) + ] timing_dict = dict(start_time=intervals[0]["start_time"], stop_time=intervals[-1]["stop_time"]) - object_info = dict(object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module) - temporal_information_dict[object_name] = dict(timing_dict=timing_dict, info=object_info, intervals=intervals) - + object_info = dict( + object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module + ) + temporal_information_dict[object_name] = dict( + timing_dict=timing_dict, info=object_info, intervals=intervals + ) + elif isinstance(object, DynamicTable) and hasattr(object, "start_time"): if verbose: print("------------------") @@ -63,7 +61,7 @@ def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = Fa print(f"Branch of dynamic table with only start_time for object {object_name} with path {object_path}") print(f"Top module {top_module}") print("------------------") - + elif isinstance(object, TimeSeries): time_series = object if verbose: @@ -78,16 +76,18 @@ def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = Fa else: timing_dict["start_time"] = time_series.starting_time timing_dict["stop_time"] = time_series.starting_time + time_series.num_samples / time_series.rate - - + # TODO This fails when the series has timestamps with gaps. One option to find out the gaps and write the series - # As intervals is that the naive way np.diff generates a too large memory allocation. + # As intervals is that the naive way np.diff generates a too large memory allocation. # Probably, a bisection algorithm can be used to find the gaps. - object_info = dict(object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module) + object_info = dict( + object_path=object_path, object_name=object_name, object_id=object_id, top_module=top_module + ) temporal_information_dict[object_name] = dict(timing_dict=timing_dict, info=object_info) - - return temporal_information_dict + + return temporal_information_dict + def generate_epoch_guides_trace(temporal_data_dict): epoch_dict = temporal_data_dict["epochs"] @@ -95,23 +95,25 @@ def generate_epoch_guides_trace(temporal_data_dict): x_epoch = [] y_epoch = [] for i, epoch in enumerate(intervals, start=1): - start_time = epoch['start_time'] - stop_time = epoch['stop_time'] + start_time = epoch["start_time"] + stop_time = epoch["stop_time"] x_epoch.extend([start_time, start_time, None]) x_epoch.extend([stop_time, stop_time, None]) y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) y_epoch.extend([0.9, len(temporal_data_dict) + 1, None]) - - - trace = go.Scatter(x=x_epoch, - y=y_epoch, - mode='lines', - line=dict(color='#808080', width=1.5, dash='dash'), - name='epochs_guides', - visible=False) - + + trace = go.Scatter( + x=x_epoch, + y=y_epoch, + mode="lines", + line=dict(color="#808080", width=1.5, dash="dash"), + name="epochs_guides", + visible=False, + ) + return trace + def generate_epoch_annotations(temporal_data_dict) -> list: if "epochs" not in temporal_data_dict: return [] @@ -121,58 +123,53 @@ def generate_epoch_annotations(temporal_data_dict) -> list: annotation_dict_list = [] for i, epoch in enumerate(intervals, start=1): - start_time = epoch['start_time'] - stop_time = epoch['stop_time'] - + start_time = epoch["start_time"] + stop_time = epoch["stop_time"] + # Add a text box for each epoch annotation_x = start_time + (stop_time - start_time) / 2 - annotation_y = len(temporal_data_dict) + len(temporal_data_dict) * 0.15 - annotation_dict = dict( + annotation_y = len(temporal_data_dict) + len(temporal_data_dict) * 0.05 + annotation_dict = dict( x=annotation_x, y=annotation_y, - text=f'Epoch {i}', + text=f"Epoch {i}", showarrow=False, - font=dict( - size=14, - color="#000000" - ), + font=dict(size=8, color="#000000"), align="center", - bordercolor="#c7c7c7", borderwidth=2, borderpad=4, - bgcolor="#d3d3d3", - opacity=0.8 + opacity=0.6, ) annotation_dict_list.append(annotation_dict) return annotation_dict_list + def generate_total_duration_traces(temporal_data_dict, styling_dict): - traces_dict = dict() for i, object_name in enumerate(styling_dict, start=1): times_info_dict = temporal_data_dict[object_name] info_dict = times_info_dict["info"] timing_dict = times_info_dict["timing_dict"] - start_time = timing_dict['start_time'] - stop_time = timing_dict['stop_time'] + start_time = timing_dict["start_time"] + stop_time = timing_dict["stop_time"] top_module = info_dict["top_module"] color = styling_dict[object_name]["color"] - - trace = go.Scatter(x=[start_time, stop_time], - y=[i, i], - mode='lines', - line=dict(color=color, width=6), - name=f'{top_module} - {object_name}', - legendgroup=top_module, # Group by top_module - hovertemplate = f'{object_name}
start: {start_time}
end: {stop_time}') - traces_dict[object_name] = trace - + + trace = go.Scatter( + x=[start_time, stop_time], + y=[i, i], + mode="lines", + line=dict(color=color, width=6), + name=f"{top_module} - {object_name}", + legendgroup=top_module, # Group by top_module + hovertemplate=f"{object_name}
start: {start_time}
end: {stop_time}", + ) + traces_dict[object_name] = trace return traces_dict def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: - fig = go.FigureWidget() # Key modules for temporal organization ad trials and epochs @@ -180,65 +177,74 @@ def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: sorted_keys = [key for key in temporal_data_dict.keys() if key not in key_modules] sorted_keys.sort(key=lambda x: temporal_data_dict[x]["info"]["top_module"]) unique_top_modules = list(set([temporal_data_dict[key]["info"]["top_module"] for key in sorted_keys])) - + # if "epochs" in temporal_data_dict: - # sorted_keys = ["epochs"] + sorted_keys - # if "trials" in temporal_data_dict: + # sorted_keys = ["epochs"] + sorted_keys + # if "trials" in temporal_data_dict: # sorted_keys = sorted_keys + ["trials"] # Generate a color map for the top modules num_top_modules = len(unique_top_modules) - cm = plt.get_cmap('Accent') # Use 'gist_rainbow' colormap - cNorm = colors.Normalize(vmin=0, vmax=num_top_modules - 1) + cm = plt.get_cmap("Accent") # Use 'gist_rainbow' colormap + cNorm = colors.Normalize(vmin=0, vmax=num_top_modules - 1) scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm) color_map = {module: colors.rgb2hex(scalarMap.to_rgba(i)) for i, module in enumerate(unique_top_modules)} - + styling_dict_epochs = dict(epochs=dict(color="#d3d3d3")) - styling_dict_rest = {key: dict(color=color_map[temporal_data_dict[key]["info"]["top_module"]]) for key in sorted_keys} + styling_dict_rest = { + key: dict(color=color_map[temporal_data_dict[key]["info"]["top_module"]]) for key in sorted_keys + } styling_dict_trials = dict(trials=dict(color="#d3d3d3")) - + styling_dict = dict() if "trials" in temporal_data_dict: styling_dict.update(styling_dict_trials) - + styling_dict.update(styling_dict_rest) - + if "epochs" in temporal_data_dict: styling_dict.update(styling_dict_epochs) traces_dict = generate_total_duration_traces(temporal_data_dict, styling_dict) for trace in traces_dict.values(): fig.add_trace(trace) - + epoch_guide_trace = generate_epoch_guides_trace(temporal_data_dict) fig.add_trace(epoch_guide_trace) annotation_dict_list = generate_epoch_annotations(temporal_data_dict) # Button for toggling epoch guides epoch_guide_button = dict( - type="buttons", - direction = "down", - showactive=True, - buttons=[ - dict( - label="Toggle Epoch Guides", - method="update", - args=[{"visible": [True]*len(traces_dict) + [False]}, {"annotations": []}], - args2 = [{"visible": [True]*len(traces_dict) + [True]}, {"annotations": annotation_dict_list}] , - ), - ], - x=1.15, # Position from left (0 to 1) - y=1.15, # Position from bottom (0 to 1) - xanchor="right", # Anchor point is on the right - yanchor="top", # Anchor point is at the top - ) - + type="buttons", + direction="down", + showactive=True, + buttons=[ + dict( + label="Toggle Epoch Guides", + method="update", + args=[{"visible": [True] * len(traces_dict) + [False]}, {"annotations": []}], + args2=[{"visible": [True] * len(traces_dict) + [True]}, {"annotations": annotation_dict_list}], + ), + ], + x=0.10, # Position from left (0 to 1) + y=-0.15, # Position from bottom (0 to 1) + ) + ticktext = list(styling_dict.keys()) tickvals = list(range(1, len(temporal_data_dict) + 1)) fig.update_layout( xaxis_title="Time (s)", - yaxis=dict(tickvals=tickvals, ticktext=ticktext, range=[0.5, len(ticktext) * 1.25], ticks="", showline=False, showticklabels=True, showgrid=True, gridwidth=0.1), - updatemenus=[epoch_guide_button], + yaxis=dict( + tickvals=tickvals, + ticktext=ticktext, + range=[0.5, len(ticktext) * 1.25], + ticks="", + showline=False, + showticklabels=True, + showgrid=True, + gridwidth=0.1, + ), + updatemenus=[epoch_guide_button], ) return fig @@ -248,15 +254,13 @@ def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: class TimeGrid(widgets.VBox): - def __init__(self, nwbfile: NWBFile): super().__init__() self.nwbfile = nwbfile # This is the calculation part self.temporal_data_dict = extract_timing_information_from_nwbfile(nwbfile=self.nwbfile) - - + # This is the figure widget self.figure_widget = generate_time_grid_widget(temporal_data_dict=self.temporal_data_dict) self.children = [self.figure_widget] From a055b04a036c66630c37409c78578dfefa79844a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 3 Jul 2023 19:04:59 +0200 Subject: [PATCH 3/5] fixes when testing on different dandisets --- nwbwidgets/time_grid.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py index 952b4ff5..226e76ef 100644 --- a/nwbwidgets/time_grid.py +++ b/nwbwidgets/time_grid.py @@ -90,6 +90,9 @@ def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = Fa def generate_epoch_guides_trace(temporal_data_dict): + if "epochs" not in temporal_data_dict: + return [] + epoch_dict = temporal_data_dict["epochs"] intervals = epoch_dict["intervals"] x_epoch = [] @@ -210,7 +213,8 @@ def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: fig.add_trace(trace) epoch_guide_trace = generate_epoch_guides_trace(temporal_data_dict) - fig.add_trace(epoch_guide_trace) + if epoch_guide_trace: + fig.add_trace(epoch_guide_trace) annotation_dict_list = generate_epoch_annotations(temporal_data_dict) # Button for toggling epoch guides @@ -229,7 +233,10 @@ def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: x=0.10, # Position from left (0 to 1) y=-0.15, # Position from bottom (0 to 1) ) - + if epoch_guide_trace: + updatemenus = [epoch_guide_button] + else: + updatemenus = None ticktext = list(styling_dict.keys()) tickvals = list(range(1, len(temporal_data_dict) + 1)) fig.update_layout( @@ -244,7 +251,7 @@ def generate_time_grid_widget(temporal_data_dict: dict) -> go.FigureWidget: showgrid=True, gridwidth=0.1, ), - updatemenus=[epoch_guide_button], + updatemenus=updatemenus, ) return fig From 60b02c8e17e8ff157714766f65ac169e1cc45f77 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jul 2023 17:43:03 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nwbwidgets/file.py | 1 + nwbwidgets/time_grid.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nwbwidgets/file.py b/nwbwidgets/file.py index 20411f98..d7b36c6d 100644 --- a/nwbwidgets/file.py +++ b/nwbwidgets/file.py @@ -57,5 +57,6 @@ def show_nwbfile(nwbfile: NWBFile, neurodata_vis_spec: dict) -> widgets.Widget: accordion = lazy_show_over_data(neuro_data, func_, labels=labels, style=widgets.Accordion) from .time_grid import TimeGrid + time_grid_widget = TimeGrid(nwbfile) return widgets.VBox(info + [time_grid_widget] + [accordion]) diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py index 952b4ff5..fd72ddb6 100644 --- a/nwbwidgets/time_grid.py +++ b/nwbwidgets/time_grid.py @@ -1,7 +1,7 @@ -import plotly.graph_objects as go import matplotlib.cm as mplcm import matplotlib.colors as colors import matplotlib.pyplot as plt +import plotly.graph_objects as go from hdmf.common import DynamicTable, VectorData from pynwb import NWBFile, TimeSeries From f7a4add23d0ee145e0910e5797ca7ea2bbed7f8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jul 2023 17:48:59 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nwbwidgets/time_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nwbwidgets/time_grid.py b/nwbwidgets/time_grid.py index 6a342b99..edad125e 100644 --- a/nwbwidgets/time_grid.py +++ b/nwbwidgets/time_grid.py @@ -92,7 +92,7 @@ def extract_timing_information_from_nwbfile(nwbfile: NWBFile, verbose: bool = Fa def generate_epoch_guides_trace(temporal_data_dict): if "epochs" not in temporal_data_dict: return [] - + epoch_dict = temporal_data_dict["epochs"] intervals = epoch_dict["intervals"] x_epoch = []