From 84283f9b9a9260fff04f2b133653a2fc1114b8d8 Mon Sep 17 00:00:00 2001 From: Dmitry Batenkov Date: Sat, 14 Jun 2025 14:03:37 -0400 Subject: [PATCH 1/2] minor tweaks to public facing functions --- collab/foraging/toolkit/evaluate.py | 6 ++++++ collab/foraging/toolkit/inference.py | 2 +- collab/foraging/toolkit/visualization.py | 3 ++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/collab/foraging/toolkit/evaluate.py b/collab/foraging/toolkit/evaluate.py index 65596198..f5116bb0 100644 --- a/collab/foraging/toolkit/evaluate.py +++ b/collab/foraging/toolkit/evaluate.py @@ -69,3 +69,9 @@ def evaluate_performance( fig.suptitle("Model evaluation", fontsize=16) plt.show() + else: + return { + "coverage": coverage, + "mae": mae, + "rsquared": rsquared, + } \ No newline at end of file diff --git a/collab/foraging/toolkit/inference.py b/collab/foraging/toolkit/inference.py index da731d99..eece48fa 100644 --- a/collab/foraging/toolkit/inference.py +++ b/collab/foraging/toolkit/inference.py @@ -100,7 +100,7 @@ def run_svi_inference( loss.backward() losses.append(loss.item()) adam.step() - if (step % 200 == 0) or (step == 1) & verbose: + if verbose and ((step % 200 == 0) or (step == 1)): print("[iteration %04d] loss: %.4f" % (step, loss)) if plot: diff --git a/collab/foraging/toolkit/visualization.py b/collab/foraging/toolkit/visualization.py index f6d6e899..8fbd0492 100644 --- a/collab/foraging/toolkit/visualization.py +++ b/collab/foraging/toolkit/visualization.py @@ -201,7 +201,7 @@ def update(frame): return foragers_scat, *predictors_scat_list # Create the animation - print(num_frames) + print(f"Animation generation complete. Generated {num_frames} frames.") ani = animation.FuncAnimation( fig, update, @@ -211,4 +211,5 @@ def update(frame): interval=500, repeat_delay=3500, ) + plt.close(fig) # Close the figure to prevent displaying the static plot return ani From 349d585240a3c08e5f977e560f02b04aeea94a50 Mon Sep 17 00:00:00 2001 From: Dmitry Batenkov Date: Tue, 24 Jun 2025 12:35:50 -0400 Subject: [PATCH 2/2] format lint --- collab/foraging/toolkit/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/collab/foraging/toolkit/evaluate.py b/collab/foraging/toolkit/evaluate.py index f5116bb0..d41da03a 100644 --- a/collab/foraging/toolkit/evaluate.py +++ b/collab/foraging/toolkit/evaluate.py @@ -74,4 +74,4 @@ def evaluate_performance( "coverage": coverage, "mae": mae, "rsquared": rsquared, - } \ No newline at end of file + }