From 705133c8036f102cdb9d62d78f7a424360de3e39 Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Wed, 21 May 2025 18:57:53 -0700 Subject: [PATCH 1/2] refactor: clean imports, add argparse, modularize config, and improve evaluation handling --- .github/workflows/build-and-test.yml | 45 +++ .github/workflows/pull_request.yml | 37 +++ baseline_finetuning.py | 195 ------------- .../example_config.yaml | 0 job_template.slurm => jobs/job_template.slurm | 0 requirements.txt.licence | 1 + .../download_unpack_isic2019.sh | 0 .../submit_from_config.sh | 11 +- src/__init__.py | 0 .../evaluation/evaluate_isic_results.py | 166 +++++------ src/finetune/baseline_finetuning.py | 260 ++++++++++++++++++ .../models/model_comparison.py | 0 .../models/model_comparison_2.py | 0 13 files changed, 428 insertions(+), 287 deletions(-) create mode 100644 .github/workflows/build-and-test.yml create mode 100644 .github/workflows/pull_request.yml delete mode 100644 baseline_finetuning.py rename example_config.yaml => configs/example_config.yaml (100%) rename job_template.slurm => jobs/job_template.slurm (100%) create mode 100644 requirements.txt.licence rename download_unpack_isic2019.sh => scripts/download_unpack_isic2019.sh (100%) rename submit_from_config.sh => scripts/submit_from_config.sh (76%) create mode 100644 src/__init__.py rename evaluate_isic_results.py => src/evaluation/evaluate_isic_results.py (54%) create mode 100644 src/finetune/baseline_finetuning.py rename model_comparison.py => src/models/model_comparison.py (100%) rename model_comparison_2.py => src/models/model_comparison_2.py (100%) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml new file mode 100644 index 0000000..767cad6 --- /dev/null +++ b/.github/workflows/build-and-test.yml @@ -0,0 +1,45 @@ +# +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Build and Test + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + workflow_call: + +jobs: + pylint: + name: PyLint + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Infrastructure + run: | + pip install -r requirements.txt + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') + black_lint: + name: Black Code Formatter Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Black + run: pip install black[jupyter] + - name: Check code formatting with Black + run: black . --exclude '\.ipynb$' diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 0000000..b98bed7 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,37 @@ +# +# This source file is part of the ARPA-H CARE LLM project +# +# SPDX-FileCopyrightText: 2025 Stanford University and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +name: Pull Request + +on: + pull_request: + workflow_dispatch: + +jobs: + reuse_action: + name: REUSE Compliance Check + uses: DaneshjouLab/.github/.github/workflows/reuse.yml@main + markdown_link_check: + name: Markdown Link Check + uses: DaneshjouLab/.github/.github/workflows/markdown-link-check.yml@main + yamllint: + name: YAML Lint Check + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Install yamllint + run: pip install yamllint + + - name: Run yamllint with custom config + run: yamllint -c .yamllint .github/workflows/*.yml diff --git a/baseline_finetuning.py b/baseline_finetuning.py deleted file mode 100644 index 9255bfc..0000000 --- a/baseline_finetuning.py +++ /dev/null @@ -1,195 +0,0 @@ -''' -This script fine-tunes a ViT model on the ISIC 2019 dataset with various resolutions. -It includes data augmentation, model evaluation, and GPU memory monitoring. -''' - -#TODO: pass the file paths as os.imports - -import torch -import torch.nn as nn -from torch.utils.data import Dataset -from PIL import Image -from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer -from datasets import load_dataset -from torchvision import transforms -from sklearn.metrics import accuracy_score, f1_score, roc_auc_score -import numpy as np -import time -import json -import os -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.metrics import confusion_matrix - -try: - import pynvml - pynvml.nvmlInit() - GPU_AVAILABLE = True -except ImportError: - GPU_AVAILABLE = False - print("pynvml not installed, GPU memory monitoring disabled.") -from thop import profile - -# Compute metrics for evaluation -def compute_metrics(eval_pred, model_name, resolution): - logits, labels = eval_pred - predictions = np.argmax(logits, axis=-1) - acc = accuracy_score(labels, predictions) - f1 = f1_score(labels, predictions, average='weighted') - auc = roc_auc_score(labels, logits, multi_class='ovr') - - # Plot confusion matrix - conf_mat = confusion_matrix(labels, predictions) - fig, ax = plt.subplots(figsize=(10, 10)) - sns.heatmap(conf_mat, annot=True, cmap='Blues') - ax.set_xlabel('Predicted labels') - ax.set_ylabel('True labels') - ax.set_title(f'{model_name}_{resolution}_conf_mat') - plt.savefig(f'{model_name}_{resolution}_conf_mat.png', dpi=300, bbox_inches='tight') - plt.close() - - # Classification breakdown - unique, counts = np.unique(predictions, return_counts=True) - class_breakdown = dict(zip(unique, counts)) - with open(f'{model_name}_{resolution}_class_breakdown.json', 'w') as f: - json.dump(class_breakdown, f) - - return {'accuracy': acc, 'f1': f1, 'auc': auc} - -# Measure GPU memory usage -def get_gpu_memory(device_id=0): - if not GPU_AVAILABLE: - return -1 - try: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return mem_info.used / 1024**2 # MB - except: - return -1 - -# Main function for fine-tuning -def main(): - # Models and resolutions to compare - models = [ - {'name': 'vit', 'model_id': 'google/vit-base-patch16-224', 'type': 'vit'}, - ] - resolutions = [224, 112, 56] - - # Results storage - results = {model['name']: {} for model in models} - - # Load dataset - dataset = load_dataset("MKZuziak/ISIC_2019_224") - full_dataset = dataset['train'].train_test_split(test_size=0.2, stratify_by_column='label', seed=42) - train_dataset = full_dataset['train'] - val_dataset = full_dataset['test'] - - # Data augmentation - transform = transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomRotation(20), - transforms.ColorJitter(brightness=0.2, contrast=0.2), - ]) - - for model_info in models: - model_name = model_info['name'] - model_id = model_info['model_id'] - model_type = model_info['type'] - print(f"\nFine-tuning model: {model_name}") - - for resolution in resolutions: - print(f"Resolution: {resolution}x{resolution}") - - # Load preprocessor - if model_type == 'vit': - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) - else: - preprocessor = None # timm models use manual preprocessing - - # Create datasets - train_ds = ISICDataset(train_dataset, preprocessor, resolution, transform, model_type) - val_ds = ISICDataset(val_dataset, preprocessor, resolution, model_type=model_type) - - # Load model - if model_type == 'vit': - model = ViTForImageClassification.from_pretrained( - model_id, - num_labels=8, - ignore_mismatched_sizes=True - ) - else: - pass # Load other models as needed - - # Estimate FLOPs - input_tensor = torch.randn(1, 3, resolution, resolution) - try: - flops, _ = profile(model, inputs=(input_tensor,)) - flops = flops / 1e9 # GFLOPs - except: - flops = -1 # Fallback if FLOPs estimation fails - - # Training arguments - training_args = TrainingArguments( - output_dir=f'./results_{model_name}_{resolution}', - num_train_epochs=3, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, - warmup_steps=500, - weight_decay=0.01, - logging_dir=f'./logs_{model_name}_{resolution}', - logging_steps=10, - eval_strategy='epoch', - save_strategy='epoch', - load_best_model_at_end=True, - metric_for_best_model='accuracy', - ) - - # Initialize Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_ds, - eval_dataset=val_ds, - compute_metrics=lambda pred: compute_metrics(pred, model_name, resolution), - ) - - # Measure memory and time - start_time = time.time() - peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - - # Fine-tune - trainer.train() - - # Update peak memory - current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 - peak_memory = max(peak_memory, current_memory) - - # Evaluate - eval_start_time = time.time() - eval_results = trainer.evaluate() - eval_time = time.time() - eval_start_time - - # Total training time - train_time = time.time() - start_time - eval_time - - # Save model - model.save_pretrained(f'./finetuned_{model_name}_{resolution}') - if model_type == 'vit': - preprocessor.save_pretrained(f'./finetuned_{model_name}_{resolution}') - - # Store results - results[model_name][resolution] = { - 'peak_memory_mb': peak_memory, - 'flops_giga': flops, - 'train_time_seconds': train_time, - 'eval_time_seconds': eval_time, - 'eval_metrics': eval_results - } - print(f"Results for {model_name} at {resolution}x{resolution}: {results[model_name][resolution]}") - - # Save results to JSON - with open('results_metrics.json', 'w') as f: - json.dump(results, f, indent=4) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/example_config.yaml b/configs/example_config.yaml similarity index 100% rename from example_config.yaml rename to configs/example_config.yaml diff --git a/job_template.slurm b/jobs/job_template.slurm similarity index 100% rename from job_template.slurm rename to jobs/job_template.slurm diff --git a/requirements.txt.licence b/requirements.txt.licence new file mode 100644 index 0000000..176ed16 --- /dev/null +++ b/requirements.txt.licence @@ -0,0 +1 @@ +# MIT \ No newline at end of file diff --git a/download_unpack_isic2019.sh b/scripts/download_unpack_isic2019.sh similarity index 100% rename from download_unpack_isic2019.sh rename to scripts/download_unpack_isic2019.sh diff --git a/submit_from_config.sh b/scripts/submit_from_config.sh similarity index 76% rename from submit_from_config.sh rename to scripts/submit_from_config.sh index 2914198..ba7a062 100644 --- a/submit_from_config.sh +++ b/scripts/submit_from_config.sh @@ -39,7 +39,16 @@ for placeholder in "${!config_map[@]}"; do if [ -z "$value" ]; then echo "Warning: Config key '${config_map[$placeholder]}' is empty or missing" fi - sed -i '' "s|{{${placeholder}}}|${value}|g" "$JOB_FILE" + + # Escape characters that might break sed (like slashes or ampersands) + value_escaped=$(printf '%s\n' "$value" | sed 's/[&/\]/\\&/g') + + # Detect OS and use correct sed syntax + if [[ "$OSTYPE" == "darwin"* ]]; then + sed -i '' "s|{{${placeholder}}}|${value}|g" "$JOB_FILE" + else + sed -i "s|{{${placeholder}}}|${value}|g" "$JOB_FILE" + fi done # Submit the job diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluate_isic_results.py b/src/evaluation/evaluate_isic_results.py similarity index 54% rename from evaluate_isic_results.py rename to src/evaluation/evaluate_isic_results.py index fcd8e68..75f16cb 100644 --- a/evaluate_isic_results.py +++ b/src/evaluation/evaluate_isic_results.py @@ -2,9 +2,10 @@ Evaluates results from ISIC 2019 fine-tuning and linear probing experiments. Conducts paired t-tests to assess statistical significance across JPEG quality levels and model variations, requiring multiple runs. For single runs, skips t-tests and -summarizes performance. Evaluates performance under degradation settings (JPEG +summarizes perform ance. Evaluates performance under degradation settings (JPEG qualities 90, 50, 20). Generates plots and saves results to JSON. """ +# pylint: disable=broad-exception-caught import json import itertools @@ -26,11 +27,14 @@ def load_results( multiple runs or single run per condition. """ results = {} - for file, mode in [(finetune_file, "finetune"), (linear_probe_file, "linear_probe")]: + for file, mode in [ + (finetune_file, "finetune"), + (linear_probe_file, "linear_probe"), + ]: if not Path(file).exists(): print(f"Warning: {file} not found, skipping {mode} results.") continue - with open(file, "r") as f: + with open(file, "r", encoding="utf-8") as f: data = json.load(f) for model, qualities in data.items(): for quality, metrics in qualities.items(): @@ -48,113 +52,92 @@ def load_results( def paired_t_tests(metrics, models, qualities, modes): - """Conducts paired t-tests across JPEG qualities, models, and modes if multiple + """ + Conducts paired t-tests across JPEG qualities, models, and modes if multiple runs are available. """ t_test_results = [] - has_multiple_runs = any(len(runs) > 1 for runs in metrics.values()) + metric_names = ["accuracy", "f1", "auc"] + + def run_if_valid(k1, k2, label): + if ( + k1 in metrics and k2 in metrics and + len(metrics[k1]) == len(metrics[k2]) > 1 + ): + for metric in metric_names: + try: + v1 = [r[metric] for r in metrics[k1]] + v2 = [r[metric] for r in metrics[k2]] + stat, p = ttest_rel(v1, v2) + t_test_results.append({ + "comparison": label.format(metric=metric), + "statistic": stat, + "p_value": p, + }) + except Exception as e: + print(f"Error in t-test for {label.format(metric=metric)}: {e}") + has_multiple_runs = any(len(r) > 1 for r in metrics.values()) if not has_multiple_runs: - print( - "Warning: Only single-run metrics available. Skipping t-tests. " - "Multiple runs required for statistical significance testing." - ) + print("Warning: Only single-run metrics available. Skipping t-tests.") return t_test_results - # Within-model, across JPEG quality + # 1. Within-model: JPEG quality comparisons for model, mode in itertools.product(models, modes): for q1, q2 in itertools.combinations(qualities, 2): - key1 = (model, q1, mode) - key2 = (model, q2, mode) - if key1 not in metrics or key2 not in metrics: - continue - if len(metrics[key1]) != len(metrics[key2]): - print(f"Warning: Mismatched run counts for {key1} vs {key2}, skipping.") - continue - for metric in ["accuracy", "f1", "auc"]: - data1 = [run[metric] for run in metrics[key1]] - data2 = [run[metric] for run in metrics[key2]] - try: - stat, p = ttest_rel(data1, data2) - t_test_results.append({ - "comparison": f"{model}_{mode}_{metric}_jpeg{q1}_vs_jpeg{q2}", - "statistic": stat, - "p_value": p, - }) - except ValueError as e: - print(f"Error in t-test for {model}_{mode}_{metric}_jpeg{q1}_vs_jpeg{q2}: {e}") + run_if_valid( + (model, q1, mode), + (model, q2, mode), + f"{model}_{mode}" + "_{{metric}}_jpeg{q1}_vs_jpeg{q2}" + ) - # Across models, same JPEG quality and mode + # 2. Across models: same quality & mode for m1, m2 in itertools.combinations(models, 2): for quality, mode in itertools.product(qualities, modes): - key1 = (m1, quality, mode) - key2 = (m2, quality, mode) - if key1 not in metrics or key2 not in metrics: - continue - if len(metrics[key1]) != len(metrics[key2]): - print(f"Warning: Mismatched run counts for {key1} vs {key2}, skipping.") - continue - for metric in ["accuracy", "f1", "auc"]: - data1 = [run[metric] for run in metrics[key1]] - data2 = [run[metric] for run in metrics[key2]] - try: - stat, p = ttest_rel(data1, data2) - t_test_results.append({ - "comparison": f"{m1}_vs_{m2}_{mode}_{metric}_jpeg{quality}", - "statistic": stat, - "p_value": p, - }) - except ValueError as e: - print(f"Error in t-test for {m1}_vs_{m2}_{mode}_{metric}_jpeg{quality}: {e}") + run_if_valid( + (m1, quality, mode), + (m2, quality, mode), + f"{m1}_vs_{m2}_{mode}" + "_{{metric}}_jpeg{quality}" + ) - # Fine-tune vs. linear probe, same model and JPEG quality + # 3. Finetune vs. Linear Probe: same model & quality for model, quality in itertools.product(models, qualities): - key1 = (model, quality, "finetune") - key2 = (model, quality, "linear_probe") - if key1 not in metrics or key2 not in metrics: - continue - if len(metrics[key1]) != len(metrics[key2]): - print(f"Warning: Mismatched run counts for {key1} vs {key2}, skipping.") - continue - for metric in ["accuracy", "f1", "auc"]: - data1 = [run[metric] for run in metrics[key1]] - data2 = [run[metric] for run in metrics[key2]] - try: - stat, p = ttest_rel(data1, data2) - t_test_results.append({ - "comparison": f"{model}_finetune_vs_linear_probe_{metric}_jpeg{quality}", - "statistic": stat, - "p_value": p, - }) - except ValueError as e: - print( - f"Error in t-test for {model}_finetune_vs_linear_probe_{metric}_jpeg{quality}: {e}" - ) + run_if_valid( + (model, quality, "finetune"), + (model, quality, "linear_probe"), + f"{model}_finetune_vs_linear_probe" + "_{{metric}}_jpeg{quality}" + ) return t_test_results def summarize_performance(metrics): - """Summarizes performance metrics (mean, std) across models, qualities, and modes.""" + """ + Summarizes performance metrics (mean, std) across models, qualities, and modes. + """ summary = [] for (model, quality, mode), runs in metrics.items(): if not runs: continue + # Compute mean and std for each metric - accuracies = [run["accuracy"] for run in runs] - f1s = [run["f1"] for run in runs] - aucs = [run["auc"] for run in runs] - summary.append({ - "model": model, - "jpeg_quality": quality, - "mode": mode, - "accuracy_mean": np.mean(accuracies), - "accuracy_std": np.std(accuracies, ddof=1) if len(runs) > 1 else 0, - "f1_mean": np.mean(f1s), - "f1_std": np.std(f1s, ddof=1) if len(runs) > 1 else 0, - "auc_mean": np.mean(aucs), - "auc_std": np.std(aucs, ddof=1) if len(runs) > 1 else 0, - }) + accuracies = [run.get("accuracy") for run in runs if "accuracy" in run] + f1s = [run.get("f1") for run in runs if "f1" in run] + aucs = [run.get("auc") for run in runs if "auc" in run] + + summary.append( + { + "model": model, + "jpeg_quality": quality, + "mode": mode, + "accuracy_mean": np.mean(accuracies), + "accuracy_std": np.std(accuracies, ddof=1) if len(runs) > 1 else 0, + "f1_mean": np.mean(f1s), + "f1_std": np.std(f1s, ddof=1) if len(runs) > 1 else 0, + "auc_mean": np.mean(aucs), + "auc_std": np.std(aucs, ddof=1) if len(runs) > 1 else 0, + } + ) return pd.DataFrame(summary) @@ -219,9 +202,10 @@ def main(): return # Conduct paired t-tests (if multiple runs) - models = ["vit", "dinov2", "simclr"] - qualities = [90, 50, 20] - modes = ["finetune", "linear_probe"] + # Dynamic model/quality/mode extraction + models = sorted(set(k[0] for k in metrics)) + qualities = sorted(set(k[1] for k in metrics)) + modes = sorted(set(k[2] for k in metrics)) t_test_results = paired_t_tests(metrics, models, qualities, modes) # Summarize performance @@ -235,7 +219,7 @@ def main(): "t_test_results": t_test_results, "performance_summary": performance_df.to_dict(orient="records"), } - with open("evaluation_results.json", "w") as f: + with open("evaluation_results.json", "w", encoding="utf-8") as f: json.dump(results, f, indent=4) print("Evaluation complete. Results saved to 'evaluation_results.json'.") @@ -246,4 +230,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/finetune/baseline_finetuning.py b/src/finetune/baseline_finetuning.py new file mode 100644 index 0000000..6d43dcd --- /dev/null +++ b/src/finetune/baseline_finetuning.py @@ -0,0 +1,260 @@ +""" +This script fine-tunes a ViT model on the ISIC 2019 dataset with various resolutions. +It includes data augmentation, model evaluation, and GPU memory monitoring. + +Example run: +python finetune_isic.py \ + --dataset_name MKZuziak/ISIC_2019_224 \ + --resolutions 224 112 \ + --num_epochs 5 \ + --output_dir ./results_vit_isic +""" + +# TODO: pass the file paths as os.imports + +# Standard libraries +import json +import time + +# Third-party libraries +import numpy as np +import matplotlib.pyplot as plt +from scipy.special import softmax +import seaborn as sns +from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + f1_score, + roc_auc_score, +) +import torch +from torchvision import transforms +from transformers import ( + ViTForImageClassification, + ViTFeatureExtractor, + TrainingArguments, + Trainer, +) +from datasets import load_dataset + +# Profiling +from thop import profile + +# GPU memory monitoring via pynvml +try: + import pynvml + pynvml.nvmlInit() + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + print("pynvml not installed, GPU memory monitoring disabled.") + +import argparse +from pathlib import Path + +def parse_args(): + parser = argparse.ArgumentParser(description="Fine-tune ViT on ISIC 2019 dataset.") + parser.add_argument("--dataset_name", type=str, default="MKZuziak/ISIC_2019_224", + help="HuggingFace dataset name") + parser.add_argument("--output_dir", type=Path, default=Path("./results"), + help="Where to save models and results") + parser.add_argument("--resolutions", type=int, nargs="+", default=[224, 112, 56], + help="List of resolutions to fine-tune on") + parser.add_argument("--num_epochs", type=int, default=3, + help="Number of training epochs") + return parser.parse_args() + + +# Compute metrics for evaluation +def compute_metrics(eval_pred, model_name, resolution): + """ + Compute accuracy, F1 score, and AUC for the model predictions. + Also generates a confusion matrix and saves it as an image. + """ + logits, labels = eval_pred + predictions = np.argmax(logits, axis=-1) + acc = accuracy_score(labels, predictions) + f1 = f1_score(labels, predictions, average="weighted") + + probs = softmax(logits, axis=1) + auc = roc_auc_score(labels, probs, multi_class="ovr") + # You're computing AUC from logits, not probabilities. Use soft-max first. + # auc = roc_auc_score(labels, logits, multi_class="ovr") + + # Plot confusion matrix + conf_mat = confusion_matrix(labels, predictions) + _, ax = plt.subplots(figsize=(10, 10)) + sns.heatmap(conf_mat, annot=True, cmap="Blues") + ax.set_xlabel("Predicted labels") + ax.set_ylabel("True labels") + ax.set_title(f"{model_name}_{resolution}_conf_mat") + plt.savefig(f"{model_name}_{resolution}_conf_mat.png", dpi=300, bbox_inches="tight") + plt.close() + + # Classification breakdown + unique, counts = np.unique(predictions, return_counts=True) + class_breakdown = dict(zip(unique, counts)) + with open(f"{model_name}_{resolution}_class_breakdown.json", "w") as f: + json.dump(class_breakdown, f) + + return {"accuracy": acc, "f1": f1, "auc": auc} + + +# Measure GPU memory usage +def get_gpu_memory(device_id=0): + if not GPU_AVAILABLE: + return -1 + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return mem_info.used / 1024**2 # MB + except: + return -1 + + +# Main function for fine-tuning +def main(): + """ + Main function to fine-tune the model on the ISIC 2019 dataset. + """ + # Models and resolutions to compare + models = [ + {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit"}, + ] + resolutions = args.resolutions + + # Results storage + results = {model["name"]: {} for model in models} + + # Load dataset + dataset = load_dataset(args.dataset_name) + if dataset is None: + print(f"Dataset {args.dataset_name} not found.") + return + full_dataset = dataset["train"].train_test_split( + test_size=0.2, stratify_by_column="label", seed=42 + ) + train_dataset = full_dataset["train"] + val_dataset = full_dataset["test"] + + # Data augmentation + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(20), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + ] + ) + + for model_info in models: + model_name = model_info["name"] + model_id = model_info["model_id"] + model_type = model_info["type"] + print(f"\nFine-tuning model: {model_name}") + + for resolution in resolutions: + print(f"Resolution: {resolution}x{resolution}") + + # Load preprocessor + if model_type == "vit": + preprocessor = ViTFeatureExtractor.from_pretrained( + model_id, size=resolution + ) + else: + preprocessor = None # timm models use manual preprocessing + + # Create datasets + train_ds = ISICDataset( + train_dataset, preprocessor, resolution, transform, model_type + ) + val_ds = ISICDataset( + val_dataset, preprocessor, resolution, model_type=model_type + ) + + # Load model + if model_type == "vit": + model = ViTForImageClassification.from_pretrained( + model_id, num_labels=8, ignore_mismatched_sizes=True + ) + else: + pass # Load other models as needed + + # Estimate FLOPs + input_tensor = torch.randn(1, 3, resolution, resolution) + try: + flops, _ = profile(model, inputs=(input_tensor,)) + flops = flops / 1e9 # GFLOPs + except: + flops = -1 # Fallback if FLOPs estimation fails + + # Training arguments + training_args = TrainingArguments( + output_dir=args.output_dir / f"{model_name}_{resolution}", + num_train_epochs=args.num_epochs, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + warmup_steps=500, + weight_decay=0.01, + logging_dir=f"./logs_{model_name}_{resolution}", + logging_steps=10, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="accuracy", + ) + + # Initialize Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=lambda pred: compute_metrics( + pred, model_name, resolution + ), + ) + + # Measure memory and time + start_time = time.time() + peak_memory = get_gpu_memory() if GPU_AVAILABLE else -1 + + # Fine-tune + trainer.train() + + # Update peak memory + current_memory = get_gpu_memory() if GPU_AVAILABLE else -1 + peak_memory = max(peak_memory, current_memory) + + # Evaluate + eval_start_time = time.time() + eval_results = trainer.evaluate() + eval_time = time.time() - eval_start_time + + # Total training time + train_time = time.time() - start_time - eval_time + + # Save model + model.save_pretrained(f"./finetuned_{model_name}_{resolution}") + if model_type == "vit": + preprocessor.save_pretrained(f"./finetuned_{model_name}_{resolution}") + + # Store results + results[model_name][resolution] = { + "peak_memory_mb": peak_memory, + "flops_giga": flops, + "train_time_seconds": train_time, + "eval_time_seconds": eval_time, + "eval_metrics": eval_results, + } + print( + f"Results for {model_name} at {resolution}x{resolution}: {results[model_name][resolution]}" + ) + + # Save results to JSON + args.output_dir.mkdir(parents=True, exist_ok=True) + with open(args.output_dir / "results_metrics.json", "w", encoding="utf-8") as f: + json.dump(results, f, indent=4) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/model_comparison.py b/src/models/model_comparison.py similarity index 100% rename from model_comparison.py rename to src/models/model_comparison.py diff --git a/model_comparison_2.py b/src/models/model_comparison_2.py similarity index 100% rename from model_comparison_2.py rename to src/models/model_comparison_2.py From 23a90d67c86e917a1c04987ebbd3a7443e88d61d Mon Sep 17 00:00:00 2001 From: Vicky Bikia Date: Wed, 21 May 2025 19:24:09 -0700 Subject: [PATCH 2/2] Refactor and add --- README.md | 80 ++++++++ src/models/constants.py | 5 + src/models/model_comparison_2.py | 302 +++++++++++++++++++++++-------- 3 files changed, 307 insertions(+), 80 deletions(-) create mode 100644 README.md create mode 100644 src/models/constants.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..ef72269 --- /dev/null +++ b/README.md @@ -0,0 +1,80 @@ +# Finetuning Pretrained Models for Compressed Dermatology Image Analysis + +This project explores how compressed and degraded dermatology images (from the ISIC 2019 dataset) affect classification performance using pretrained vision models. It compares fine-tuning vs. linear probing across multiple JPEG quality levels. + +## Project Goals + +- Evaluate model robustness to image compression (JPEG 90/50/20) +- Compare pretrained models: ViT, DINOv2, and SimCLR +- Benchmark fine-tuning vs. linear probing +- Analyze FLOPs, GPU memory, and classification accuracy + +## Models + +- `ViT`: Vision Transformer from Hugging Face +- `DINOv2`: Self-supervised ViT from Meta +- `SimCLR`: Contrastive ResNet50 trained with linear classifier + +## Metrics Tracked + +- Accuracy, F1 Score, AUC +- FLOPs (GFLOPs) +- GPU memory usage +- Training and evaluation time + +## Project Structure + +``` +CS231N/ +├── configs/ +│ └── example_config.yaml # Configs for job submissions +│ +├── scripts/ # Lightweight utility or shell scripts +│ ├── download_unpack_isic2019.sh # Downloads and unpacks ISIC data +│ └── submit_from_config.sh # SLURM submission helper +│ +├── jobs/ # SLURM-related job definitions +│ └── job_template.slurm +│ +├── src/ # Source code, logically grouped +│ ├── __init__.py +│ ├── finetune/ # Fine-tuning workflows +│ │ └── baseline_finetuning.py +│ ├── evaluation/ # Evaluation + plotting +│ │ └── evaluate_isic_results.py +│ └── models/ # Model-related scripts +│ ├── model_comparison.py # Config file with constant strings +│ ├── model_comparison.py +│ └── model_comparison_2.py + +│ +├── results/ # Auto-generated results +│ ├── plots/ # Accuracy/f1/AUC plots +│ └── logs/ # Training logs or SLURM outputs +│ +├── requirements.txt +├── .gitignore +├── .github +└── README.md +``` + +## Quick Start + +1. Install requirements: + ```bash + pip install -r requirements.txt + ``` + +2. Run training: + ```bash + python train_models.py + ``` + +3. Run evaluation: + ```bash + python evaluate_isic_results.py + ``` + +## 📦 Dataset + +- [ISIC 2019 (Hugging Face)](https://huggingface.co/datasets/MKZuziak/ISIC_2019_224) \ No newline at end of file diff --git a/src/models/constants.py b/src/models/constants.py new file mode 100644 index 0000000..f28f8ae --- /dev/null +++ b/src/models/constants.py @@ -0,0 +1,5 @@ +HF_MODELS = ["vit", "dinov2"] +SSL_MODEL = "simclr" +ALL_MODEL_TYPES = HF_MODELS + SSL_MODELS +SIMCLR_BACKBONE = "resnet50" +NUM_CLASSES = 8 diff --git a/src/models/model_comparison_2.py b/src/models/model_comparison_2.py index 24ace8a..8a012ed 100644 --- a/src/models/model_comparison_2.py +++ b/src/models/model_comparison_2.py @@ -1,44 +1,85 @@ +# Environment Setup import os + +os.environ["WANDB_DISABLED"] = "true" # Disable Weights & Biases logging + +# Standard Libraries +import io +import json +import random +import time + +# Scientific & Visualization Libraries +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from PIL import Image + +# PyTorch & Torchvision import torch import torch.nn as nn from torch.utils.data import Dataset -from PIL import Image -from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer -from transformers import AutoModelForImageClassification, AutoImageProcessor -from transformers import TrainerCallback -from datasets import load_dataset, ClassLabel from torchvision import transforms -from sklearn.metrics import accuracy_score, f1_score, roc_auc_score -import numpy as np -import time -import json -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.metrics import confusion_matrix + +# Hugging Face Transformers & Datasets +from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + Trainer, + TrainerCallback, + TrainingArguments, + ViTFeatureExtractor, + ViTForImageClassification, +) +from datasets import load_dataset, ClassLabel + + +# Metrics +from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + f1_score, + roc_auc_score, +) + +# Model Profiling & Vision Backbones import timm -import random -import io -import os -os.environ["WANDB_DISABLED"] = "true" +from thop import profile + +# Local Application Imports +from constants import HF_MODELS, SSL_MODEL, SIMCLR_BACKBONE, NUM_CLASSES + + +# GPU Memory Monitoring (optional) try: import pynvml + pynvml.nvmlInit() GPU_AVAILABLE = True except ImportError: GPU_AVAILABLE = False print("pynvml not installed, GPU memory monitoring disabled.") -from thop import profile # Cache paths -os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers") -os.environ["HF_DATASETS_CACHE"] = os.getenv("HF_DATASETS_CACHE", "~/.cache/huggingface/datasets") +os.environ["TRANSFORMERS_CACHE"] = os.getenv( + "TRANSFORMERS_CACHE", "~/.cache/huggingface/transformers" +) +os.environ["HF_DATASETS_CACHE"] = os.getenv( + "HF_DATASETS_CACHE", "~/.cache/huggingface/datasets" +) os.environ["HF_HOME"] = os.getenv("HF_HOME", "~/.cache/huggingface") + def env_path(key, default): + """Get environment variable or default value.""" return os.environ.get(key, default) + class DegradationTransform: + """ + Applies random JPEG compression and Gaussian blur to an image.""" + def __init__(self, p=0.5): self.p = p @@ -63,9 +104,12 @@ def __call__(self, img): img = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(img) if random.random() < self.p: num_colors = random.randint(16, 64) - img = img.quantize(colors=num_colors, method=Image.Quantize.MAXCOVERAGE).convert('RGB') + img = img.quantize( + colors=num_colors, method=Image.Quantize.MAXCOVERAGE + ).convert("RGB") return img + class JPEGCompressionTransform: def __init__(self, quality): self.quality = quality @@ -78,8 +122,17 @@ def __call__(self, img): buffer.seek(0) return Image.open(buffer) + class ISICDataset(Dataset): - def __init__(self, dataset, preprocessor=None, resolution=224, transform=None, model_type="vit", jpeg_quality=None): + def __init__( + self, + dataset, + preprocessor=None, + resolution=224, + transform=None, + model_type="vit", + jpeg_quality=None, + ): self.dataset = dataset self.preprocessor = preprocessor self.resolution = resolution @@ -87,12 +140,15 @@ def __init__(self, dataset, preprocessor=None, resolution=224, transform=None, m self.model_type = model_type self.jpeg_quality = jpeg_quality if model_type == "simclr": - self.preprocessor = transforms.Compose([ - transforms.Resize((resolution, resolution)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) + self.preprocessor = transforms.Compose( + [ + transforms.Resize((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.dataset) @@ -111,7 +167,7 @@ def __getitem__(self, idx): if self.jpeg_quality is not None: image = JPEGCompressionTransform(self.jpeg_quality)(image) - if self.model_type in ["vit", "dinov2"]: + if self.model_type in HF_MODELS: encoding = self.preprocessor(images=image, return_tensors="pt") pixel_values = encoding["pixel_values"].squeeze(0) elif self.model_type == "simclr": @@ -122,6 +178,7 @@ def __getitem__(self, idx): label = torch.tensor(label, dtype=torch.long) return {"pixel_values": pixel_values, "labels": label} + def compute_metrics(eval_pred, model_name, jpeg_quality): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) @@ -130,7 +187,9 @@ def compute_metrics(eval_pred, model_name, jpeg_quality): probs = torch.softmax(torch.tensor(logits), dim=1).numpy() auc = roc_auc_score(labels, probs, multi_class="ovr") - plot_dir = os.path.join(env_path("PLOT_DIR", "."), model_name, f"jpeg_{jpeg_quality}") + plot_dir = os.path.join( + env_path("PLOT_DIR", "."), model_name, f"jpeg_{jpeg_quality}" + ) os.makedirs(plot_dir, exist_ok=True) conf_mat = confusion_matrix(labels, predictions) @@ -147,9 +206,9 @@ def compute_metrics(eval_pred, model_name, jpeg_quality): with open(os.path.join(plot_dir, "class_breakdown.json"), "w") as f: json.dump(class_breakdown, f) - return {"accuracy": acc, "f1": f1, "auc": auc} + def get_gpu_memory(device_id=0): if not GPU_AVAILABLE: return -1 @@ -160,6 +219,7 @@ def get_gpu_memory(device_id=0): except: return -1 + class SimCLRForClassification(nn.Module): def __init__(self, backbone, num_classes=8): super().__init__() @@ -172,10 +232,13 @@ def forward(self, pixel_values, labels=None): loss = None if labels is not None: loss = nn.CrossEntropyLoss()(logits, labels) - return {'logits': logits, 'loss': loss} if loss is not None else {'logits': logits} + return ( + {"logits": logits, "loss": loss} if loss is not None else {"logits": logits} + ) + def freeze_backbone(model, model_type): - if model_type in ["vit", "dinov2"]: + if model_type in HF_MODELS: for name, param in model.named_parameters(): if "classifier" not in name: param.requires_grad = False @@ -188,7 +251,6 @@ def freeze_backbone(model, model_type): raise ValueError(f"Unsupported model_type: {model_type}") - class LossLoggerCallback(TrainerCallback): """ Logs each training step's loss and other metrics to a structured JSON Lines file. @@ -196,7 +258,9 @@ class LossLoggerCallback(TrainerCallback): def __init__(self, log_dir: str, phase: str, model_name: str, jpeg_quality: int): os.makedirs(log_dir, exist_ok=True) - self.log_file = os.path.join(log_dir, f"{model_name}_jpeg{jpeg_quality}_{phase}_log.jsonl") + self.log_file = os.path.join( + log_dir, f"{model_name}_jpeg{jpeg_quality}_{phase}_log.jsonl" + ) def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: @@ -205,6 +269,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): json.dump({"step": state.global_step, **logs}, f) f.write("\n") + def main(): models = [ {"name": "vit", "model_id": "google/vit-base-patch16-224", "type": "vit"}, @@ -217,23 +282,35 @@ def main(): results = {m["name"]: {} for m in models} results_linear_probe = {m["name"]: {} for m in models} - dataset = load_dataset("MKZuziak/ISIC_2019_224", cache_dir=os.environ["HF_DATASETS_CACHE"]) + dataset = load_dataset( + "MKZuziak/ISIC_2019_224", cache_dir=os.environ["HF_DATASETS_CACHE"] + ) dataset = dataset.cast_column("label", ClassLabel(num_classes=8)) - full_dataset = dataset["train"].train_test_split(test_size=0.2, stratify_by_column="label", seed=42) + full_dataset = dataset["train"].train_test_split( + test_size=0.2, stratify_by_column="label", seed=42 + ) train_dataset, val_dataset = full_dataset["train"], full_dataset["test"] - transform = transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.RandomRotation(20), - transforms.ColorJitter(brightness=0.2, contrast=0.2), - DegradationTransform(p=0.5), - ]) + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(20), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + DegradationTransform(p=0.5), + ] + ) for model_info in models: - name, model_id, typ = model_info["name"], model_info["model_id"], model_info["type"] + name, model_id, typ = ( + model_info["name"], + model_info["model_id"], + model_info["type"], + ) if typ == "vit": - preprocessor = ViTFeatureExtractor.from_pretrained(model_id, size=resolution) + preprocessor = ViTFeatureExtractor.from_pretrained( + model_id, size=resolution + ) elif typ == "dinov2": preprocessor = AutoImageProcessor.from_pretrained(model_id, size=resolution) else: @@ -241,11 +318,17 @@ def main(): train_ds = ISICDataset(train_dataset, preprocessor, resolution, transform, typ) if typ == "vit": - model = ViTForImageClassification.from_pretrained(model_id, num_labels=8, ignore_mismatched_sizes=True) + model = ViTForImageClassification.from_pretrained( + model_id, num_labels=8, ignore_mismatched_sizes=True + ) elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained(model_id, num_labels=8, ignore_mismatched_sizes=True) + model = AutoModelForImageClassification.from_pretrained( + model_id, num_labels=8, ignore_mismatched_sizes=True + ) elif typ == "simclr": - model = SimCLRForClassification(timm.create_model("resnet50", pretrained=True, num_classes=0), 8) + model = SimCLRForClassification( + timm.create_model("resnet50", pretrained=True, num_classes=0), 8 + ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) @@ -275,15 +358,27 @@ def main(): ) for jpeg_quality in jpeg_qualities: - val_ds = ISICDataset(val_dataset, preprocessor, resolution, model_type=typ, jpeg_quality=jpeg_quality) + val_ds = ISICDataset( + val_dataset, + preprocessor, + resolution, + model_type=typ, + jpeg_quality=jpeg_quality, + ) trainer = Trainer( model=model, args=train_args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=lambda pred: compute_metrics(pred, name, jpeg_quality), - callbacks=[LossLoggerCallback(log_dir=os.environ["LOG_DIR"], phase="finetune", model_name=name, jpeg_quality=jpeg_quality)] - + callbacks=[ + LossLoggerCallback( + log_dir=os.environ["LOG_DIR"], + phase="finetune", + model_name=name, + jpeg_quality=jpeg_quality, + ) + ], ) # ---- TRAINING PHASE ---- @@ -299,22 +394,33 @@ def main(): eval_start_time = time.time() eval_results = trainer.evaluate() eval_time = time.time() - eval_start_time - train_time = time.time() - start_time - eval_time if jpeg_quality == jpeg_qualities[0] else 0 + train_time = ( + time.time() - start_time - eval_time + if jpeg_quality == jpeg_qualities[0] + else 0 + ) - model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_jpeg{jpeg_quality}") + model_dir = os.path.join( + env_path("MODEL_DIR", "."), f"{name}_jpeg{jpeg_quality}" + ) os.makedirs(model_dir, exist_ok=True) - if typ in ["vit", "dinov2"]: + if typ in HF_MODELS: model.save_pretrained(model_dir) preprocessor.save_pretrained(model_dir) - elif typ == "simclr": - torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) + elif typ == SSL_MODEL: + torch.save( + model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") + ) with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump({ - "model_type": "simclr", - "backbone": "resnet50", - "num_classes": 8 - }, f) + json.dump( + { + "model_type": SSL_MODEL, + "backbone": "resnet50", + "num_classes": 8, + }, + f, + ) results[name][jpeg_quality] = { "peak_memory_mb": peak_memory, @@ -324,14 +430,20 @@ def main(): "eval_metrics": eval_results, } - print(f"[Finetune] {name} @ JPEG {jpeg_quality}: {results[name][jpeg_quality]}") + print( + f"[Finetune] {name} @ JPEG {jpeg_quality}: {results[name][jpeg_quality]}" + ) # ---- LINEAR PROBE PHASE ---- if typ == "vit": - model = ViTForImageClassification.from_pretrained(model_id, num_labels=8, ignore_mismatched_sizes=True) + model = ViTForImageClassification.from_pretrained( + model_id, num_labels=8, ignore_mismatched_sizes=True + ) elif typ == "dinov2": - model = AutoModelForImageClassification.from_pretrained(model_id, num_labels=8, ignore_mismatched_sizes=True) - elif typ == "simclr": + model = AutoModelForImageClassification.from_pretrained( + model_id, num_labels=8, ignore_mismatched_sizes=True + ) + elif typ == SSL_MODEL: backbone = timm.create_model("resnet50", pretrained=True, num_classes=0) model = SimCLRForClassification(backbone, 8) @@ -339,13 +451,18 @@ def main(): freeze_backbone(model, typ) linear_args = TrainingArguments( - output_dir=os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), f"{name}_jpeg{jpeg_quality}_linear_probe"), + output_dir=os.path.join( + env_path("TRAIN_OUTPUT_DIR", "."), + f"{name}_jpeg{jpeg_quality}_linear_probe", + ), num_train_epochs=1, per_device_train_batch_size=16, per_device_eval_batch_size=16, warmup_steps=100, weight_decay=0.01, - logging_dir=os.path.join(env_path("LOG_DIR", "."), f"{name}_jpeg{jpeg_quality}_linear_probe"), + logging_dir=os.path.join( + env_path("LOG_DIR", "."), f"{name}_jpeg{jpeg_quality}_linear_probe" + ), logging_steps=1, eval_strategy="epoch", save_strategy="epoch", @@ -359,8 +476,14 @@ def main(): train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=lambda pred: compute_metrics(pred, name, jpeg_quality), - callbacks=[LossLoggerCallback(log_dir=os.environ["LOG_DIR"], phase="linear_probe", model_name=name, jpeg_quality=jpeg_quality)] - + callbacks=[ + LossLoggerCallback( + log_dir=os.environ["LOG_DIR"], + phase="linear_probe", + model_name=name, + jpeg_quality=jpeg_quality, + ) + ], ) start_time = time.time() @@ -374,20 +497,27 @@ def main(): eval_time = time.time() - eval_start_time train_time = time.time() - start_time - eval_time - model_dir = os.path.join(env_path("MODEL_DIR", "."), f"{name}_jpeg{jpeg_quality}_linear_probe") + model_dir = os.path.join( + env_path("MODEL_DIR", "."), f"{name}_jpeg{jpeg_quality}_linear_probe" + ) os.makedirs(model_dir, exist_ok=True) - if typ in ["vit", "dinov2"]: + if typ in HF_MODELS: model.save_pretrained(model_dir) preprocessor.save_pretrained(model_dir) - elif typ == "simclr": - torch.save(model.state_dict(), os.path.join(model_dir, "pytorch_model.bin")) + elif typ == SSL_MODEL: + torch.save( + model.state_dict(), os.path.join(model_dir, "pytorch_model.bin") + ) with open(os.path.join(model_dir, "config.json"), "w") as f: - json.dump({ - "model_type": "simclr", - "backbone": "resnet50", - "num_classes": 8 - }, f) + json.dump( + { + "model_type": SSL_MODEL, + "backbone": "resnet50", + "num_classes": 8, + }, + f, + ) results_linear_probe[name][jpeg_quality] = { "peak_memory_mb": peak_memory, @@ -397,12 +527,24 @@ def main(): "eval_metrics": eval_results, } - print(f"[LinearProbe] {name} @ JPEG {jpeg_quality}: {results_linear_probe[name][jpeg_quality]}") + print( + f"[LinearProbe] {name} @ JPEG {jpeg_quality}: {results_linear_probe[name][jpeg_quality]}" + ) - with open(os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_finetune.json"), "w") as f: + with open( + os.path.join( + env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_finetune.json" + ), + "w", + ) as f: json.dump(results, f, indent=4) - with open(os.path.join(env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_linear_probe.json"), "w") as f: + with open( + os.path.join( + env_path("TRAIN_OUTPUT_DIR", "."), "results_metrics_linear_probe.json" + ), + "w", + ) as f: json.dump(results_linear_probe, f, indent=4)