Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Better Cached Latents #123

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,40 +263,61 @@ def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
for model in model_list:
if model is not None: model.to(accelerator.device, dtype=weight_dtype)

def get_cached_latent_dir(c_dir):
from omegaconf import ListConfig

if isinstance(c_dir, str):
return os.path.abspath(c_dir) if c_dir is not None else None

if isinstance(c_dir, ListConfig):
c_dir = OmegaConf.to_object(c_dir)
return c_dir

return None

def handle_cache_latents(
should_cache,
output_dir,
train_dataloader,
train_batch_size,
vae,
cached_latent_dir=None,
shuffle=False
shuffle=False,
minimum_required_frames = 0
):

# Cache latents by storing them in VRAM.
# Speeds up training and saves memory by not encoding during the train loop.
if not should_cache: return None
vae.to('cuda', dtype=torch.float16)
vae.to('cuda', dtype=torch.float32)
vae.enable_slicing()

cached_latent_dir = (
os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
)
cached_latent_dir = get_cached_latent_dir(cached_latent_dir)

if cached_latent_dir is None:
cache_save_dir = f"{output_dir}/cached_latents"
os.makedirs(cache_save_dir, exist_ok=True)

for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
if batch['pixel_values'].shape[1] > 2 and batch['pixel_values'].shape[1] < minimum_required_frames:
print(f"""
Batch item at index {i} does not meet required minimum frames: {minimum_required_frames}.
Seeing this error means that some of your video lengths are too short, but training will continue.
"""
)
continue

save_name = f"cached_{i}"
full_out_path = f"{cache_save_dir}/{save_name}.pt"

pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
for k, v in batch.items(): batch[k] = v[0]

pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32)
batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
for k, v in batch.items():
batch[k] = v[0]

torch.save(batch, full_out_path)

del pixel_values
del batch

Expand All @@ -305,13 +326,31 @@ def handle_cache_latents(
else:
cache_save_dir = cached_latent_dir

# Convert string to list of strings for processing if we have more than.
cache_save_dir = (
[cache_save_dir] if not isinstance(cache_save_dir, list)
else
cache_save_dir
)

cached_dataset_list = []

for save_dir in cache_save_dir:
cached_dataset = CachedDataset(cache_dir=save_dir)
cached_dataset_list.append(cached_dataset)

if len(cached_dataset_list) > 1:
print(f"Found {len(cached_dataset_list)} cached datasets. Merging...")
new_cached_dataset = torch.utils.data.ConcatDataset(cached_dataset_list)
else:
new_cached_dataset = cached_dataset_list[0]

return torch.utils.data.DataLoader(
CachedDataset(cache_dir=cache_save_dir),
batch_size=train_batch_size,
shuffle=shuffle,
num_workers=0
)
new_cached_dataset,
batch_size=train_batch_size,
shuffle=shuffle,
num_workers=0,
)

def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
global already_printed_trainables
Expand Down Expand Up @@ -650,7 +689,9 @@ def main(
train_dataloader,
train_batch_size,
vae,
cached_latent_dir
cached_latent_dir,
shuffle=shuffle,
minimum_required_frames=kwargs.get("minimum_required_frames", 0)
)

if cached_data_loader is not None:
Expand Down