Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ To run the script, use the following command:
./batch-caption.py --glob "path/to/images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone."
```

This command will caption all the `.jpg` images in the specified directory using the provided prompt, writing `.txt` files alongside each image.
This command will caption all the `.jpg` images in the relative or absolute directory using the provided prompt, writing `.txt` files alongside each image.


## Command-Line Arguments

Expand Down
18 changes: 11 additions & 7 deletions scripts/batch-caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
"""
import argparse
import dataclasses
import glob as glob_module
import json
import logging
import os
import random
from pathlib import Path
from pathlib import Path, PurePath

import PIL.Image
import torch
Expand Down Expand Up @@ -291,25 +292,28 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp
prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"]))
else:
raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}")

if len(prompts) == 0:
raise ValueError("No prompts found in JSON file")

if sum(p.weight for p in prompts) <= 0.0:
raise ValueError("Prompt weights must sum to a positive number")

return prompts


def find_images(glob: str | None, filelist: str | Path | None) -> list[Path]:
if glob is None and filelist is None:
raise ValueError("Must specify either --glob or --filelist")

paths = []

if glob is not None:
paths.extend(Path(".").glob(glob))

if PurePath(glob.split('*')[0].split('?')[0]).is_absolute():
paths.extend(Path(p) for p in glob_module.glob(glob))
else:
paths.extend(Path(".").glob(glob))

if filelist is not None:
paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != ""))

Expand Down