Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ yapf==0.40.2
matplotlib==3.9.2
pydantic==2.9.2
scikit-learn==1.5.2
termcolor==3.1.0
tiktoken==0.12.0
diskcache==5.6.3
azure-identity==1.25.1
flaml==2.3.6
gdown==5.2.0

open_clip_torch==2.29.0
Expand Down
92 changes: 92 additions & 0 deletions train_methods/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,98 @@ def __getitem__(self, i):
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
return example


class COGFDDataset(Dataset):
def __init__(
self,
data_dir: str,
tokenizer: CLIPTokenizer,
size: int=512,
center_crop=False,
use_pooler=False,
task_info=None,
concept_combination=None,
labels=None
):
self.use_pooler = use_pooler
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer

if task_info is None or len(task_info) != 2:
raise ValueError("task_info must be a list/tuple of length 2 containing [concept, theme]")

if concept_combination is None or len(concept_combination) == 0:
raise ValueError("concept_combination cannot be None or empty")

if labels is None or len(labels) == 0:
raise ValueError("labels cannot be None or empty")

if len(concept_combination) != len(labels):
raise ValueError(f"Length mismatch: concept_combination ({len(concept_combination)}) != labels ({len(labels)})")

self.instance_images_path = []
self.instance_prompt = []

p = Path(data_dir)
if not p.exists():
raise ValueError(f"Instance {p} images root doesn't exists.")

image_paths = list(p.iterdir())
if len(image_paths) == 0:
raise ValueError(f"No images found in {p}")

self.instance_images_path += image_paths

self.prompts = concept_combination
self.labels = labels

self.num_instance_images = len(self.instance_images_path)
self._length = len(self.prompts)

self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])

def __len__(self):
return self._length

def __getitem__(self, index) -> dict:
if index >= self._length:
raise IndexError(f"Index {index} out of range for dataset of length {self._length}")

example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
concept = self.prompts[index % self._length]
label = self.labels[index % self._length]

if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["concept"] = concept
example["label"] = label
example["instance_images"] = self.image_transforms(instance_image)

example["prompt_ids"] = self.tokenizer(
concept,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids

example["attention_mask"] = self.tokenizer(
concept,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).attention_mask

return example

class MCEDataset(Dataset):
def __init__(
self,
Expand Down
105 changes: 105 additions & 0 deletions train_methods/legacy_autogen/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pathlib import Path
from types import TracebackType
from typing import Any, Protocol, Self

import diskcache

class AbstractCache(Protocol):

def get(self, key: str, default: Any | None = None) -> Any | None:
...

def set(self, key: str, value: Any) -> None:
...

def close(self) -> None:
...

def __enter__(self) -> Self:
...

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
...

class DiskCache(AbstractCache):
def __init__(self, seed: str | int):
self.cache = diskcache.Cache(seed)

def get(self, key: str, default: Any | None = None) -> Any | None:
return self.cache.get(key, default)

def set(self, key: str, value: Any) -> None:
self.cache.set(key, value)

def close(self) -> None:
self.cache.close()

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()

class CacheFactory:
@staticmethod
def cache_factory(
seed: str | int,
cache_path_root: str = ".cache",
) -> AbstractCache:
path = Path(cache_path_root, str(seed))
return DiskCache(Path(".", path))

class Cache(AbstractCache):
ALLOWED_CONFIG_KEYS = [
"cache_seed",
"cache_path_root",
]

@staticmethod
def disk(cache_seed: str | int = 42, cache_path_root: str = ".cache") -> "Cache":
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})

def __init__(self, config: dict[str, Any]):
self.config = config
# Ensure that the seed is always treated as a string before being passed to any cache factory or stored.
self.config["cache_seed"] = str(self.config.get("cache_seed", 42))

# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
seed=self.config["cache_seed"],
cache_path_root=self.config.get("cache_path_root", ""),
)

def __enter__(self) -> "Cache":
return self.cache.__enter__()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
return self.cache.__exit__(exc_type, exc_value, traceback)

def get(self, key: str, default: Any | None = None) -> Any | None:
return self.cache.get(key, default)

def set(self, key: str, value: Any) -> None:
self.cache.set(key, value)

def close(self) -> None:
self.cache.close()
Loading