diff --git a/train.py b/train.py index d5d9292..1ee4ca6 100644 --- a/train.py +++ b/train.py @@ -263,6 +263,18 @@ def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): for model in model_list: if model is not None: model.to(accelerator.device, dtype=weight_dtype) +def get_cached_latent_dir(c_dir): + from omegaconf import ListConfig + + if isinstance(c_dir, str): + return os.path.abspath(c_dir) if c_dir is not None else None + + if isinstance(c_dir, ListConfig): + c_dir = OmegaConf.to_object(c_dir) + return c_dir + + return None + def handle_cache_latents( should_cache, output_dir, @@ -270,33 +282,42 @@ def handle_cache_latents( train_batch_size, vae, cached_latent_dir=None, - shuffle=False + shuffle=False, + minimum_required_frames = 0 ): # Cache latents by storing them in VRAM. # Speeds up training and saves memory by not encoding during the train loop. if not should_cache: return None - vae.to('cuda', dtype=torch.float16) + vae.to('cuda', dtype=torch.float32) vae.enable_slicing() - cached_latent_dir = ( - os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None - ) + cached_latent_dir = get_cached_latent_dir(cached_latent_dir) if cached_latent_dir is None: cache_save_dir = f"{output_dir}/cached_latents" os.makedirs(cache_save_dir, exist_ok=True) for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): + if batch['pixel_values'].shape[1] > 2 and batch['pixel_values'].shape[1] < minimum_required_frames: + print(f""" + Batch item at index {i} does not meet required minimum frames: {minimum_required_frames}. + Seeing this error means that some of your video lengths are too short, but training will continue. + """ + ) + continue save_name = f"cached_{i}" full_out_path = f"{cache_save_dir}/{save_name}.pt" - pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) - batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) - for k, v in batch.items(): batch[k] = v[0] + pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float32) + batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) + for k, v in batch.items(): + batch[k] = v[0] + torch.save(batch, full_out_path) + del pixel_values del batch @@ -305,13 +326,31 @@ def handle_cache_latents( else: cache_save_dir = cached_latent_dir + # Convert string to list of strings for processing if we have more than. + cache_save_dir = ( + [cache_save_dir] if not isinstance(cache_save_dir, list) + else + cache_save_dir + ) + + cached_dataset_list = [] + + for save_dir in cache_save_dir: + cached_dataset = CachedDataset(cache_dir=save_dir) + cached_dataset_list.append(cached_dataset) + if len(cached_dataset_list) > 1: + print(f"Found {len(cached_dataset_list)} cached datasets. Merging...") + new_cached_dataset = torch.utils.data.ConcatDataset(cached_dataset_list) + else: + new_cached_dataset = cached_dataset_list[0] + return torch.utils.data.DataLoader( - CachedDataset(cache_dir=cache_save_dir), - batch_size=train_batch_size, - shuffle=shuffle, - num_workers=0 - ) + new_cached_dataset, + batch_size=train_batch_size, + shuffle=shuffle, + num_workers=0, + ) def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None): global already_printed_trainables @@ -650,7 +689,9 @@ def main( train_dataloader, train_batch_size, vae, - cached_latent_dir + cached_latent_dir, + shuffle=shuffle, + minimum_required_frames=kwargs.get("minimum_required_frames", 0) ) if cached_data_loader is not None: