diff --git a/logger.py b/logger.py index b6e617b..53d1295 100644 --- a/logger.py +++ b/logger.py @@ -1,44 +1,43 @@ import torch - -class Logger(object): +class Logger: def __init__(self, runs, info=None): self.info = info self.results = [[] for _ in range(runs)] def add_result(self, run, result): - assert len(result) == 3 - assert run >= 0 and run < len(self.results) + if len(result) != 3 or not (0 <= run < len(self.results)): + raise ValueError("Invalid result format or run index.") self.results[run].append(result) def print_statistics(self, run=None): + def calculate_statistics(data): + train_max = data[:, 0].max().item() + valid_max = data[:, 1].max().item() + best_index = data[:, 1].argmax() + final_train = data[best_index, 0].item() + final_test = data[best_index, 2].item() + return train_max, valid_max, final_train, final_test + if run is not None: - result = 100 * torch.tensor(self.results[run]) - argmax = result[:, 1].argmax().item() + data = 100 * torch.tensor(self.results[run]) + train_max, valid_max, final_train, final_test = calculate_statistics(data) print(f'Run {run + 1:02d}:') - print(f'Highest Train: {result[:, 0].max():.2f}') - print(f'Highest Valid: {result[:, 1].max():.2f}') - print(f' Final Train: {result[argmax, 0]:.2f}') - print(f' Final Test: {result[argmax, 2]:.2f}') + print(f'Highest Train: {train_max:.2f}') + print(f'Highest Valid: {valid_max:.2f}') + print(f' Final Train: {final_train:.2f}') + print(f' Final Test: {final_test:.2f}') else: - result = 100 * torch.tensor(self.results) - - best_results = [] - for r in result: - train1 = r[:, 0].max().item() - valid = r[:, 1].max().item() - train2 = r[r[:, 1].argmax(), 0].item() - test = r[r[:, 1].argmax(), 2].item() - best_results.append((train1, valid, train2, test)) + all_results = 100 * torch.tensor(self.results) + stats = [calculate_statistics(run_data) for run_data in all_results] + stats_tensor = torch.tensor(stats) - best_result = torch.tensor(best_results) + def print_mean_std(idx, label): + values = stats_tensor[:, idx] + print(f'{label}: {values.mean():.2f} ± {values.std():.2f}') print(f'All runs:') - r = best_result[:, 0] - print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') - r = best_result[:, 1] - print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') - r = best_result[:, 2] - print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') - r = best_result[:, 3] - print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') + print_mean_std(0, 'Highest Train') + print_mean_std(1, 'Highest Valid') + print_mean_std(2, 'Final Train') + print_mean_std(3, 'Final Test')