|  | 
|  | 1 | +import numpy as np | 
|  | 2 | +import plotly.graph_objects as go | 
|  | 3 | +from plotly.subplots import make_subplots | 
|  | 4 | + | 
|  | 5 | + | 
|  | 6 | +def plot_subplots( | 
|  | 7 | +    data1, | 
|  | 8 | +    data2, | 
|  | 9 | +    labels_x=None, | 
|  | 10 | +    labels_y=None, | 
|  | 11 | +    subplot_titles=None, | 
|  | 12 | +    title="", | 
|  | 13 | +    nrows=None, | 
|  | 14 | +    ncols=None, | 
|  | 15 | +    linewidth=1, | 
|  | 16 | +    markersize=4, | 
|  | 17 | +    linecolor=None, | 
|  | 18 | +    markercolor=None, | 
|  | 19 | +    fontsize=12, | 
|  | 20 | +    fig=None, | 
|  | 21 | +): | 
|  | 22 | +    """ | 
|  | 23 | +    Plot a grid of subplots using Plotly, handling both single-component (scalar vs scalar) and multi-component data. | 
|  | 24 | +
 | 
|  | 25 | +    Parameters: | 
|  | 26 | +    - data1: numpy array, first set of data to plot (e.g., strain, time) with shape (n_datapoints, n_plots) | 
|  | 27 | +    - data2: numpy array, second set of data to plot (e.g., stress) with shape (n_datapoints, n_plots) | 
|  | 28 | +    - labels_x: list of strings, labels for the x axes of each subplot (optional, default=None) | 
|  | 29 | +    - labels_y: list of strings, labels for the y axes of each subplot (optional, default=None) | 
|  | 30 | +    - subplot_titles: list of strings, titles for each subplot (optional, default=None) | 
|  | 31 | +    - title: string, title of the overall plot | 
|  | 32 | +    - nrows: int, number of rows in the subplot grid (optional) | 
|  | 33 | +    - ncols: int, number of columns in the subplot grid (optional) | 
|  | 34 | +    - linewidth: int, line width for the plots (optional, default=1) | 
|  | 35 | +    - markersize: int, size of the markers (optional, default=4) | 
|  | 36 | +    - linecolor: list of strings, colors of the lines for each subplot (optional, default=None, all blue) | 
|  | 37 | +    - markercolor: list of strings, colors of the markers for each subplot (optional, default=None, all blue) | 
|  | 38 | +    - fontsize: int, font size for axis labels, subplot titles, and tick labels (optional, default=12) | 
|  | 39 | +    - fig: existing Plotly figure to overlay the new subplots (optional, default=None, creates a new figure) | 
|  | 40 | +    """ | 
|  | 41 | +    # Validate data shapes | 
|  | 42 | +    if not isinstance(data1, np.ndarray) or not isinstance(data2, np.ndarray): | 
|  | 43 | +        raise ValueError("data1 and data2 must be numpy arrays.") | 
|  | 44 | + | 
|  | 45 | +    if data1.shape[0] != data2.shape[0]: | 
|  | 46 | +        raise ValueError( | 
|  | 47 | +            "data1 and data2 must have the same number of data points (rows)." | 
|  | 48 | +        ) | 
|  | 49 | + | 
|  | 50 | +    if data1.shape[1] != data2.shape[1]: | 
|  | 51 | +        raise ValueError( | 
|  | 52 | +            "data1 and data2 must have the same number of components (columns)." | 
|  | 53 | +        ) | 
|  | 54 | + | 
|  | 55 | +    # Set the number of components based on data shape | 
|  | 56 | +    n_components = data1.shape[1] | 
|  | 57 | + | 
|  | 58 | +    # Initialize linecolor and markercolor lists if not provided | 
|  | 59 | +    if linecolor is None: | 
|  | 60 | +        linecolor = ["blue"] * n_components | 
|  | 61 | +    elif len(linecolor) != n_components: | 
|  | 62 | +        raise ValueError( | 
|  | 63 | +            f"The length of linecolor must match the number of components ({n_components})." | 
|  | 64 | +        ) | 
|  | 65 | + | 
|  | 66 | +    if markercolor is None: | 
|  | 67 | +        markercolor = ["blue"] * n_components | 
|  | 68 | +    elif len(markercolor) != n_components: | 
|  | 69 | +        raise ValueError( | 
|  | 70 | +            f"The length of markercolor must match the number of components ({n_components})." | 
|  | 71 | +        ) | 
|  | 72 | + | 
|  | 73 | +    # If nrows or ncols is not specified, determine an optimal grid layout | 
|  | 74 | +    if nrows is None or ncols is None: | 
|  | 75 | +        nrows = int(np.ceil(np.sqrt(n_components))) | 
|  | 76 | +        ncols = int(np.ceil(n_components / nrows)) | 
|  | 77 | + | 
|  | 78 | +    # Handle subplot titles | 
|  | 79 | +    if subplot_titles is None: | 
|  | 80 | +        subplot_titles = [f"Component {i+1}" for i in range(n_components)] | 
|  | 81 | +    elif len(subplot_titles) != n_components: | 
|  | 82 | +        raise ValueError( | 
|  | 83 | +            f"The length of subplot_titles must match the number of components ({n_components})." | 
|  | 84 | +        ) | 
|  | 85 | + | 
|  | 86 | +    # Handle labels_x and labels_y | 
|  | 87 | +    if labels_x is None: | 
|  | 88 | +        labels_x = [""] * n_components | 
|  | 89 | +    elif len(labels_x) != n_components: | 
|  | 90 | +        raise ValueError( | 
|  | 91 | +            f"The length of labels_x must match the number of components ({n_components})." | 
|  | 92 | +        ) | 
|  | 93 | + | 
|  | 94 | +    if labels_y is None: | 
|  | 95 | +        labels_y = [""] * n_components | 
|  | 96 | +    elif len(labels_y) != n_components: | 
|  | 97 | +        raise ValueError( | 
|  | 98 | +            f"The length of labels_y must match the number of components ({n_components})." | 
|  | 99 | +        ) | 
|  | 100 | + | 
|  | 101 | +    # Create the subplot figure if not provided | 
|  | 102 | +    if fig is None: | 
|  | 103 | +        fig = make_subplots(rows=nrows, cols=ncols, subplot_titles=subplot_titles) | 
|  | 104 | + | 
|  | 105 | +    # Add traces for each component | 
|  | 106 | +    for i in range(n_components): | 
|  | 107 | +        row = i // ncols + 1 | 
|  | 108 | +        col = i % ncols + 1 | 
|  | 109 | +        fig.add_trace( | 
|  | 110 | +            go.Scatter( | 
|  | 111 | +                x=data1[:, i], | 
|  | 112 | +                y=data2[:, i], | 
|  | 113 | +                mode="lines+markers", | 
|  | 114 | +                marker=dict(symbol="x", size=markersize, color=markercolor[i]), | 
|  | 115 | +                line=dict(width=linewidth, color=linecolor[i]), | 
|  | 116 | +                name=f"Component {i+1}", | 
|  | 117 | +            ), | 
|  | 118 | +            row=row, | 
|  | 119 | +            col=col, | 
|  | 120 | +        ) | 
|  | 121 | + | 
|  | 122 | +        # Update axes with text labels | 
|  | 123 | +        fig.update_xaxes( | 
|  | 124 | +            title_text=labels_x[i], | 
|  | 125 | +            row=row, | 
|  | 126 | +            col=col, | 
|  | 127 | +            showgrid=True, | 
|  | 128 | +            mirror=True, | 
|  | 129 | +            ticks="inside", | 
|  | 130 | +            tickwidth=2, | 
|  | 131 | +            ticklen=6, | 
|  | 132 | +            title_font=dict(size=fontsize), | 
|  | 133 | +            tickfont=dict(size=fontsize), | 
|  | 134 | +            automargin=True, | 
|  | 135 | +        ) | 
|  | 136 | +        fig.update_yaxes( | 
|  | 137 | +            title_text=labels_y[i], | 
|  | 138 | +            row=row, | 
|  | 139 | +            col=col, | 
|  | 140 | +            showgrid=True, | 
|  | 141 | +            mirror=True, | 
|  | 142 | +            ticks="inside", | 
|  | 143 | +            tickwidth=2, | 
|  | 144 | +            ticklen=6, | 
|  | 145 | +            title_font=dict(size=fontsize), | 
|  | 146 | +            tickfont=dict(size=fontsize), | 
|  | 147 | +            automargin=True, | 
|  | 148 | +        ) | 
|  | 149 | + | 
|  | 150 | +    # Update layout with the overall plot title and styling | 
|  | 151 | +    fig.update_layout( | 
|  | 152 | +        height=500, | 
|  | 153 | +        width=800, | 
|  | 154 | +        title_text=title, | 
|  | 155 | +        title_font=dict(size=fontsize), | 
|  | 156 | +        showlegend=False,  # Legends removed | 
|  | 157 | +        template="plotly_white", | 
|  | 158 | +        margin=dict(l=50, r=50, t=50, b=50),  # Adjust margins to prevent overlap | 
|  | 159 | +        title_x=0.5, | 
|  | 160 | +        autosize=False, | 
|  | 161 | +    ) | 
|  | 162 | + | 
|  | 163 | +    # Add a box outline around all subplots | 
|  | 164 | +    for i in range(1, nrows * ncols + 1): | 
|  | 165 | +        fig.update_xaxes( | 
|  | 166 | +            showline=True, | 
|  | 167 | +            linewidth=2, | 
|  | 168 | +            linecolor="black", | 
|  | 169 | +            row=(i - 1) // ncols + 1, | 
|  | 170 | +            col=(i - 1) % ncols + 1, | 
|  | 171 | +        ) | 
|  | 172 | +        fig.update_yaxes( | 
|  | 173 | +            showline=True, | 
|  | 174 | +            linewidth=2, | 
|  | 175 | +            linecolor="black", | 
|  | 176 | +            row=(i - 1) // ncols + 1, | 
|  | 177 | +            col=(i - 1) % ncols + 1, | 
|  | 178 | +        ) | 
|  | 179 | + | 
|  | 180 | +    # Update subplot titles with the specified fontsize | 
|  | 181 | +    for annotation in fig["layout"]["annotations"]: | 
|  | 182 | +        annotation["font"] = dict(size=fontsize) | 
|  | 183 | + | 
|  | 184 | +    # Return the figure for further customization or overlaying | 
|  | 185 | +    return fig | 
0 commit comments