Skip to content

Commit 74802c1

Browse files
committed
bugfixes
1 parent 74edae0 commit 74802c1

File tree

1 file changed

+93
-32
lines changed

1 file changed

+93
-32
lines changed

aeolis/gui.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def browse_nc_file(self):
251251

252252
self.nc_file_entry.delete(0, END)
253253
self.nc_file_entry.insert(0, file_path)
254+
255+
# Auto-load and plot the data
256+
self.plot_nc_2d()
254257

255258
def load_new_config(self):
256259
"""Load a new configuration file and update all fields"""
@@ -486,15 +489,6 @@ def create_plot_output_2d_tab(self, tab_control):
486489
self.time_slider.pack(side=LEFT, fill=X, expand=1, padx=5)
487490
self.time_slider.set(0)
488491

489-
# Create a frame for buttons
490-
output_button_frame = ttk.Frame(plot_frame)
491-
output_button_frame.pack(pady=5)
492-
493-
# Single Load & Plot button
494-
plot_button = ttk.Button(output_button_frame, text="Load & Plot",
495-
command=self.plot_nc_2d)
496-
plot_button.grid(row=0, column=0, padx=5)
497-
498492
def create_plot_output_1d_tab(self, tab_control):
499493
# Create the 'Plot Output 1D' tab
500494
tab6 = ttk.Frame(tab_control)
@@ -584,15 +578,6 @@ def create_plot_output_1d_tab(self, tab_control):
584578
self.time_slider_1d.pack(side=LEFT, fill=X, expand=1, padx=5)
585579
self.time_slider_1d.set(0)
586580

587-
# Create a frame for buttons
588-
output_button_frame_1d = ttk.Frame(plot_frame_1d)
589-
output_button_frame_1d.pack(pady=5)
590-
591-
# Create plot button
592-
plot_button_1d = ttk.Button(output_button_frame_1d, text="Load & Plot",
593-
command=self.plot_1d_transect)
594-
plot_button_1d.grid(row=0, column=0, padx=5)
595-
596581
def browse_nc_file_1d(self):
597582
"""Open file dialog to select a NetCDF file for 1D plotting"""
598583
# Get initial directory from config file location
@@ -632,6 +617,9 @@ def browse_nc_file_1d(self):
632617

633618
self.nc_file_entry_1d.delete(0, END)
634619
self.nc_file_entry_1d.insert(0, file_path)
620+
621+
# Auto-load and plot the data
622+
self.plot_1d_transect()
635623

636624
def on_variable_changed(self, event):
637625
"""Update plot when variable selection changes"""
@@ -766,17 +754,21 @@ def plot_1d_transect(self):
766754
if 'time' in var.dimensions:
767755
# Load all time steps
768756
var_data = var[:]
769-
# Need at least 3 dimensions: (time, n, s)
757+
# Need at least 3 dimensions: (time, n, s) or (time, n, s, fractions)
770758
if var_data.ndim < 3:
771759
continue # Skip variables without spatial dimensions
772760
n_times = max(n_times, var_data.shape[0])
773761
else:
774762
# Single time step - validate shape
775-
# Need exactly 2 spatial dimensions: (n, s)
776-
if var.ndim != 2:
777-
continue # Skip variables without 2D spatial dimensions
778-
var_data = var[:, :]
779-
var_data = np.expand_dims(var_data, axis=0) # Add time dimension
763+
# Need at least 2 spatial dimensions: (n, s) or (n, s, fractions)
764+
if var.ndim < 2:
765+
continue # Skip variables without spatial dimensions
766+
if var.ndim == 2:
767+
var_data = var[:, :]
768+
var_data = np.expand_dims(var_data, axis=0) # Add time dimension
769+
elif var.ndim == 3: # (n, s, fractions)
770+
var_data = var[:, :, :]
771+
var_data = np.expand_dims(var_data, axis=0) # Add time dimension
780772

781773
var_data_dict[var_name] = var_data
782774
candidate_vars.append(var_name)
@@ -863,10 +855,19 @@ def update_1d_plot(self):
863855
# Get the data
864856
var_data = self.nc_data_cache_1d['vars'][var_name]
865857

858+
# Check if variable has fractions dimension (4D: time, n, s, fractions)
859+
has_fractions = var_data.ndim == 4
860+
866861
# Extract transect based on direction
867862
if self.transect_direction_var.get() == 'cross-shore':
868863
# Fix y-index (n), vary along x (s)
869-
transect_data = var_data[time_idx, transect_idx, :]
864+
if has_fractions:
865+
# Extract all fractions for this transect: (fractions,)
866+
transect_data = var_data[time_idx, transect_idx, :, :] # (s, fractions)
867+
# Average or select first fraction
868+
transect_data = transect_data.mean(axis=1) # Average across fractions
869+
else:
870+
transect_data = var_data[time_idx, transect_idx, :]
870871

871872
# Get x-coordinates
872873
if self.nc_data_cache_1d['x'] is not None:
@@ -884,7 +885,13 @@ def update_1d_plot(self):
884885
xlabel = 'Grid Index'
885886
else:
886887
# Fix x-index (s), vary along y (n)
887-
transect_data = var_data[time_idx, :, transect_idx]
888+
if has_fractions:
889+
# Extract all fractions for this transect: (fractions,)
890+
transect_data = var_data[time_idx, :, transect_idx, :] # (n, fractions)
891+
# Average or select first fraction
892+
transect_data = transect_data.mean(axis=1) # Average across fractions
893+
else:
894+
transect_data = var_data[time_idx, :, transect_idx]
888895

