Skip to content

Commit 94a1426

Browse files
Refactoring of canvas (#16)
* import Canvas directly * Update plot directly from canvas * replace kwargs with explicit arguments * Fix tikz figure * Added trailing commas * Update src/maxplotlib/canvas/canvas.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Formatting --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
1 parent 6f9918f commit 94a1426

File tree

11 files changed

+264
-165
lines changed

11 files changed

+264
-165
lines changed

src/maxplotlib/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from maxplotlib.canvas.canvas import Canvas
2+
3+
__all__ = ["Canvas"]

src/maxplotlib/backends/matplotlib/utils_old.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,12 @@ def setup_plotstyle(
299299

300300
def set_common_xlabel(self, xlabel="common X"):
301301
self.fig.text(
302-
0.5, -0.075, xlabel, va="center", ha="center", fontsize=self.fontsize
302+
0.5,
303+
-0.075,
304+
xlabel,
305+
va="center",
306+
ha="center",
307+
fontsize=self.fontsize,
303308
)
304309
# fig.text(0.04, 0.5, 'common Y', va='center', ha='center', rotation='vertical', fontsize=rcParams['axes.labelsize'])
305310

@@ -438,7 +443,9 @@ def scale_axis(
438443
i0 = int(xmin / delta)
439444
i1 = int(xmax / delta + 1)
440445
locs = np.arange(
441-
includepoint - width, includepoint + width + delta, delta
446+
includepoint - width,
447+
includepoint + width + delta,
448+
delta,
442449
)
443450
locs = locs[locs >= xmin - 1e-12]
444451
locs = locs[locs <= xmax + 1e-12]
@@ -473,7 +480,9 @@ def scale_axis(
473480
i0 = int(ymin / delta)
474481
i1 = int(ymax / delta + 1)
475482
locs = np.arange(
476-
includepoint - width, includepoint + width + delta, delta
483+
includepoint - width,
484+
includepoint + width + delta,
485+
delta,
477486
)
478487
locs = locs[locs >= ymin - 1e-12]
479488
locs = locs[locs <= ymax + 1e-12]
@@ -507,7 +516,10 @@ def adjustFigAspect(self, aspect=1):
507516
else:
508517
ylim /= aspect
509518
self.fig.subplots_adjust(
510-
left=0.5 - xlim, right=0.5 + xlim, bottom=0.5 - ylim, top=0.5 + ylim
519+
left=0.5 - xlim,
520+
right=0.5 + xlim,
521+
bottom=0.5 - ylim,
522+
top=0.5 + ylim,
511523
)
512524

513525
def add_figure_label(
@@ -619,14 +631,16 @@ def savefig(
619631
# self.fig.savefig(self.directory + filename + '.' + format,bbox_inches='tight', transparent=False)
620632
if tight_layout:
621633
self.fig.savefig(
622-
self.directory + filename + "." + format, bbox_inches="tight"
634+
self.directory + filename + "." + format,
635+
bbox_inches="tight",
623636
)
624637
else:
625638
self.fig.savefig(self.directory + filename + "." + format)
626639
elif format == "pgf":
627640
# Save pgf figure
628641
self.fig.savefig(
629-
self.directory + filename + "." + format, bbox_inches="tight"
642+
self.directory + filename + "." + format,
643+
bbox_inches="tight",
630644
)
631645

632646
# Replace pgf figure colors with colorlet
@@ -672,15 +686,16 @@ def savefig(
672686
else:
673687
try:
674688
plt.savefig(
675-
self.directory + filename + "." + format, bbox_inches="tight"
689+
self.directory + filename + "." + format,
690+
bbox_inches="tight",
676691
)
677692
except Exception as e:
678693
print(
679694
"ERROR: Could not save figure: "
680695
+ self.directory
681696
+ filename
682697
+ "."
683-
+ format
698+
+ format,
684699
)
685700
print(e)
686701

@@ -690,7 +705,7 @@ def savefig(
690705
for format in formats:
691706
if format in imgcat_formats:
692707
f.write(
693-
"imgcat " + self.directory + filename + "." + format + "\n"
708+
"imgcat " + self.directory + filename + "." + format + "\n",
694709
)
695710

696711
if print_imgcat and ("png" in formats or "pdf" in formats):

src/maxplotlib/canvas/canvas.py

Lines changed: 121 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from plotly.subplots import make_subplots
66

77
import maxplotlib.backends.matplotlib.utils as plt_utils
8-
import maxplotlib.subfigure.line_plot as lp
9-
import maxplotlib.subfigure.tikz_figure as tf
8+
from maxplotlib.subfigure.line_plot import LinePlot
9+
from maxplotlib.subfigure.tikz_figure import TikzFigure
1010

1111

1212
class Canvas:
@@ -37,11 +37,15 @@ def __init__(self, **kwargs):
3737

3838
# Dictionary to store lines for each subplot
3939
# Key: (row, col), Value: list of lines with their data and kwargs
40-
self.subplots = {}
40+
self._subplots = {}
4141
self._num_subplots = 0
4242

4343
self._subplot_matrix = [[None] * self.ncols for _ in range(self.nrows)]
4444

45+
@property
46+
def subplots(self):
47+
return self._subplots
48+
4549
@property
4650
def layers(self):
4751
layers = []
@@ -66,34 +70,92 @@ def generate_new_rowcol(self, row, col):
6670
assert col is not None, "Not enough columns!"
6771
return row, col
6872

69-
def add_tikzfigure(self, **kwargs):
73+
def add_line(
74+
self,
75+
x_data,
76+
y_data,
77+
layer=0,
78+
subplot: LinePlot | None = None,
79+
row: int | None = None,
80+
col: int | None = None,
81+
plot_type="plot",
82+
**kwargs,
83+
):
84+
if row is not None and col is not None:
85+
try:
86+
subplot = self._subplot_matrix[row][col]
87+
except KeyError:
88+
raise ValueError("Invalid subplot position.")
89+
else:
90+
row, col = 0, 0
91+
subplot = self._subplot_matrix[row][col]
92+
93+
if subplot is None:
94+
row, col = self.generate_new_rowcol(row, col)
95+
subplot = self.add_subplot(col=col, row=row)
96+
97+
subplot.add_line(
98+
x_data=x_data,
99+
y_data=y_data,
100+
layer=layer,
101+
plot_type=plot_type,
102+
**kwargs,
103+
)
104+
105+
def add_tikzfigure(
106+
self,
107+
col=None,
108+
row=None,
109+
label=None,
110+
**kwargs,
111+
):
70112
"""
71113
Adds a subplot to the figure.
72114
73115
Parameters:
74116
**kwargs: Arbitrary keyword arguments.
75-
- col (int): Column index for the subplot.
76-
- row (int): Row index for the subplot.
77-
- label (str): Label to identify the subplot.
78117
"""
79-
col = kwargs.get("col", None)
80-
row = kwargs.get("row", None)
81-
label = kwargs.get("label", None)
82118

83119
row, col = self.generate_new_rowcol(row, col)
84120

85121
# Initialize the LinePlot for the given subplot position
86-
tikz_figure = tf.TikzFigure(**kwargs)
122+
tikz_figure = TikzFigure(
123+
col=col,
124+
row=row,
125+
label=label,
126+
**kwargs,
127+
)
87128
self._subplot_matrix[row][col] = tikz_figure
88129

89130
# Store the LinePlot instance by its position for easy access
90131
if label is None:
91-
self.subplots[(row, col)] = tikz_figure
132+
self._subplots[(row, col)] = tikz_figure
92133
else:
93-
self.subplots[label] = tikz_figure
134+
self._subplots[label] = tikz_figure
94135
return tikz_figure
95136

96-
def add_subplot(self, **kwargs):
137+
def add_subplot(
138+
self,
139+
col: int | None = None,
140+
row: int | None = None,
141+
figsize: tuple = (10, 6),
142+
title: str | None = None,
143+
caption: str | None = None,
144+
description: str | None = None,
145+
label: str | None = None,
146+
grid: bool = False,
147+
legend: bool = False,
148+
xmin: float | int | None = None,
149+
xmax: float | int | None = None,
150+
ymin: float | int | None = None,
151+
ymax: float | int | None = None,
152+
xlabel: str | None = None,
153+
ylabel: str | None = None,
154+
xscale: float | int = 1.0,
155+
yscale: float | int = 1.0,
156+
xshift: float | int = 0.0,
157+
yshift: float | int = 0.0,
158+
):
97159
"""
98160
Adds a subplot to the figure.
99161
@@ -103,21 +165,32 @@ def add_subplot(self, **kwargs):
103165
- row (int): Row index for the subplot.
104166
- label (str): Label to identify the subplot.
105167
"""
106-
col = kwargs.get("col", None)
107-
row = kwargs.get("row", None)
108-
label = kwargs.get("label", None)
109168

110169
row, col = self.generate_new_rowcol(row, col)
111170

112171
# Initialize the LinePlot for the given subplot position
113-
line_plot = lp.LinePlot(**kwargs)
172+
line_plot = LinePlot(
173+
title=title,
174+
grid=grid,
175+
legend=legend,
176+
xmin=xmin,
177+
xmax=xmax,
178+
ymin=ymin,
179+
ymax=ymax,
180+
xlabel=xlabel,
181+
ylabel=ylabel,
182+
xscale=xscale,
183+
yscale=yscale,
184+
xshift=xshift,
185+
yshift=yshift,
186+
)
114187
self._subplot_matrix[row][col] = line_plot
115188

116189
# Store the LinePlot instance by its position for easy access
117190
if label is None:
118-
self.subplots[(row, col)] = line_plot
191+
self._subplots[(row, col)] = line_plot
119192
else:
120-
self.subplots[label] = line_plot
193+
self._subplots[label] = line_plot
121194
return line_plot
122195

123196
def savefig(
@@ -136,7 +209,10 @@ def savefig(
136209
for layer in self.layers:
137210
layers.append(layer)
138211
fig, axs = self.plot(
139-
show=False, backend="matplotlib", savefig=True, layers=layers
212+
show=False,
213+
backend="matplotlib",
214+
savefig=True,
215+
layers=layers,
140216
)
141217
_fn = f"{filename_no_extension}_{layers}.{extension}"
142218
fig.savefig(_fn)
@@ -153,25 +229,38 @@ def savefig(
153229
else:
154230

155231
fig, axs = self.plot(
156-
show=False, backend="matplotlib", savefig=True, layers=layers
232+
show=False,
233+
backend="matplotlib",
234+
savefig=True,
235+
layers=layers,
157236
)
158237
fig.savefig(full_filepath)
159238
if verbose:
160239
print(f"Saved {full_filepath}")
161240

162-
def plot(self, backend="matplotlib", show=True, savefig=False, layers=None):
241+
def plot(self, backend="matplotlib", savefig=False, layers=None):
163242
if backend == "matplotlib":
164-
return self.plot_matplotlib(show=show, savefig=savefig, layers=layers)
243+
return self.plot_matplotlib(savefig=savefig, layers=layers)
165244
elif backend == "plotly":
166-
self.plot_plotly(show=show, savefig=savefig)
245+
return self.plot_plotly(savefig=savefig)
246+
else:
247+
raise ValueError(f"Invalid backend: {backend}")
167248

168-
def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
249+
def show(self, backend="matplotlib"):
250+
if backend == "matplotlib":
251+
self.plot(backend="matplotlib", savefig=False, layers=None)
252+
self._matplotlib_fig.show()
253+
elif backend == "plotly":
254+
plot = self.plot_plotly(savefig=False)
255+
else:
256+
raise ValueError("Invalid backend")
257+
258+
def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
169259
"""
170260
Generate and optionally display the subplots.
171261
172262
Parameters:
173263
filename (str, optional): Filename to save the figure.
174-
show (bool): Whether to display the plot.
175264
"""
176265

177266
tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)
@@ -205,17 +294,11 @@ def plot_matplotlib(self, show=True, savefig=False, layers=None, usetex=False):
205294

206295
for (row, col), subplot in self.subplots.items():
207296
ax = axes[row][col]
208-
# print(f"{subplot = }")
209297
subplot.plot_matplotlib(ax, layers=layers)
210298
# ax.set_title(f"Subplot ({row}, {col})")
211299
ax.grid()
212-
# Set caption, labels, etc., if needed
213-
# plt.tight_layout()
214300

215-
if show:
216-
plt.show()
217-
# else:
218-
# plt.close()
301+
# Set caption, labels, etc., if needed
219302
self._plotted = True
220303
self._matplotlib_fig = fig
221304
self._matplotlib_axes = axes
@@ -240,7 +323,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
240323
fig_width, fig_height = self._figsize
241324
else:
242325
fig_width, fig_height = plt_utils.set_size(
243-
width=self._width, ratio=self._ratio
326+
width=self._width,
327+
ratio=self._ratio,
244328
)
245329
# print(self._width, fig_width, fig_height)
246330
# Create subplots
@@ -271,8 +355,8 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
271355
fig.write_image(savefig)
272356

273357
# Show or return the figure
274-
if show:
275-
fig.show()
358+
# if show:
359+
# fig.show()
276360
return fig
277361

278362
# Property getters

0 commit comments

Comments
 (0)