55from plotly .subplots import make_subplots
66
77import maxplotlib .backends .matplotlib .utils as plt_utils
8- import maxplotlib .subfigure .line_plot as lp
9- import maxplotlib .subfigure .tikz_figure as tf
8+ from maxplotlib .subfigure .line_plot import LinePlot
9+ from maxplotlib .subfigure .tikz_figure import TikzFigure
1010
1111
1212class Canvas :
@@ -37,11 +37,15 @@ def __init__(self, **kwargs):
3737
3838 # Dictionary to store lines for each subplot
3939 # Key: (row, col), Value: list of lines with their data and kwargs
40- self .subplots = {}
40+ self ._subplots = {}
4141 self ._num_subplots = 0
4242
4343 self ._subplot_matrix = [[None ] * self .ncols for _ in range (self .nrows )]
4444
45+ @property
46+ def subplots (self ):
47+ return self ._subplots
48+
4549 @property
4650 def layers (self ):
4751 layers = []
@@ -66,34 +70,92 @@ def generate_new_rowcol(self, row, col):
6670 assert col is not None , "Not enough columns!"
6771 return row , col
6872
69- def add_tikzfigure (self , ** kwargs ):
73+ def add_line (
74+ self ,
75+ x_data ,
76+ y_data ,
77+ layer = 0 ,
78+ subplot : LinePlot | None = None ,
79+ row : int | None = None ,
80+ col : int | None = None ,
81+ plot_type = "plot" ,
82+ ** kwargs ,
83+ ):
84+ if row is not None and col is not None :
85+ try :
86+ subplot = self ._subplot_matrix [row ][col ]
87+ except KeyError :
88+ raise ValueError ("Invalid subplot position." )
89+ else :
90+ row , col = 0 , 0
91+ subplot = self ._subplot_matrix [row ][col ]
92+
93+ if subplot is None :
94+ row , col = self .generate_new_rowcol (row , col )
95+ subplot = self .add_subplot (col = col , row = row )
96+
97+ subplot .add_line (
98+ x_data = x_data ,
99+ y_data = y_data ,
100+ layer = layer ,
101+ plot_type = plot_type ,
102+ ** kwargs ,
103+ )
104+
105+ def add_tikzfigure (
106+ self ,
107+ col = None ,
108+ row = None ,
109+ label = None ,
110+ ** kwargs ,
111+ ):
70112 """
71113 Adds a subplot to the figure.
72114
73115 Parameters:
74116 **kwargs: Arbitrary keyword arguments.
75- - col (int): Column index for the subplot.
76- - row (int): Row index for the subplot.
77- - label (str): Label to identify the subplot.
78117 """
79- col = kwargs .get ("col" , None )
80- row = kwargs .get ("row" , None )
81- label = kwargs .get ("label" , None )
82118
83119 row , col = self .generate_new_rowcol (row , col )
84120
85121 # Initialize the LinePlot for the given subplot position
86- tikz_figure = tf .TikzFigure (** kwargs )
122+ tikz_figure = TikzFigure (
123+ col = col ,
124+ row = row ,
125+ label = label ,
126+ ** kwargs ,
127+ )
87128 self ._subplot_matrix [row ][col ] = tikz_figure
88129
89130 # Store the LinePlot instance by its position for easy access
90131 if label is None :
91- self .subplots [(row , col )] = tikz_figure
132+ self ._subplots [(row , col )] = tikz_figure
92133 else :
93- self .subplots [label ] = tikz_figure
134+ self ._subplots [label ] = tikz_figure
94135 return tikz_figure
95136
96- def add_subplot (self , ** kwargs ):
137+ def add_subplot (
138+ self ,
139+ col : int | None = None ,
140+ row : int | None = None ,
141+ figsize : tuple = (10 , 6 ),
142+ title : str | None = None ,
143+ caption : str | None = None ,
144+ description : str | None = None ,
145+ label : str | None = None ,
146+ grid : bool = False ,
147+ legend : bool = False ,
148+ xmin : float | int | None = None ,
149+ xmax : float | int | None = None ,
150+ ymin : float | int | None = None ,
151+ ymax : float | int | None = None ,
152+ xlabel : str | None = None ,
153+ ylabel : str | None = None ,
154+ xscale : float | int = 1.0 ,
155+ yscale : float | int = 1.0 ,
156+ xshift : float | int = 0.0 ,
157+ yshift : float | int = 0.0 ,
158+ ):
97159 """
98160 Adds a subplot to the figure.
99161
@@ -103,21 +165,32 @@ def add_subplot(self, **kwargs):
103165 - row (int): Row index for the subplot.
104166 - label (str): Label to identify the subplot.
105167 """
106- col = kwargs .get ("col" , None )
107- row = kwargs .get ("row" , None )
108- label = kwargs .get ("label" , None )
109168
110169 row , col = self .generate_new_rowcol (row , col )
111170
112171 # Initialize the LinePlot for the given subplot position
113- line_plot = lp .LinePlot (** kwargs )
172+ line_plot = LinePlot (
173+ title = title ,
174+ grid = grid ,
175+ legend = legend ,
176+ xmin = xmin ,
177+ xmax = xmax ,
178+ ymin = ymin ,
179+ ymax = ymax ,
180+ xlabel = xlabel ,
181+ ylabel = ylabel ,
182+ xscale = xscale ,
183+ yscale = yscale ,
184+ xshift = xshift ,
185+ yshift = yshift ,
186+ )
114187 self ._subplot_matrix [row ][col ] = line_plot
115188
116189 # Store the LinePlot instance by its position for easy access
117190 if label is None :
118- self .subplots [(row , col )] = line_plot
191+ self ._subplots [(row , col )] = line_plot
119192 else :
120- self .subplots [label ] = line_plot
193+ self ._subplots [label ] = line_plot
121194 return line_plot
122195
123196 def savefig (
@@ -136,7 +209,10 @@ def savefig(
136209 for layer in self .layers :
137210 layers .append (layer )
138211 fig , axs = self .plot (
139- show = False , backend = "matplotlib" , savefig = True , layers = layers
212+ show = False ,
213+ backend = "matplotlib" ,
214+ savefig = True ,
215+ layers = layers ,
140216 )
141217 _fn = f"{ filename_no_extension } _{ layers } .{ extension } "
142218 fig .savefig (_fn )
@@ -153,25 +229,38 @@ def savefig(
153229 else :
154230
155231 fig , axs = self .plot (
156- show = False , backend = "matplotlib" , savefig = True , layers = layers
232+ show = False ,
233+ backend = "matplotlib" ,
234+ savefig = True ,
235+ layers = layers ,
157236 )
158237 fig .savefig (full_filepath )
159238 if verbose :
160239 print (f"Saved { full_filepath } " )
161240
162- def plot (self , backend = "matplotlib" , show = True , savefig = False , layers = None ):
241+ def plot (self , backend = "matplotlib" , savefig = False , layers = None ):
163242 if backend == "matplotlib" :
164- return self .plot_matplotlib (show = show , savefig = savefig , layers = layers )
243+ return self .plot_matplotlib (savefig = savefig , layers = layers )
165244 elif backend == "plotly" :
166- self .plot_plotly (show = show , savefig = savefig )
245+ return self .plot_plotly (savefig = savefig )
246+ else :
247+ raise ValueError (f"Invalid backend: { backend } " )
167248
168- def plot_matplotlib (self , show = True , savefig = False , layers = None , usetex = False ):
249+ def show (self , backend = "matplotlib" ):
250+ if backend == "matplotlib" :
251+ self .plot (backend = "matplotlib" , savefig = False , layers = None )
252+ self ._matplotlib_fig .show ()
253+ elif backend == "plotly" :
254+ plot = self .plot_plotly (savefig = False )
255+ else :
256+ raise ValueError ("Invalid backend" )
257+
258+ def plot_matplotlib (self , savefig = False , layers = None , usetex = False ):
169259 """
170260 Generate and optionally display the subplots.
171261
172262 Parameters:
173263 filename (str, optional): Filename to save the figure.
174- show (bool): Whether to display the plot.
175264 """
176265
177266 tex_fonts = plt_utils .setup_tex_fonts (fontsize = self .fontsize , usetex = usetex )
@@ -205,17 +294,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
205294
206295 for (row , col ), subplot in self .subplots .items ():
207296 ax = axes [row ][col ]
208- # print(f"{subplot = }")
209297 subplot .plot_matplotlib (ax , layers = layers )
210298 # ax.set_title(f"Subplot ({row}, {col})")
211299 ax .grid ()
212- # Set caption, labels, etc., if needed
213- # plt.tight_layout()
214300
215- if show :
216- plt .show ()
217- # else:
218- # plt.close()
301+ # Set caption, labels, etc., if needed
219302 self ._plotted = True
220303 self ._matplotlib_fig = fig
221304 self ._matplotlib_axes = axes
@@ -240,7 +323,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
240323 fig_width , fig_height = self ._figsize
241324 else :
242325 fig_width , fig_height = plt_utils .set_size (
243- width = self ._width , ratio = self ._ratio
326+ width = self ._width ,
327+ ratio = self ._ratio ,
244328 )
245329 # print(self._width, fig_width, fig_height)
246330 # Create subplots
@@ -271,8 +355,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
271355 fig .write_image (savefig )
272356
273357 # Show or return the figure
274- if show :
275- fig .show ()
358+ # if show:
359+ # fig.show()
276360 return fig
277361
278362 # Property getters
0 commit comments