889896
# Get y-coordinates
890897
if self.nc_data_cache_1d['y'] is not None:
@@ -912,9 +919,22 @@ def update_1d_plot(self):
912919
'ustars': 'Shear Velocity S-component (m/s)',
913920
'ustarn': 'Shear Velocity N-component (m/s)',
914921
'zs': 'Surface Elevation (m)',
915-
'zsep': 'Separation Elevation (m)'
922+
'zsep': 'Separation Elevation (m)',
923+
'Ct': 'Sediment Concentration (kg/m²)',
924+
'Cu': 'Equilibrium Concentration (kg/m²)',
925+
'q': 'Sediment Flux (kg/m/s)',
926+
'qs': 'Sediment Flux S-component (kg/m/s)',
927+
'qn': 'Sediment Flux N-component (kg/m/s)',
928+
'pickup': 'Sediment Entrainment (kg/m²)',
929+
'uth': 'Threshold Shear Velocity (m/s)',
930+
'w': 'Fraction Weight (-)',
916931
}
917932
ylabel = ylabel_dict.get(var_name, var_name)
933+
934+
# Add indication if variable has fractions dimension
935+
if has_fractions:
936+
ylabel += ' (avg. fractions)'
937+
918938
self.output_1d_ax.set_ylabel(ylabel)
919939

920940
# Set title
@@ -1083,9 +1103,26 @@ def get_variable_label(self, var_name):
10831103
'ustars': 'Shear Velocity S-component (m/s)',
10841104
'ustarn': 'Shear Velocity N-component (m/s)',
10851105
'zs': 'Surface Elevation (m)',
1086-
'zsep': 'Separation Elevation (m)'
1106+
'zsep': 'Separation Elevation (m)',
1107+
'Ct': 'Sediment Concentration (kg/m²)',
1108+
'Cu': 'Equilibrium Concentration (kg/m²)',
1109+
'q': 'Sediment Flux (kg/m/s)',
1110+
'qs': 'Sediment Flux S-component (kg/m/s)',
1111+
'qn': 'Sediment Flux N-component (kg/m/s)',
1112+
'pickup': 'Sediment Entrainment (kg/m²)',
1113+
'uth': 'Threshold Shear Velocity (m/s)',
1114+
'w': 'Fraction Weight (-)',
10871115
}
1088-
return label_dict.get(var_name, var_name)
1116+
base_label = label_dict.get(var_name, var_name)
1117+
1118+
# Check if this variable has fractions dimension
1119+
if hasattr(self, 'nc_data_cache') and self.nc_data_cache is not None:
1120+
if var_name in self.nc_data_cache.get('vars', {}):
1121+
var_data = self.nc_data_cache['vars'][var_name]
1122+
if var_data.ndim == 4:
1123+
base_label += ' (avg. fractions)'
1124+
1125+
return base_label
10891126

10901127
def get_variable_title(self, var_name):
10911128
"""Get title for variable"""
@@ -1095,9 +1132,26 @@ def get_variable_title(self, var_name):
10951132
'ustars': 'Shear Velocity (S-component)',
10961133
'ustarn': 'Shear Velocity (N-component)',
10971134
'zs': 'Surface Elevation',
1098-
'zsep': 'Separation Elevation'
1135+
'zsep': 'Separation Elevation',
1136+
'Ct': 'Sediment Concentration',
1137+
'Cu': 'Equilibrium Concentration',
1138+
'q': 'Sediment Flux',
1139+
'qs': 'Sediment Flux (S-component)',
1140+
'qn': 'Sediment Flux (N-component)',
1141+
'pickup': 'Sediment Entrainment',
1142+
'uth': 'Threshold Shear Velocity',
1143+
'w': 'Fraction Weight',
10991144
}
1100-
return title_dict.get(var_name, var_name)
1145+
base_title = title_dict.get(var_name, var_name)
1146+
1147+
# Check if this variable has fractions dimension
1148+
if hasattr(self, 'nc_data_cache') and self.nc_data_cache is not None:
1149+
if var_name in self.nc_data_cache.get('vars', {}):
1150+
var_data = self.nc_data_cache['vars'][var_name]
1151+
if var_data.ndim == 4:
1152+
base_title += ' (avg. fractions)'
1153+
1154+
return base_title
11011155

11021156
def update_2d_plot(self):
11031157
"""Update the 2D plot with current settings"""
@@ -1121,7 +1175,14 @@ def update_2d_plot(self):
11211175

11221176
# Get the data
11231177
var_data = self.nc_data_cache['vars'][var_name]
1124-
z_data = var_data[time_idx, :, :]
1178+
1179+
# Check if variable has fractions dimension (4D: time, n, s, fractions)
1180+
if var_data.ndim == 4:
1181+
# Average across fractions or select first fraction
1182+
z_data = var_data[time_idx, :, :, :].mean(axis=2) # Average across fractions
1183+
else:
1184+
z_data = var_data[time_idx, :, :]
1185+
11251186
x_data = self.nc_data_cache['x']
11261187
y_data = self.nc_data_cache['y']
11271188

0 commit comments

Comments
 (0)