diff --git a/configs/my_config.yaml b/configs/my_config.yaml index 29450dc..edc50eb 100644 --- a/configs/my_config.yaml +++ b/configs/my_config.yaml @@ -7,7 +7,7 @@ train_data: preprocessed: True n_sample_frames: 8 width: 256 - height: 256 + height: 128 sample_start_idx: 0 sample_frame_rate: 15 use_random_start_idx: False @@ -16,12 +16,14 @@ train_data: single_video_path: "" single_video_prompt: "" + infinet_source_path: "video" + validation_data: prompt: "" sample_preview: True num_frames: 16 width: 256 - height: 256 + height: 128 num_inference_steps: 50 guidance_scale: 9 @@ -34,6 +36,7 @@ validation_steps: 100 trainable_modules: - "attn1.to_out" - "attn2.to_out" + - "infinet" seed: 64 mixed_precision: "fp16" use_8bit_adam: False # This seems to be incompatible at the moment. diff --git a/configs/video_folder.yaml b/configs/video_folder.yaml index d0529b1..a3b8628 100644 --- a/configs/video_folder.yaml +++ b/configs/video_folder.yaml @@ -30,6 +30,7 @@ trainable_modules: - "attn1" - "attn2" - "attn3" + - "infinet" seed: 1234 mixed_precision: "fp16" use_8bit_adam: False # This seems to be incompatible at the moment. diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..3c1529b --- /dev/null +++ b/eval.py @@ -0,0 +1,221 @@ +import argparse +import datetime +import logging +import inspect +import math +import os +import random +import gc +import subprocess +import tempfile + +from typing import Dict, Optional, Tuple, List + +import numpy as np +from omegaconf import OmegaConf + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms as T +import diffusers +import transformers + +from pkg_resources import resource_filename +from torchvision import transforms +from tqdm.auto import tqdm + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed + +from models.unet_3d_condition import UNet3DConditionModel +from diffusers.models import AutoencoderKL +from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention_processor import AttnProcessor2_0, Attention +from diffusers.models.attention import BasicTransformerBlock + +from transformers import CLIPTextModel, CLIPTokenizer +from utils.dataset import VideoDataset +from einops import rearrange, repeat + +already_printed_unet = False + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def create_logging(logging, logger, accelerator): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + +def accelerate_set_verbose(accelerator): + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + +def create_output_folders(output_dir, config): + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + out_dir = os.path.join(output_dir, f"train_{now}") + + os.makedirs(out_dir, exist_ok=True) + os.makedirs(f"{out_dir}/samples", exist_ok=True) + OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) + + return out_dir + + +def load_primary_models(pretrained_model_path): + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + + unet = UNet3DConditionModel() + + model_path = os.path.join(os.getcwd(), pretrained_model_path, 'unet', 'diffusion_pytorch_model.bin') + # Load the pretrained weights + pretrained_dict = torch.load( + model_path, + map_location=torch.device('cuda'), + ) + unet.load_state_dict(pretrained_dict, strict=False) + + has_pretrained_weights = False + + for k, v in pretrained_dict.items(): + if k.startswith('infinet'): + has_pretrained_weights = True + + if not has_pretrained_weights: + print('Pretrained Infinet not found, setting its weights to zeros') + unet.infinet._init_weights() + + unet.infinet.diffusion_depth = 1 + #unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + return noise_scheduler, tokenizer, text_encoder, vae, unet + + +def main(): + pretrained_model_path = "models/model_scope_diffusers" + # Load scheduler, tokenizer and models. + noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path) + + vae.to("cuda") + unet.to("cuda") + text_encoder.to("cuda") + + # Enable VAE slicing to save memory. + vae.enable_slicing() + + + #unet.eval() + #text_encoder.eval() + + pipeline = TextToVideoSDPipeline.from_pretrained( + pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet + ) + + pipeline.enable_xformers_memory_efficient_attention() + + diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = diffusion_scheduler + + prompt = "Couple walking on the beach" + os.makedirs("samples", exist_ok=True) + out_file = f"samples/eval_{prompt}.mp4" + + with torch.no_grad(): + video_frames = pipeline( + prompt, + width=512, + height=384, + num_frames=20, + num_inference_steps=50, + guidance_scale=7.5 + ).frames + video_path = export_to_video(video_frames, out_file) + + del pipeline + gc.collect() + +from PIL import Image +import cv2 +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None, fps: int = 8) -> str: + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + #fps = 8 + h, w, c = video_frames[0].shape + + os.makedirs(os.path.join(os.getcwd(), 'out'), exist_ok=True) + for i in range(len(video_frames)): +# Image.fromarray(video_frames[i]).save(os.path.join(os.getcwd(), 'out', f"frame_{i}.png")) + cv2.imwrite(os.path.join(os.getcwd(), 'out', + f"{i:06}.png"), video_frames[i]) + + # create a pipe for ffmpeg to write the video frames to + ffmpeg_pipe = subprocess.Popen( + [ + "ffmpeg", + "-y", # overwrite output file if it already exists + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(fps), + "-i", "-", + "-vcodec", "libx264", + "-preset", "medium", + "-crf", "23", + output_video_path, + ], + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # write each video frame to the ffmpeg pipe + for frame in video_frames: + ffmpeg_pipe.stdin.write(frame.tobytes()) + + # close the ffmpeg pipe and wait for it to finish writing the video file + ffmpeg_pipe.stdin.close() + ffmpeg_pipe.wait() + + return output_video_path + +if __name__ == "__main__": + main() + +def find_ffmpeg_binary(): + try: + import google.colab + return 'ffmpeg' + except: + pass + for package in ['imageio_ffmpeg', 'imageio-ffmpeg']: + try: + package_path = resource_filename(package, 'binaries') + files = [os.path.join(package_path, f) for f in os.listdir( + package_path) if f.startswith("ffmpeg-")] + files.sort(key=lambda x: os.path.getmtime(x), reverse=True) + return files[0] if files else 'ffmpeg' + except: + return 'ffmpeg' \ No newline at end of file diff --git a/models/unet_3d_blocks.py b/models/unet_3d_blocks.py old mode 100755 new mode 100644 index fa91387..e2e2c59 --- a/models/unet_3d_blocks.py +++ b/models/unet_3d_blocks.py @@ -19,6 +19,104 @@ from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_temporal import TransformerTemporalModel +class DoDBlock(nn.Module): + """ + A downconvolution layer with masked video latents + Gets the masked video latents (the first and the last frame) and makes a masked convolution + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: Always 3D, downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + dims=3, + depth=0, + out_channels=None, + padding=1, + is_up=False,): + super().__init__() + self.channels = channels + self.out_channels = min((out_channels or channels) * (2 ** depth), 1280) if not is_up else (out_channels or channels) + self.dims = dims + self.is_up = is_up + stride = 2**depth if dims != 3 else (1, 2**depth, 2**depth) # if depth is zero, the stride is 1 + + # Convolution block, which should be initialized with zero weights and biases + # (zero conv) + self.conv_w = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + + self.conv_b = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + + # Conv for masking + self.mask_conv_w = nn.Conv2d( + 1, # only black and white + self.out_channels, + 3, + stride=stride, + padding=padding) + + self.mask_conv_b = nn.Conv2d( + 1, # only black and white + self.out_channels, + 3, + stride=stride, + padding=padding) + + # h - hidden states, x_c - frame conditioning, x_m - masked video latents + def forward(self, h, x_c=None, x_m=None): + + # When no frame conditioning is provided (top DoD iteration) + # return the untouched hidden states + if x_c is None or x_m is None: + return h + + # Add image conditioning as linear operation + + #print('IS UP ', self.is_up) + #print('h', h.shape) + #print('xc', x_c.shape) + #print('xm', x_m.shape) + + #print('cw', self.conv_w.weight.shape) + + # get weights and biases from frame conditioning + # vid convolution (initialized with zero weights and biases at first) + x_c_w = self.conv_w(x_c) + x_c_b = self.conv_b(x_c) + + #print('xcw', x_c_w.shape) + #print('xcb', x_c_b.shape) + + h = x_c_w * h + x_c_b + h # uses hadamard product + + # Use masked video latents to mask the convolution + x_m_w = self.mask_conv_w(x_m) + x_m_b = self.mask_conv_b(x_m) + + h = x_m_w * h + x_m_b + h # uses hadamard product + + return h + + def _init_weights(self): + # Zero initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.constant_(m.weight, 0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Assign gradient checkpoint function to simple variable for readability. g_c = checkpoint.checkpoint @@ -119,6 +217,7 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + infinet=None, ): if down_block_type == "DownBlock3D": return DownBlock3D( @@ -132,6 +231,7 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: @@ -153,6 +253,7 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) raise ValueError(f"{down_block_type} does not exist.") @@ -175,6 +276,7 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + infinet=None, ): if up_block_type == "UpBlock3D": return UpBlock3D( @@ -188,6 +290,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: @@ -209,6 +312,7 @@ def get_up_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) raise ValueError(f"{up_block_type} does not exist.") @@ -373,6 +477,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + infinet=None, ): super().__init__() resnets = [] @@ -406,6 +511,17 @@ def __init__( out_channels, ) ) + + if infinet is not None: + infinet.input_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + len(infinet.input_blocks_injections), + out_channels, + is_up=False, + ) + ) + attentions.append( Transformer2DModel( out_channels // attn_num_head_channels, @@ -453,6 +569,9 @@ def forward( attention_mask=None, num_frames=1, cross_attention_kwargs=None, + dod_block=None, + x_c=None, + x_m=None, ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -462,6 +581,8 @@ def forward( ): if self.gradient_checkpointing: + # TODO: Infinet is not implemented here yet! + # so don't use it for now with gradient checkpointing on hidden_states = cross_attn_g_c( attn, temp_attn, @@ -477,6 +598,10 @@ def forward( else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -511,6 +636,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + infinet=None, ): super().__init__() resnets = [] @@ -540,6 +666,16 @@ def __init__( ) ) + if infinet is not None: + infinet.input_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, # dims + len(infinet.input_blocks_injections), + out_channels, + is_up=False, + ) + ) + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) @@ -554,7 +690,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None, num_frames=1): + def forward(self, hidden_states, temb=None, num_frames=1, dod_block=None, x_c=None, x_m=None,): output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): @@ -563,6 +699,9 @@ def forward(self, hidden_states, temb=None, num_frames=1): else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) output_states += (hidden_states,) @@ -597,6 +736,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + infinet=None, ): super().__init__() resnets = [] @@ -632,6 +772,19 @@ def __init__( out_channels, ) ) + + if infinet is not None: + #print(len(infinet.input_blocks_injections)) + #print(len(infinet.output_blocks_injections)) + infinet.output_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + max(0, 3 - len(infinet.output_blocks_injections)),#max(0, len(infinet.input_blocks_injections) - len(infinet.output_blocks_injections)), + (out_channels // 2**(len(infinet.output_blocks_injections)-1)) if len(infinet.output_blocks_injections) > 1 else out_channels, + is_up=True, + ) + ) + attentions.append( Transformer2DModel( out_channels // attn_num_head_channels, @@ -675,6 +828,9 @@ def forward( attention_mask=None, num_frames=1, cross_attention_kwargs=None, + dod_block=None, + x_c=None, + x_m=None, ): # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( @@ -686,6 +842,8 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.gradient_checkpointing: + # TODO: Infinet is not implemented here yet! + # so don't use it with gradient checkpointing hidden_states = cross_attn_g_c( attn, temp_attn, @@ -701,6 +859,10 @@ def forward( else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -731,10 +893,12 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + infinet=None, ): super().__init__() resnets = [] temp_convs = [] + self.gradient_checkpointing=False for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -761,6 +925,18 @@ def __init__( ) ) + if infinet is not None: + #print(len(infinet.input_blocks_injections)) + #print(len(infinet.output_blocks_injections)) + infinet.output_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + max(0, 3 - len(infinet.output_blocks_injections)),#max(0, len(infinet.input_blocks_injections) - len(infinet.output_blocks_injections)), + (out_channels // 2**(len(infinet.output_blocks_injections)-1)) if len(infinet.output_blocks_injections) > 1 else out_channels, + is_up=True, + ) + ) + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) @@ -769,7 +945,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1, dod_block=None, x_c=None, x_m=None,): for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -781,6 +957,9 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/models/unet_3d_condition.py b/models/unet_3d_condition.py old mode 100755 new mode 100644 index 8857cad..98d2ce8 --- a/models/unet_3d_condition.py +++ b/models/unet_3d_condition.py @@ -32,11 +32,34 @@ UpBlock3D, get_down_block, get_up_block, + DoDBlock, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Class to keep DiffusionOverDiffusion modules as a separate model +# with weights saveable as a detachable checkpoint +class InfiNet(nn.Module): + def __init__(self, in_channels): + super(InfiNet, self).__init__() + + self.in_channels = in_channels + + self.diffusion_depth = 0 # Placeholder, because it's not passable into the pipeline + + self.input_blocks_injections = nn.ModuleList() + self.output_blocks_injections = nn.ModuleList() + + def _init_weights(self): + # Zero initialization + for m in self.modules(): + if isinstance(m, DoDBlock): + m._init_weights() + if isinstance(m, nn.ModuleList): + for l in m: + if isinstance(l, DoDBlock): + l._init_weights() @dataclass class UNet3DConditionOutput(BaseOutput): @@ -103,11 +126,14 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, + use_infinet=True, ): super().__init__() self.sample_size = sample_size + self.use_infinet = use_infinet + # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( @@ -154,6 +180,10 @@ def __init__( self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) + if self.use_infinet: + # InfiNet insertion + self.infinet = InfiNet(in_channels) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -178,6 +208,7 @@ def __init__( attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=False, + infinet=self.infinet if self.use_infinet else None, ) self.down_blocks.append(down_block) @@ -230,6 +261,7 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=False, + infinet=self.infinet if self.use_infinet else None, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -248,6 +280,10 @@ def __init__( self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + + if use_infinet: + + self.infinet._init_weights() def set_attention_slice(self, slice_size): r""" @@ -331,6 +367,7 @@ def forward( down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, + #diffusion_depth: int = 1, #until diffusion_depth is in Diffusers, we'll have to set it up in class properties instead ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: @@ -397,15 +434,31 @@ def forward( emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + # If aiming for DiffusionOverDiffusion and have InfiNet enabled, keep the original video + # + its mask of the first and last frames + + if self.use_infinet and self.infinet.diffusion_depth > 0: + x_c = sample.clone().detach() + x_m = torch.zeros(x_c.shape[:1] + (1,) + x_c.shape[2:], dtype=x_c.dtype, device=x_c.device) + x_m[:, :, 0, :, :] = torch.ones_like(x_m[:, :, 0, :, :]) + x_m[:, :, -1, :, :] = torch.ones_like(x_m[:, :, -1, :, :]) + x_c = x_c.permute(0, 2, 1, 3, 4).reshape((x_c.shape[0] * num_frames, -1) + x_c.shape[3:]) + x_m = x_m.permute(0, 2, 1, 3, 4).reshape((x_m.shape[0] * num_frames, -1) + x_m.shape[3:]) + else: + x_c = None + x_m = None + # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) + # InfiNet TODO: do we have to add a module injection here as well? + sample = self.transformer_in(sample, num_frames=num_frames).sample # 3. down down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: + for i, downsample_block in enumerate(self.down_blocks): if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, @@ -414,9 +467,12 @@ def forward( attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, + dod_block=self.infinet.input_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames, dod_block=self.infinet.input_blocks_injections[i] if self.infinet is not None else None, x_c=x_c, x_m=x_m,) down_block_res_samples += res_samples @@ -467,6 +523,9 @@ def forward( attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, + dod_block=self.infinet.output_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) else: sample = upsample_block( @@ -475,6 +534,9 @@ def forward( res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, + dod_block=self.infinet.output_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) # 6. post-process diff --git a/requirements.txt b/requirements.txt index 4ffbc0d..30f44f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accelerate torch torchvision torchaudio +imageio_ffmpeg git+https://github.com/huggingface/diffusers.git transformers einops diff --git a/train.py b/train.py index 4b3796e..f070239 100644 --- a/train.py +++ b/train.py @@ -84,7 +84,27 @@ def load_primary_models(pretrained_model_path): tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") - unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + unet = UNet3DConditionModel() + + model_path = os.path.join(os.getcwd(), pretrained_model_path, 'unet', 'diffusion_pytorch_model.bin') + # Load the pretrained weights + pretrained_dict = torch.load( + model_path, + map_location=torch.device('cuda'), + ) + unet.load_state_dict(pretrained_dict, strict=False) + + has_pretrained_weights = False + + for k, v in pretrained_dict.items(): + if k.startswith('infinet'): + has_pretrained_weights = True + + if not has_pretrained_weights: + print('Pretrained Infinet not found, setting its weights to zeros') + print("It's expected when training for the first time") + unet.infinet._init_weights() return noise_scheduler, tokenizer, text_encoder, vae, unet @@ -222,7 +242,7 @@ def main( train_data: Dict, validation_data: Dict, validation_steps: int = 100, - trainable_modules: Tuple[str] = ("attn1", "attn2" ), + trainable_modules: Tuple[str] = ("attn1", "attn2", "infinet"), train_batch_size: int = 1, max_train_steps: int = 500, learning_rate: float = 5e-5, @@ -314,11 +334,13 @@ def main( num_training_steps=max_train_steps * gradient_accumulation_steps, ) + train_infinet = 'infinet' in trainable_modules if trainable_modules is not None else False + # Get the training dataset if train_data.pop("type", "regular") == "folder": - train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer) + train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer, train_infinet=train_infinet) else: - train_dataset = VideoDataset(**train_data, tokenizer=tokenizer) + train_dataset = VideoDataset(**train_data, tokenizer=tokenizer, train_infinet=train_infinet) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -378,6 +400,7 @@ def main( logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") + logger.info(f" InfiNet training = {train_infinet}") global_step = 0 first_epoch = 0 @@ -391,6 +414,9 @@ def finetune_unet(batch, train_encoder=False): #noise_scheduler.beta_schedule = "squaredcos_cap_v2" unet.train() + + # Set up diffusion depth for infinet training + unet.infinet.diffusion_depth = batch["diffusion_depth"] # Convert videos to latent space pixel_values = batch["pixel_values"].to(weight_dtype) diff --git a/utils/chops_to_folder_dataset.py b/utils/chops_to_folder_dataset.py new file mode 100644 index 0000000..1861662 --- /dev/null +++ b/utils/chops_to_folder_dataset.py @@ -0,0 +1,102 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +import shutil +from pathlib import Path +from PIL import Image + +def write_as_video(output_filename, video_frames, overwrite_dims, width, height, fps): + if overwrite_dims: + height, width, _ = video_frames[0].shape + #fourcc = cv2.VideoWriter_fourcc(*'mp4v') + #out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + #for j in video_frames: + # out.write(j) + video_frames = [Image.fromarray(cv2.cvtColor(j, cv2.COLOR_BGR2RGB)) for j in video_frames] + if overwrite_dims: + video_frames = [j.resize(width, height) for j in video_frames] + + video_frames[0].save(output_filename, save_all=True, append_images=video_frames[1:]) + + #out.release() + +def read_first_frame(video_path): + patience = 5 + p = 0 + video = cv2.VideoCapture(video_path) + ret = False + while not ret: + ret, frame = video.read() + p += 1 + if p > patience: + raise Exception(f'Cannot read the video at {video_path}') + video.release() + return frame + +def get_fps(video_path): + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + video.release() + return fps + +def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite_fps, fps): + + folder_dataset_path = os.path.join(init_path, 'folder_dataset') + os.mkdir(folder_dataset_path) + depth_name = init_path + + t_counter=0 + for d in range(0, depth): + for j in range(L**(d-1) if d > 1 else 1): + for i in range(L if d > 0 else 1): + t_counter+=1 + tq = tqdm(total=t_counter) + + for d in range(0, depth): + depth_name = os.path.join(depth_name, f'depth_{d}') + for j in range(L**(d-1) if d > 1 else 1): + part_path = os.path.join(depth_name, f'part_{j}') + # sample the text info for the next subset + for i in range(L if d > 0 else 1): + txt_path = os.path.join(part_path, f'subset_{i}.txt') + + # go to the subset for video frames sampling + next_depth_name = os.path.join(depth_name, f'depth_{d+1}') + next_part_path = os.path.join(next_depth_name, f'part_{i+L*j}') # `i` cause we want to sample each corresponding *subset* + + # depths > 0 are *guaranteed* to have L videos in their part_j folders + + # now sampling each first frame at the next level + L_frames = [read_first_frame(os.path.join(next_part_path, f'subset_{k}.mp4')) for k in range(L)] + + # write all the L sampled frames to an mp4 in the folder dataset + if overwrite_fps: + fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) + + write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.gif'), L_frames, overwrite_dims, width, height, fps) + shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.txt')) + + tq.set_description(f'Depth {d}, part {j}, subset{i}') + #tq.set_description(os.path.join(next_part_path, f'subset_{0}.mp4')) + tq.update(1) + + tq.close() + +def main(): + parser = argparse.ArgumentParser(description="Convert the chopped labeled tree-like data into a FolderDataset") + parser.add_argument("outpath", help="Path where to save the end FolderDataset", default=os.getcwd()) + parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--D", help="Tree depth") + parser.add_argument("--overwrite_dims", help="Preserve the original video dims", action="store_true") + parser.add_argument("--w", help="Output video width", default=384) + parser.add_argument("--h", help="Output video height", default=256) + parser.add_argument("--overwrite_fps", help="Preserve the original video fps", action="store_true") + parser.add_argument("--fps", help="Output video fps", default=12) + args = parser.parse_args() + move_the_files(args.outpath, int(args.L), int(args.D), bool(args.overwrite_dims), int(args.w), int(args.h), bool(args.overwrite_fps), int(args.fps)) + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/utils/dataset.py b/utils/dataset.py index 8860f25..c7b6154 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -1,4 +1,7 @@ import os +import re +import secrets + import decord import numpy as np import random @@ -26,16 +29,20 @@ def __init__( preprocessed: bool = False, single_video_path: str = "", single_video_prompt: str = "", + train_infinet = False, + infinet_source_path: str = "", **kwargs ): - self.tokenizer = tokenizer self.preprocessed = preprocessed self.single_video_path = single_video_path self.single_video_prompt = single_video_prompt - self.train_data = self.load_from_json(json_path) + if not train_infinet: + self.train_data = self.load_from_json(json_path) + else: + self.train_data = infinet_source_path self.vid_data_key = vid_data_key self.sample_iters = 0 self.original_start_idx = sample_start_idx @@ -49,6 +56,60 @@ def __init__( self.sample_frame_rate = sample_frame_rate self.sample_frame_rate_init = sample_frame_rate + self.train_infinet = train_infinet + self.depth_data = {} + + if self.train_infinet: + self.init_infinet_video_data() + self.depth = None + + def init_infinet_video_data(self): + print("Initializing InfiNet Dataset") + self.depths, self.train_data = self.count_folders(self.train_data) + self.infindex = 0 + self.prev_depth = 0 + + def count_folders(self, path, depth_counter=0, parts_counter=None): + if parts_counter is None: + parts_counter = {} + + for entry in os.scandir(path): + if entry.is_dir(): + if entry.name.startswith("depth"): + depth_counter += 1 + current_depth = entry.name + current_depth = int(re.search(r'\d+', current_depth).group()) + if current_depth not in parts_counter: + parts_counter[current_depth] = {'depth':current_depth, 'part_count': 0, 'mp4_data': []} + depth_counter, parts_counter = self.count_folders(entry.path, depth_counter, parts_counter) + elif entry.name.startswith("part"): + parent_depth = os.path.basename(os.path.dirname(entry.path)) + if parent_depth.startswith("depth"): + parent_depth = int(re.search(r'\d+', parent_depth).group()) + parts_counter[parent_depth]['part_count'] += 1 + mp4_data = self.get_mp4_and_txt_data(entry.path) + parts_counter[parent_depth]['mp4_data'].extend(mp4_data) + else: + depth_counter, parts_counter = self.count_folders(entry.path, depth_counter, parts_counter) + + return depth_counter, parts_counter + + def get_mp4_and_txt_data(self, path): + mp4_data = [] + for entry in os.scandir(path): + if entry.is_file() and entry.name.lower().endswith(".mp4"): + mp4_path = entry.path + txt_path = os.path.splitext(mp4_path)[0] + ".txt" + + if os.path.exists(txt_path): + with open(txt_path, "r") as txt_file: + txt_content = txt_file.read() + else: + txt_content = None + + mp4_data.append({'mp4_path': mp4_path, 'txt_content': txt_content}) + + return mp4_data def load_from_json(self, path): # Don't load a JSON file if we're doing single video training if os.path.exists(self.single_video_path): return @@ -89,7 +150,25 @@ def get_prompt_ids(self, prompt): def get_frame_range(self, idx, vr): return list(range(idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] - + + def get_infinet_frame_range(self, current_depth, vr): + max_depths = self.depths + min_frames = 2 + max_frames = self.n_sample_frames + + # Calculate the number of frames to be sampled at the current depth + num_frames = min_frames + int((max_frames - min_frames) * (current_depth / (max_depths - 1))) + + # Limit the number of frames to the total number of frames in the video + num_frames = min(num_frames, len(vr)) + + # Calculate the step size between frames + step = (len(vr) - 1) // (num_frames - 1) + + # Generate the frame indices + frame_range = [idx for idx in range(0, len(vr), step)][:num_frames] + + return frame_range def get_sample_idx(self, idx, vr): # Get the frame idx range based on the get_vid_idx function # We have a fallback here just in case we the frame cannot be read @@ -107,7 +186,7 @@ def get_sample_idx(self, idx, vr): def get_vid_idx(self, vr, vid_data=None): - if self.use_random_start_idx and self.n_sample_frames == 1: + if self.use_random_start_idx and self.n_sample_frames == 1 and not self.train_infinet: # Randomize the frame rate at different speeds self.sample_frame_rate = random.randint(1, self.sample_frame_rate_init) @@ -119,6 +198,7 @@ def get_vid_idx(self, vr, vid_data=None): idx = random.randint(1, max_frame) else: + if vid_data is not None: idx = vid_data['frame_index'] else: @@ -128,7 +208,10 @@ def get_vid_idx(self, vr, vid_data=None): def __len__(self): if self.train_data is not None: - return len(self.train_data['data']) + try: + return len(self.train_data['data']) + except: + return 1 else: return 1 @@ -139,8 +222,36 @@ def __getitem__(self, index): prompt = None prompt_ids = None + if self.train_infinet and self.depth != None: + + print("Using Depth:",self.depth) + + depth = self.depth + + parts = int(self.train_data[depth]["part_count"]) + if depth != self.prev_depth or self.infindex > parts: + self.infindex = 0 + self.prev_depth = depth + if self.use_random_start_idx: + self.infindex = secrets.randbelow(parts) + + train_data = self.train_data[depth]["mp4_data"][self.infindex]["mp4_path"] + vr = decord.VideoReader(train_data, width=self.width, height=self.height) + + sample_index = self.get_infinet_frame_range(depth, vr) + + # Process video and rearrange + video = vr.get_batch(sample_index) + video = rearrange(video, "f h w c -> f c h w") + + prompt = self.train_data[depth]["mp4_data"][self.infindex]["txt_content"] + prompt_ids = self.get_prompt_ids(prompt) + + self.infindex += 1 + + # Check if we're doing single video training - if os.path.exists(self.single_video_path): + elif not self.train_infinet and os.path.exists(self.single_video_path): train_data = self.single_video_path # Load and sample video frames @@ -195,7 +306,8 @@ def __getitem__(self, index): example = { "pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], - "text_prompt": prompt + "text_prompt": prompt, + "diffusion_depth": depth, } return example @@ -210,6 +322,7 @@ def __init__( fps: int = 8, path: str = "./data", fallback_prompt: str = "", + train_infinet=False, **kwargs ): self.tokenizer = tokenizer @@ -223,6 +336,7 @@ def __init__( self.n_sample_frames = n_sample_frames self.fps = fps + self.train_infinet = train_infinet def get_prompt_ids(self, prompt): return self.tokenizer( @@ -237,7 +351,8 @@ def __len__(self): return len(self.video_files) def __getitem__(self, index): - vr = decord.VideoReader(self.video_files[index], width=self.width, height=self.height) + vid_filename = self.video_files[index] + vr = decord.VideoReader(vid_filename, width=self.width, height=self.height) native_fps = vr.get_avg_fps() every_nth_frame = round(native_fps / self.fps) @@ -253,12 +368,18 @@ def __getitem__(self, index): video = vr.get_batch(idxs) video = rearrange(video, "f h w c -> f c h w") - if os.path.exists(self.video_files[index].replace(".mp4", ".txt")): - with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f: + if os.path.exists(vid_filename.replace(".mp4", ".txt")): + with open(vid_filename.replace(".mp4", ".txt"), "r") as f: prompt = f.read() else: prompt = self.fallback_prompt + + # TODO: use regex + depth = 0 + if vid_filename.startswith('depth_'): + depth = int(vid_filename[len('depth_'):vid_filename[len('depth_'):].index('_')]) prompt_ids = self.get_prompt_ids(prompt) - return {"pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt} + return {"pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, "diffusion_depth":depth} + diff --git a/utils/video_chop.py b/utils/video_chop.py new file mode 100644 index 0000000..4b2b33d --- /dev/null +++ b/utils/video_chop.py @@ -0,0 +1,120 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +from pathlib import Path + +def chop_video(video_path: str, folder:str, L: int, start_frame:int) -> int: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - start_frame + + # Calculate the maximum depth level + max_depth = 0 + while L ** (max_depth) <= total_frames: + max_depth += 1 + + max_depth = max_depth - 1 + + total_frames = L**max_depth + + video_frames = [] + for _ in tqdm(range(start_frame), desc='Reading dummy frames'): + _, _ = video.read() + + for _ in tqdm(range(total_frames), desc='Reading video frames'): + ret, frame = video.read() + if ret: + video_frames.append(frame) + + dir_name = folder#Path(video_path).stem + #dir_name = os.path.join(folder, dir_name) + + for curr_depth in range(max_depth): + num_splits = L ** curr_depth + frames_per_split = total_frames // num_splits + if dir_name == "": + dir_name = os.path.join(f"depth_{curr_depth}") + else: + dir_name = os.path.join(dir_name, f"depth_{curr_depth}") + os.makedirs(dir_name, exist_ok=True) + + for i in tqdm(range(num_splits), desc=f'Depth {curr_depth}'): + + os.makedirs(os.path.join(dir_name, f"part_{i//L}"), exist_ok=True) + output_filename = f"{dir_name}/part_{i//L}/subset_{i%L}.mp4" + height, width, _ = video_frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + + start_index = i * frames_per_split + end_index = (i + 1) * frames_per_split + + print(f'start_index: {start_index}') + print(f'end_index: {end_index}') + + for j in tqdm(range(start_index, end_index), desc=f'Subset {i}, {len(range(start_index, end_index))} Frames'): + out.write(video_frames[j]) + + out.release() + # create a txt file alongside the video + with open(f"{dir_name}/part_{i//L}/subset_{i%L}.txt", "w") as f: + f.write(f"") + + video.release() + return total_frames + +def stuff(video_path: str, L: int, only_once = True): + + only_once = True + + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + cur_dir_name = os.path.split(video_path)[0]#Path(video_path).stem + orig_name = Path(video_path).stem + #dir_name = cur_dir_name + #os.mkdir(dir_name) + vid_name = os.path.split(video_path)[1] + + scenario = 0 + start_frame = 0 + + while start_frame < total_frames - L: + + dir_name = f"scenario_{scenario}" + os.mkdir(os.path.join(cur_dir_name, dir_name)) + video_path_new = os.path.join(cur_dir_name, dir_name, vid_name) + os.rename(video_path, video_path_new) + video_path = video_path_new + start_frame += chop_video(video_path, dir_name, L, start_frame) + scenario += 1 + + if only_once: + break + + os.rename(video_path, os.path.join(os.getcwd(), vid_name)) + os.mkdir(orig_name) + + for i in os.listdir(os.getcwd()): + if i.startswith('scenario_'): + os.rename(i, os.path.join(orig_name, i)) + +def main(): + parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") + parser.add_argument("video_file", help="Path to the video file.") + parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--subscenariosplit", help="Should it split ", action='store_true', default=False) + args = parser.parse_args() + stuff(args.video_file, int(args.L), bool(args.subscenariosplit != None and args.subscenariosplit)) + +if __name__ == "__main__": + main()