The long_time_metrics saved to time_logs inside the Trainer.validation_loop method are overriden for every batch. This means that the metrics use only the last batch
A fix would be something like this:
# FIX: Accumulate time_logs averages instead of overwriting with |=
if k in long_time_metrics or "spectral_error" in k:
for key, val in new_time_logs.items():
if key in time_logs:
time_logs[key] += val / denom
else:
time_logs[key] = val / denom