diff --git a/src/blase/emulator.py b/src/blase/emulator.py index f7641d54..46a5eda0 100644 --- a/src/blase/emulator.py +++ b/src/blase/emulator.py @@ -9,6 +9,7 @@ from torch import nn import numpy as np from scipy.signal import find_peaks, peak_prominences, peak_widths +from scipy.special import voigt_profile import torch.optim as optim from tqdm import trange from torch.special import erfc @@ -327,6 +328,47 @@ def pseudo_voigt_profiles(self, wavelengths): def optimize(self): """Optimize the model parameters""" raise NotImplementedError + + + def animate(self, old_state_dict, size=1.5): + """Animate the model from a previous state to the current state""" + try: + import manim + except ImportError: + print("Manim is required for the .animate() feature, but it is not installed. Please install manim and try again.") + return None + + ## Manim requires a scene class, so we'll make one here + class ArgMinExample(manim.Scene, size=0.9): + def construct(self, size): + ax = manim.Axes( + x_range=[-10, 10], y_range=[0, 1, 0.25], axis_config={"include_tip": False} + ) + labels = ax.get_axis_labels(x_label=manim.Tex(r"$\lambda$"), + y_label=manim.Tex(r"$f(\lambda)$")) + + graph1 = ax.plot(lambda x: 1 - 2*voigt_profile(x, 1.0, 0), color=manim.MAROON) + graph2 = ax.plot(lambda x: 1 - 2*voigt_profile(x, 1.0, size), color=manim.MAROON) + + # Plot noisy data with manim: + x = np.linspace(-10, 10, 100) + y = 1 - 2*voigt_profile(x, 1.0, size) + y += np.random.normal(0, 0.03, y.shape) + coords = np.vstack((x,y)).T + + dots = manim.VGroup(*[manim.Dot().move_to(ax.c2p(coord[0],coord[1])) for coord in coords]) + self.add(dots) + + + self.add(ax, labels) + self.play(manim.Create(graph1)) + self.play(manim.Transform(graph1, graph2)) + self.play(manim.FadeOut(graph2)) + + scene = ArgMinExample(size=size) + scene.render(preview=False) # That's it! + return scene.renderer.file_writer.movie_file_path + class SparseLinearEmulator(LinearEmulator):