From 8ba8bb585cfb68430231370729d21991f4bd1a03 Mon Sep 17 00:00:00 2001 From: Nils Petersohn Date: Wed, 31 Dec 2025 01:22:00 +0100 Subject: [PATCH] allow absolute paths for the glob arg of the batch-caption script --- scripts/README.md | 3 ++- scripts/batch-caption.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/scripts/README.md b/scripts/README.md index 7f84fdd..26d3e7d 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -10,7 +10,8 @@ To run the script, use the following command: ./batch-caption.py --glob "path/to/images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone." ``` -This command will caption all the `.jpg` images in the specified directory using the provided prompt, writing `.txt` files alongside each image. +This command will caption all the `.jpg` images in the relative or absolute directory using the provided prompt, writing `.txt` files alongside each image. + ## Command-Line Arguments diff --git a/scripts/batch-caption.py b/scripts/batch-caption.py index ca5e58e..06ed9a8 100755 --- a/scripts/batch-caption.py +++ b/scripts/batch-caption.py @@ -4,11 +4,12 @@ """ import argparse import dataclasses +import glob as glob_module import json import logging import os import random -from pathlib import Path +from pathlib import Path, PurePath import PIL.Image import torch @@ -291,25 +292,28 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"])) else: raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}") - + if len(prompts) == 0: raise ValueError("No prompts found in JSON file") - + if sum(p.weight for p in prompts) <= 0.0: raise ValueError("Prompt weights must sum to a positive number") - + return prompts def find_images(glob: str | None, filelist: str | Path | None) -> list[Path]: if glob is None and filelist is None: raise ValueError("Must specify either --glob or --filelist") - + paths = [] if glob is not None: - paths.extend(Path(".").glob(glob)) - + if PurePath(glob.split('*')[0].split('?')[0]).is_absolute(): + paths.extend(Path(p) for p in glob_module.glob(glob)) + else: + paths.extend(Path(".").glob(glob)) + if filelist is not None: paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != ""))