diff --git a/collab/foraging/toolkit/evaluate.py b/collab/foraging/toolkit/evaluate.py index 65596198..d41da03a 100644 --- a/collab/foraging/toolkit/evaluate.py +++ b/collab/foraging/toolkit/evaluate.py @@ -69,3 +69,9 @@ def evaluate_performance( fig.suptitle("Model evaluation", fontsize=16) plt.show() + else: + return { + "coverage": coverage, + "mae": mae, + "rsquared": rsquared, + } diff --git a/collab/foraging/toolkit/inference.py b/collab/foraging/toolkit/inference.py index da731d99..eece48fa 100644 --- a/collab/foraging/toolkit/inference.py +++ b/collab/foraging/toolkit/inference.py @@ -100,7 +100,7 @@ def run_svi_inference( loss.backward() losses.append(loss.item()) adam.step() - if (step % 200 == 0) or (step == 1) & verbose: + if verbose and ((step % 200 == 0) or (step == 1)): print("[iteration %04d] loss: %.4f" % (step, loss)) if plot: diff --git a/collab/foraging/toolkit/visualization.py b/collab/foraging/toolkit/visualization.py index f6d6e899..8fbd0492 100644 --- a/collab/foraging/toolkit/visualization.py +++ b/collab/foraging/toolkit/visualization.py @@ -201,7 +201,7 @@ def update(frame): return foragers_scat, *predictors_scat_list # Create the animation - print(num_frames) + print(f"Animation generation complete. Generated {num_frames} frames.") ani = animation.FuncAnimation( fig, update, @@ -211,4 +211,5 @@ def update(frame): interval=500, repeat_delay=3500, ) + plt.close(fig) # Close the figure to prevent displaying the static plot return ani