-
Notifications
You must be signed in to change notification settings - Fork 80
feat: add benchmark support to PrunaDataModule and implement PartiPrompts #502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
davidberenstein1957
wants to merge
8
commits into
main
Choose a base branch
from
feat/add-partiprompts-benchmark-to-pruna
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+207
−6
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
6db8f0b
feat: add benchmark support to PrunaDataModule and implement PartiPro…
davidberenstein1957 7c53c95
refactor: simplify benchmark system, extend PartiPrompts with subset …
davidberenstein1957 975adb3
fix: add Numpydoc parameter docs for BenchmarkInfo
davidberenstein1957 6b0f4f7
feat: add benchmark discovery functions and expand benchmark registry
davidberenstein1957 56f2167
fix: use correct metric names from MetricRegistry
davidberenstein1957 2157057
fix: address PR comments - category filter, unused imports, unused fu…
davidberenstein1957 c067b82
fix: remove shuffle from test-only datasets
davidberenstein1957 1a4e4de
fix: resolve linting TODO warning
davidberenstein1957 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from functools import partial | ||
| from typing import Any, Callable, Tuple | ||
|
|
||
|
|
@@ -97,8 +98,165 @@ | |
| {"img_size": 224}, | ||
| ), | ||
| "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), | ||
| "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), | ||
| "PartiPrompts": (setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}), | ||
| "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), | ||
| "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), | ||
| "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), | ||
| } | ||
|
|
||
|
|
||
| @dataclass | ||
| class BenchmarkInfo: | ||
| """ | ||
| Metadata for a benchmark dataset. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| Internal identifier for the benchmark. | ||
| display_name : str | ||
| Human-readable name for display purposes. | ||
| description : str | ||
| Description of what the benchmark evaluates. | ||
| metrics : list[str] | ||
| List of metric names used for evaluation. | ||
| task_type : str | ||
| Type of task the benchmark evaluates (e.g., 'text_to_image'). | ||
| subsets : list[str] | ||
| Optional list of benchmark subset names. | ||
| """ | ||
|
|
||
| name: str | ||
| display_name: str | ||
| description: str | ||
| metrics: list[str] | ||
| task_type: str | ||
| subsets: list[str] = field(default_factory=list) | ||
|
|
||
|
|
||
| benchmark_info: dict[str, BenchmarkInfo] = { | ||
| "PartiPrompts": BenchmarkInfo( | ||
| name="parti_prompts", | ||
| display_name="Parti Prompts", | ||
| description=( | ||
| "Over 1,600 diverse English prompts across 12 categories with 11 challenge aspects " | ||
| "ranging from basic to complex, enabling comprehensive assessment of model capabilities " | ||
| "across different domains and difficulty levels." | ||
| ), | ||
| metrics=["arniqa", "clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| subsets=[ | ||
| "Abstract", | ||
| "Animals", | ||
| "Artifacts", | ||
| "Arts", | ||
| "Food & Beverage", | ||
| "Illustrations", | ||
| "Indoor Scenes", | ||
| "Outdoor Scenes", | ||
| "People", | ||
| "Produce & Plants", | ||
| "Vehicles", | ||
| "World Knowledge", | ||
| "Basic", | ||
| "Complex", | ||
| "Fine-grained Detail", | ||
| "Imagination", | ||
| "Linguistic Structures", | ||
| "Perspective", | ||
| "Properties & Positioning", | ||
| "Quantity", | ||
| "Simple Detail", | ||
| "Style & Format", | ||
| "Writing & Symbols", | ||
| ], | ||
| ), | ||
| "DrawBench": BenchmarkInfo( | ||
| name="drawbench", | ||
| display_name="DrawBench", | ||
| description="A comprehensive benchmark for evaluating text-to-image generation models.", | ||
| metrics=["clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "GenAIBench": BenchmarkInfo( | ||
| name="genai_bench", | ||
| display_name="GenAI Bench", | ||
| description="A benchmark for evaluating generative AI models.", | ||
| metrics=["clip_score", "clipiqa", "sharpness"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "VBench": BenchmarkInfo( | ||
| name="vbench", | ||
| display_name="VBench", | ||
| description="A benchmark for evaluating video generation models.", | ||
| metrics=["clip_score"], | ||
| task_type="text_to_video", | ||
| ), | ||
| "COCO": BenchmarkInfo( | ||
| name="coco", | ||
| display_name="COCO", | ||
| description="Microsoft COCO dataset for image generation evaluation with real image-caption pairs.", | ||
| metrics=["fid", "clip_score", "clipiqa"], | ||
| task_type="text_to_image", | ||
| ), | ||
| "ImageNet": BenchmarkInfo( | ||
| name="imagenet", | ||
| display_name="ImageNet", | ||
| description="Large-scale image classification benchmark with 1000 classes.", | ||
| metrics=["accuracy"], | ||
| task_type="image_classification", | ||
| ), | ||
| "WikiText": BenchmarkInfo( | ||
| name="wikitext", | ||
| display_name="WikiText", | ||
| description="Language modeling benchmark based on Wikipedia articles.", | ||
| metrics=["perplexity"], | ||
| task_type="text_generation", | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| # NOTE: Functions kept for future from_benchmark implementation | ||
| def list_benchmarks(task_type: str | None = None) -> list[str]: | ||
| """ | ||
| List available benchmark names. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| task_type : str | None | ||
| Filter by task type (e.g., 'text_to_image', 'text_to_video'). | ||
| If None, returns all benchmarks. | ||
|
|
||
| Returns | ||
| ------- | ||
| list[str] | ||
| List of benchmark names. | ||
| """ | ||
| if task_type is None: | ||
| return list(benchmark_info.keys()) | ||
| return [name for name, info in benchmark_info.items() if info.task_type == task_type] | ||
|
|
||
|
|
||
| def get_benchmark_info(name: str) -> BenchmarkInfo: | ||
| """ | ||
| Get benchmark metadata by name. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| The benchmark name. | ||
|
|
||
| Returns | ||
| ------- | ||
| BenchmarkInfo | ||
| The benchmark metadata. | ||
|
|
||
| Raises | ||
| ------ | ||
| KeyError | ||
| If benchmark name is not found. | ||
| """ | ||
| if name not in benchmark_info: | ||
| available = ", ".join(benchmark_info.keys()) | ||
| raise KeyError(f"Benchmark '{name}' not found. Available: {available}") | ||
| return benchmark_info[name] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should tightly couple benchmarking datasets with metrics. I think benchmarks should have their datasets available as PrunaDataModules, and the metrics for the Benchmarks should be Pruna Metrics. This way we can give the user the flexibility to use whichever dataset with whichever metric they choose, how do you feel?