From ceffb82247e1a3a9569e6f6828fca10ad9b91c94 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 2 Apr 2023 13:30:25 +0300 Subject: [PATCH 01/12] add infinet training --- configs/my_config.yaml | 1 + eval.py | 213 ++++++++++++++++++++++++++++++++++++ models/unet_3d_blocks.py | 183 ++++++++++++++++++++++++++++++- models/unet_3d_condition.py | 66 ++++++++++- requirements.txt | 1 + train.py | 20 +++- utils/dataset.py | 12 +- 7 files changed, 487 insertions(+), 9 deletions(-) create mode 100644 eval.py diff --git a/configs/my_config.yaml b/configs/my_config.yaml index 29450dc..94dccd7 100644 --- a/configs/my_config.yaml +++ b/configs/my_config.yaml @@ -34,6 +34,7 @@ validation_steps: 100 trainable_modules: - "attn1.to_out" - "attn2.to_out" + - "infinet" seed: 64 mixed_precision: "fp16" use_8bit_adam: False # This seems to be incompatible at the moment. diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..12e3813 --- /dev/null +++ b/eval.py @@ -0,0 +1,213 @@ +import argparse +import datetime +import logging +import inspect +import math +import os +import random +import gc +import subprocess +import tempfile + +from typing import Dict, Optional, Tuple, List + +import numpy as np +from omegaconf import OmegaConf + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms as T +import diffusers +import transformers + +from pkg_resources import resource_filename +from torchvision import transforms +from tqdm.auto import tqdm + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed + +from models.unet_3d_condition import UNet3DConditionModel +from diffusers.models import AutoencoderKL +from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention_processor import AttnProcessor2_0, Attention +from diffusers.models.attention import BasicTransformerBlock + +from transformers import CLIPTextModel, CLIPTokenizer +from utils.dataset import VideoDataset +from einops import rearrange, repeat + +already_printed_unet = False + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def create_logging(logging, logger, accelerator): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + +def accelerate_set_verbose(accelerator): + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + +def create_output_folders(output_dir, config): + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + out_dir = os.path.join(output_dir, f"train_{now}") + + os.makedirs(out_dir, exist_ok=True) + os.makedirs(f"{out_dir}/samples", exist_ok=True) + OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) + + return out_dir + + +def load_primary_models(pretrained_model_path): + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + + unet = UNet3DConditionModel() + + model_path = os.path.join(os.getcwd(), pretrained_model_path, 'unet', 'diffusion_pytorch_model.bin') + # Load the pretrained weights + pretrained_dict = torch.load( + model_path, + map_location=torch.device('cuda'), + ) + unet.load_state_dict(pretrained_dict, strict=False) + + unet.infinet._init_weights() + + unet.infinet.diffusion_depth = 1 + #unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + return noise_scheduler, tokenizer, text_encoder, vae, unet + + +def main(): + pretrained_model_path = "models/model_scope_diffusers" + # Load scheduler, tokenizer and models. + noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path) + + vae.to("cuda") + unet.to("cuda") + text_encoder.to("cuda") + + # Enable VAE slicing to save memory. + vae.enable_slicing() + + + #unet.eval() + #text_encoder.eval() + + pipeline = TextToVideoSDPipeline.from_pretrained( + pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet + ) + + pipeline.enable_xformers_memory_efficient_attention() + + diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = diffusion_scheduler + + prompt = "Couple walking on the beach" + os.makedirs("samples", exist_ok=True) + out_file = f"samples/eval_{prompt}.mp4" + + with torch.no_grad(): + video_frames = pipeline( + prompt, + width=512, + height=384, + num_frames=20, + num_inference_steps=50, + guidance_scale=7.5 + ).frames + video_path = export_to_video(video_frames, out_file) + + del pipeline + gc.collect() + +from PIL import Image +import cv2 +def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None, fps: int = 8) -> str: + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name + + #fps = 8 + h, w, c = video_frames[0].shape + + os.makedirs(os.path.join(os.getcwd(), 'out'), exist_ok=True) + for i in range(len(video_frames)): +# Image.fromarray(video_frames[i]).save(os.path.join(os.getcwd(), 'out', f"frame_{i}.png")) + cv2.imwrite(os.path.join(os.getcwd(), 'out', + f"{i:06}.png"), video_frames[i]) + + # create a pipe for ffmpeg to write the video frames to + ffmpeg_pipe = subprocess.Popen( + [ + "ffmpeg", + "-y", # overwrite output file if it already exists + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(fps), + "-i", "-", + "-vcodec", "libx264", + "-preset", "medium", + "-crf", "23", + output_video_path, + ], + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # write each video frame to the ffmpeg pipe + for frame in video_frames: + ffmpeg_pipe.stdin.write(frame.tobytes()) + + # close the ffmpeg pipe and wait for it to finish writing the video file + ffmpeg_pipe.stdin.close() + ffmpeg_pipe.wait() + + return output_video_path + +if __name__ == "__main__": + main() + +def find_ffmpeg_binary(): + try: + import google.colab + return 'ffmpeg' + except: + pass + for package in ['imageio_ffmpeg', 'imageio-ffmpeg']: + try: + package_path = resource_filename(package, 'binaries') + files = [os.path.join(package_path, f) for f in os.listdir( + package_path) if f.startswith("ffmpeg-")] + files.sort(key=lambda x: os.path.getmtime(x), reverse=True) + return files[0] if files else 'ffmpeg' + except: + return 'ffmpeg' \ No newline at end of file diff --git a/models/unet_3d_blocks.py b/models/unet_3d_blocks.py index fa91387..e2e2c59 100755 --- a/models/unet_3d_blocks.py +++ b/models/unet_3d_blocks.py @@ -19,6 +19,104 @@ from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_temporal import TransformerTemporalModel +class DoDBlock(nn.Module): + """ + A downconvolution layer with masked video latents + Gets the masked video latents (the first and the last frame) and makes a masked convolution + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: Always 3D, downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + dims=3, + depth=0, + out_channels=None, + padding=1, + is_up=False,): + super().__init__() + self.channels = channels + self.out_channels = min((out_channels or channels) * (2 ** depth), 1280) if not is_up else (out_channels or channels) + self.dims = dims + self.is_up = is_up + stride = 2**depth if dims != 3 else (1, 2**depth, 2**depth) # if depth is zero, the stride is 1 + + # Convolution block, which should be initialized with zero weights and biases + # (zero conv) + self.conv_w = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + + self.conv_b = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + + # Conv for masking + self.mask_conv_w = nn.Conv2d( + 1, # only black and white + self.out_channels, + 3, + stride=stride, + padding=padding) + + self.mask_conv_b = nn.Conv2d( + 1, # only black and white + self.out_channels, + 3, + stride=stride, + padding=padding) + + # h - hidden states, x_c - frame conditioning, x_m - masked video latents + def forward(self, h, x_c=None, x_m=None): + + # When no frame conditioning is provided (top DoD iteration) + # return the untouched hidden states + if x_c is None or x_m is None: + return h + + # Add image conditioning as linear operation + + #print('IS UP ', self.is_up) + #print('h', h.shape) + #print('xc', x_c.shape) + #print('xm', x_m.shape) + + #print('cw', self.conv_w.weight.shape) + + # get weights and biases from frame conditioning + # vid convolution (initialized with zero weights and biases at first) + x_c_w = self.conv_w(x_c) + x_c_b = self.conv_b(x_c) + + #print('xcw', x_c_w.shape) + #print('xcb', x_c_b.shape) + + h = x_c_w * h + x_c_b + h # uses hadamard product + + # Use masked video latents to mask the convolution + x_m_w = self.mask_conv_w(x_m) + x_m_b = self.mask_conv_b(x_m) + + h = x_m_w * h + x_m_b + h # uses hadamard product + + return h + + def _init_weights(self): + # Zero initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.constant_(m.weight, 0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Assign gradient checkpoint function to simple variable for readability. g_c = checkpoint.checkpoint @@ -119,6 +217,7 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + infinet=None, ): if down_block_type == "DownBlock3D": return DownBlock3D( @@ -132,6 +231,7 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: @@ -153,6 +253,7 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) raise ValueError(f"{down_block_type} does not exist.") @@ -175,6 +276,7 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + infinet=None, ): if up_block_type == "UpBlock3D": return UpBlock3D( @@ -188,6 +290,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: @@ -209,6 +312,7 @@ def get_up_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + infinet=infinet, ) raise ValueError(f"{up_block_type} does not exist.") @@ -373,6 +477,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + infinet=None, ): super().__init__() resnets = [] @@ -406,6 +511,17 @@ def __init__( out_channels, ) ) + + if infinet is not None: + infinet.input_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + len(infinet.input_blocks_injections), + out_channels, + is_up=False, + ) + ) + attentions.append( Transformer2DModel( out_channels // attn_num_head_channels, @@ -453,6 +569,9 @@ def forward( attention_mask=None, num_frames=1, cross_attention_kwargs=None, + dod_block=None, + x_c=None, + x_m=None, ): # TODO(Patrick, William) - attention mask is not used output_states = () @@ -462,6 +581,8 @@ def forward( ): if self.gradient_checkpointing: + # TODO: Infinet is not implemented here yet! + # so don't use it for now with gradient checkpointing on hidden_states = cross_attn_g_c( attn, temp_attn, @@ -477,6 +598,10 @@ def forward( else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -511,6 +636,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + infinet=None, ): super().__init__() resnets = [] @@ -540,6 +666,16 @@ def __init__( ) ) + if infinet is not None: + infinet.input_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, # dims + len(infinet.input_blocks_injections), + out_channels, + is_up=False, + ) + ) + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) @@ -554,7 +690,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None, num_frames=1): + def forward(self, hidden_states, temb=None, num_frames=1, dod_block=None, x_c=None, x_m=None,): output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): @@ -563,6 +699,9 @@ def forward(self, hidden_states, temb=None, num_frames=1): else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) output_states += (hidden_states,) @@ -597,6 +736,7 @@ def __init__( use_linear_projection=False, only_cross_attention=False, upcast_attention=False, + infinet=None, ): super().__init__() resnets = [] @@ -632,6 +772,19 @@ def __init__( out_channels, ) ) + + if infinet is not None: + #print(len(infinet.input_blocks_injections)) + #print(len(infinet.output_blocks_injections)) + infinet.output_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + max(0, 3 - len(infinet.output_blocks_injections)),#max(0, len(infinet.input_blocks_injections) - len(infinet.output_blocks_injections)), + (out_channels // 2**(len(infinet.output_blocks_injections)-1)) if len(infinet.output_blocks_injections) > 1 else out_channels, + is_up=True, + ) + ) + attentions.append( Transformer2DModel( out_channels // attn_num_head_channels, @@ -675,6 +828,9 @@ def forward( attention_mask=None, num_frames=1, cross_attention_kwargs=None, + dod_block=None, + x_c=None, + x_m=None, ): # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( @@ -686,6 +842,8 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.gradient_checkpointing: + # TODO: Infinet is not implemented here yet! + # so don't use it with gradient checkpointing hidden_states = cross_attn_g_c( attn, temp_attn, @@ -701,6 +859,10 @@ def forward( else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -731,10 +893,12 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + infinet=None, ): super().__init__() resnets = [] temp_convs = [] + self.gradient_checkpointing=False for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels @@ -761,6 +925,18 @@ def __init__( ) ) + if infinet is not None: + #print(len(infinet.input_blocks_injections)) + #print(len(infinet.output_blocks_injections)) + infinet.output_blocks_injections.append(DoDBlock( + infinet.in_channels, + 2, + max(0, 3 - len(infinet.output_blocks_injections)),#max(0, len(infinet.input_blocks_injections) - len(infinet.output_blocks_injections)), + (out_channels // 2**(len(infinet.output_blocks_injections)-1)) if len(infinet.output_blocks_injections) > 1 else out_channels, + is_up=True, + ) + ) + self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) @@ -769,7 +945,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1, dod_block=None, x_c=None, x_m=None,): for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -781,6 +957,9 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si else: hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if dod_block is not None: + hidden_states = dod_block(hidden_states, x_c, x_m) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/models/unet_3d_condition.py b/models/unet_3d_condition.py index 8857cad..98d2ce8 100755 --- a/models/unet_3d_condition.py +++ b/models/unet_3d_condition.py @@ -32,11 +32,34 @@ UpBlock3D, get_down_block, get_up_block, + DoDBlock, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Class to keep DiffusionOverDiffusion modules as a separate model +# with weights saveable as a detachable checkpoint +class InfiNet(nn.Module): + def __init__(self, in_channels): + super(InfiNet, self).__init__() + + self.in_channels = in_channels + + self.diffusion_depth = 0 # Placeholder, because it's not passable into the pipeline + + self.input_blocks_injections = nn.ModuleList() + self.output_blocks_injections = nn.ModuleList() + + def _init_weights(self): + # Zero initialization + for m in self.modules(): + if isinstance(m, DoDBlock): + m._init_weights() + if isinstance(m, nn.ModuleList): + for l in m: + if isinstance(l, DoDBlock): + l._init_weights() @dataclass class UNet3DConditionOutput(BaseOutput): @@ -103,11 +126,14 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, + use_infinet=True, ): super().__init__() self.sample_size = sample_size + self.use_infinet = use_infinet + # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( @@ -154,6 +180,10 @@ def __init__( self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) + if self.use_infinet: + # InfiNet insertion + self.infinet = InfiNet(in_channels) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) @@ -178,6 +208,7 @@ def __init__( attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=False, + infinet=self.infinet if self.use_infinet else None, ) self.down_blocks.append(down_block) @@ -230,6 +261,7 @@ def __init__( cross_attention_dim=cross_attention_dim, attn_num_head_channels=reversed_attention_head_dim[i], dual_cross_attention=False, + infinet=self.infinet if self.use_infinet else None, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -248,6 +280,10 @@ def __init__( self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + + if use_infinet: + + self.infinet._init_weights() def set_attention_slice(self, slice_size): r""" @@ -331,6 +367,7 @@ def forward( down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, + #diffusion_depth: int = 1, #until diffusion_depth is in Diffusers, we'll have to set it up in class properties instead ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: @@ -397,15 +434,31 @@ def forward( emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + # If aiming for DiffusionOverDiffusion and have InfiNet enabled, keep the original video + # + its mask of the first and last frames + + if self.use_infinet and self.infinet.diffusion_depth > 0: + x_c = sample.clone().detach() + x_m = torch.zeros(x_c.shape[:1] + (1,) + x_c.shape[2:], dtype=x_c.dtype, device=x_c.device) + x_m[:, :, 0, :, :] = torch.ones_like(x_m[:, :, 0, :, :]) + x_m[:, :, -1, :, :] = torch.ones_like(x_m[:, :, -1, :, :]) + x_c = x_c.permute(0, 2, 1, 3, 4).reshape((x_c.shape[0] * num_frames, -1) + x_c.shape[3:]) + x_m = x_m.permute(0, 2, 1, 3, 4).reshape((x_m.shape[0] * num_frames, -1) + x_m.shape[3:]) + else: + x_c = None + x_m = None + # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) + # InfiNet TODO: do we have to add a module injection here as well? + sample = self.transformer_in(sample, num_frames=num_frames).sample # 3. down down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: + for i, downsample_block in enumerate(self.down_blocks): if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, @@ -414,9 +467,12 @@ def forward( attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, + dod_block=self.infinet.input_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames, dod_block=self.infinet.input_blocks_injections[i] if self.infinet is not None else None, x_c=x_c, x_m=x_m,) down_block_res_samples += res_samples @@ -467,6 +523,9 @@ def forward( attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, + dod_block=self.infinet.output_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) else: sample = upsample_block( @@ -475,6 +534,9 @@ def forward( res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, + dod_block=self.infinet.output_blocks_injections[i] if self.infinet is not None else None, + x_c=x_c, + x_m=x_m, ) # 6. post-process diff --git a/requirements.txt b/requirements.txt index 4ffbc0d..30f44f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accelerate torch torchvision torchaudio +imageio_ffmpeg git+https://github.com/huggingface/diffusers.git transformers einops diff --git a/train.py b/train.py index 4b3796e..95379de 100644 --- a/train.py +++ b/train.py @@ -84,7 +84,18 @@ def load_primary_models(pretrained_model_path): tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") - unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + unet = UNet3DConditionModel() + + model_path = os.path.join(os.getcwd(), pretrained_model_path, 'unet', 'diffusion_pytorch_model.bin') + # Load the pretrained weights + pretrained_dict = torch.load( + model_path, + map_location=torch.device('cuda'), + ) + unet.load_state_dict(pretrained_dict, strict=False) + + unet.infinet._init_weights() return noise_scheduler, tokenizer, text_encoder, vae, unet @@ -222,7 +233,7 @@ def main( train_data: Dict, validation_data: Dict, validation_steps: int = 100, - trainable_modules: Tuple[str] = ("attn1", "attn2" ), + trainable_modules: Tuple[str] = ("attn1", "attn2", "infinet"), train_batch_size: int = 1, max_train_steps: int = 500, learning_rate: float = 5e-5, @@ -318,7 +329,7 @@ def main( if train_data.pop("type", "regular") == "folder": train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer) else: - train_dataset = VideoDataset(**train_data, tokenizer=tokenizer) + train_dataset = VideoDataset(**train_data, tokenizer=tokenizer, train_infinet='infinet' in trainable_modules if trainable_modules is not None else False) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -391,6 +402,9 @@ def finetune_unet(batch, train_encoder=False): #noise_scheduler.beta_schedule = "squaredcos_cap_v2" unet.train() + + # Set up diffusion depth for infinet training + unet.infinet.diffusion_depth = batch["diffusion_depth"] # Convert videos to latent space pixel_values = batch["pixel_values"].to(weight_dtype) diff --git a/utils/dataset.py b/utils/dataset.py index 8860f25..10339b7 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -26,6 +26,7 @@ def __init__( preprocessed: bool = False, single_video_path: str = "", single_video_prompt: str = "", + train_infinet = False, **kwargs ): @@ -49,6 +50,8 @@ def __init__( self.sample_frame_rate = sample_frame_rate self.sample_frame_rate_init = sample_frame_rate + self.train_infinet = train_infinet + def load_from_json(self, path): # Don't load a JSON file if we're doing single video training if os.path.exists(self.single_video_path): return @@ -107,7 +110,7 @@ def get_sample_idx(self, idx, vr): def get_vid_idx(self, vr, vid_data=None): - if self.use_random_start_idx and self.n_sample_frames == 1: + if self.use_random_start_idx and self.n_sample_frames == 1 and not self.train_infinet: # Randomize the frame rate at different speeds self.sample_frame_rate = random.randint(1, self.sample_frame_rate_init) @@ -119,6 +122,7 @@ def get_vid_idx(self, vr, vid_data=None): idx = random.randint(1, max_frame) else: + if vid_data is not None: idx = vid_data['frame_index'] else: @@ -187,6 +191,9 @@ def __getitem__(self, index): # Get video prompt prompt = vid_data['prompt'] + # Get diffusion depth for training Infinet + diffusion_depth = vid_data['diffusion_depth'] if 'diffusion_depth' in vid_data.keys() else 0 + video = vr.get_batch(sample_index) video = rearrange(video, "f h w c -> f c h w") @@ -195,7 +202,8 @@ def __getitem__(self, index): example = { "pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], - "text_prompt": prompt + "text_prompt": prompt, + "diffusion_depth": diffusion_depth, } return example From f31deef0937f3fc755960b609e42584197f71a3c Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 4 Apr 2023 23:33:43 +0300 Subject: [PATCH 02/12] add video chopping script --- video_chop.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 video_chop.py diff --git a/video_chop.py b/video_chop.py new file mode 100644 index 0000000..44d98ae --- /dev/null +++ b/video_chop.py @@ -0,0 +1,69 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +from pathlib import Path + +def chop_video(video_path: str, L: int) -> None: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + total_frames = (total_frames // L) * L + + video_frames = [] + for _ in tqdm(range(total_frames), desc='Reading video frames'): + ret, frame = video.read() + if ret: + video_frames.append(frame) + + # Calculate the maximum depth level + max_depth = 0 + while L ** (max_depth) <= total_frames // L: + max_depth += 1 + + dir_name = Path(video_path).stem + + for curr_depth in range(max_depth): + num_splits = L ** curr_depth + frames_per_split = total_frames // num_splits + if dir_name == "": + dir_name = f"depth_{curr_depth}" + else: + dir_name = os.path.join(dir_name, f"depth_{curr_depth}") + os.makedirs(dir_name, exist_ok=True) + + for i in tqdm(range(num_splits), desc=f'Depth {curr_depth}'): + + os.makedirs(os.path.join(dir_name, f"part_{i//L}"), exist_ok=True) + output_filename = f"{dir_name}/part_{i//L}/subset_{i%L}.mp4" + height, width, _ = video_frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + + start_index = i * frames_per_split + end_index = (i + 1) * frames_per_split + + for j in tqdm(range(start_index, end_index), desc=f'Subset {i}, {len(range(start_index, end_index))} Frames'): + out.write(video_frames[j]) + + out.release() + # create a txt file alongside the video + with open(f"{dir_name}/part_{i//L}/subset_{i%L}.txt", "w") as f: + f.write(f"") + + video.release() + +def main(): + parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") + parser.add_argument("video_file", help="Path to the video file.") + parser.add_argument("--L", help="Num of splits on each level.") + args = parser.parse_args() + chop_video(args.video_file, int(args.L)) + +if __name__ == "__main__": + main() From 0821aa7245dd24f45555700bbd0aeaa8b6bf7d45 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Fri, 7 Apr 2023 13:00:12 +0300 Subject: [PATCH 03/12] add infinet dataloader --- configs/my_config.yaml | 6 ++- utils/dataset.py | 115 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 109 insertions(+), 12 deletions(-) diff --git a/configs/my_config.yaml b/configs/my_config.yaml index 94dccd7..edc50eb 100644 --- a/configs/my_config.yaml +++ b/configs/my_config.yaml @@ -7,7 +7,7 @@ train_data: preprocessed: True n_sample_frames: 8 width: 256 - height: 256 + height: 128 sample_start_idx: 0 sample_frame_rate: 15 use_random_start_idx: False @@ -16,12 +16,14 @@ train_data: single_video_path: "" single_video_prompt: "" + infinet_source_path: "video" + validation_data: prompt: "" sample_preview: True num_frames: 16 width: 256 - height: 256 + height: 128 num_inference_steps: 50 guidance_scale: 9 diff --git a/utils/dataset.py b/utils/dataset.py index 10339b7..f8115e1 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -1,4 +1,7 @@ import os +import re +import secrets + import decord import numpy as np import random @@ -27,16 +30,19 @@ def __init__( single_video_path: str = "", single_video_prompt: str = "", train_infinet = False, + infinet_source_path: str = "", **kwargs ): - self.tokenizer = tokenizer self.preprocessed = preprocessed self.single_video_path = single_video_path self.single_video_prompt = single_video_prompt - self.train_data = self.load_from_json(json_path) + if not train_infinet: + self.train_data = self.load_from_json(json_path) + else: + self.train_data = infinet_source_path self.vid_data_key = vid_data_key self.sample_iters = 0 self.original_start_idx = sample_start_idx @@ -51,7 +57,58 @@ def __init__( self.sample_frame_rate_init = sample_frame_rate self.train_infinet = train_infinet - + self.depth_data = {} + + if self.train_infinet: + self.init_infinet_video_data() + + def init_infinet_video_data(self): + print("Initializing InfiNet Dataset") + self.depths, self.train_data = self.count_folders(self.train_data) + self.infindex = 0 + self.prev_depth = 0 + + def count_folders(self, path, depth_counter=0, parts_counter=None): + if parts_counter is None: + parts_counter = {} + + for entry in os.scandir(path): + if entry.is_dir(): + if entry.name.startswith("depth"): + depth_counter += 1 + current_depth = entry.name + current_depth = int(re.search(r'\d+', current_depth).group()) + if current_depth not in parts_counter: + parts_counter[current_depth] = {'depth':current_depth, 'part_count': 0, 'mp4_data': []} + depth_counter, parts_counter = self.count_folders(entry.path, depth_counter, parts_counter) + elif entry.name.startswith("part"): + parent_depth = os.path.basename(os.path.dirname(entry.path)) + if parent_depth.startswith("depth"): + parent_depth = int(re.search(r'\d+', parent_depth).group()) + parts_counter[parent_depth]['part_count'] += 1 + mp4_data = self.get_mp4_and_txt_data(entry.path) + parts_counter[parent_depth]['mp4_data'].extend(mp4_data) + else: + depth_counter, parts_counter = self.count_folders(entry.path, depth_counter, parts_counter) + + return depth_counter, parts_counter + + def get_mp4_and_txt_data(self, path): + mp4_data = [] + for entry in os.scandir(path): + if entry.is_file() and entry.name.lower().endswith(".mp4"): + mp4_path = entry.path + txt_path = os.path.splitext(mp4_path)[0] + ".txt" + + if os.path.exists(txt_path): + with open(txt_path, "r") as txt_file: + txt_content = txt_file.read() + else: + txt_content = None + + mp4_data.append({'mp4_path': mp4_path, 'txt_content': txt_content}) + + return mp4_data def load_from_json(self, path): # Don't load a JSON file if we're doing single video training if os.path.exists(self.single_video_path): return @@ -92,7 +149,25 @@ def get_prompt_ids(self, prompt): def get_frame_range(self, idx, vr): return list(range(idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] - + + def get_infinet_frame_range(self, current_depth, vr): + max_depths = self.depths + min_frames = 2 + max_frames = self.n_sample_frames + + # Calculate the number of frames to be sampled at the current depth + num_frames = min_frames + int((max_frames - min_frames) * (current_depth / (max_depths - 1))) + + # Limit the number of frames to the total number of frames in the video + num_frames = min(num_frames, len(vr)) + + # Calculate the step size between frames + step = (len(vr) - 1) // (num_frames - 1) + + # Generate the frame indices + frame_range = [idx for idx in range(0, len(vr), step)][:num_frames] + + return frame_range def get_sample_idx(self, idx, vr): # Get the frame idx range based on the get_vid_idx function # We have a fallback here just in case we the frame cannot be read @@ -136,15 +211,38 @@ def __len__(self): else: return 1 - def __getitem__(self, index): + def __getitem__(self, index, depth=None): # Initialize variables video = None prompt = None prompt_ids = None + if self.train_infinet and depth != None: + parts = int(self.train_data[depth]["part_count"]) + if depth != self.prev_depth or self.infindex > parts: + self.infindex = 0 + self.prev_depth = depth + if self.use_random_start_idx: + self.infindex = secrets.randbelow(parts) + + train_data = self.train_data[depth]["mp4_data"][self.infindex]["mp4_path"] + vr = decord.VideoReader(train_data, width=self.width, height=self.height) + + sample_index = self.get_infinet_frame_range(depth, vr) + + # Process video and rearrange + video = vr.get_batch(sample_index) + video = rearrange(video, "f h w c -> f c h w") + + prompt = self.train_data[depth]["mp4_data"][self.infindex]["txt_content"] + prompt_ids = self.get_prompt_ids(prompt) + + self.infindex += 1 + + # Check if we're doing single video training - if os.path.exists(self.single_video_path): + elif not self.train_infinet and os.path.exists(self.single_video_path): train_data = self.single_video_path # Load and sample video frames @@ -191,9 +289,6 @@ def __getitem__(self, index): # Get video prompt prompt = vid_data['prompt'] - # Get diffusion depth for training Infinet - diffusion_depth = vid_data['diffusion_depth'] if 'diffusion_depth' in vid_data.keys() else 0 - video = vr.get_batch(sample_index) video = rearrange(video, "f h w c -> f c h w") @@ -203,7 +298,7 @@ def __getitem__(self, index): "pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, - "diffusion_depth": diffusion_depth, + "diffusion_depth": depth, } return example From e49f8e4c7d1467f1b3b7ab36f1d19be97b975077 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 9 Apr 2023 12:49:29 +0300 Subject: [PATCH 04/12] don't reset infinet weights at start --- eval.py | 10 +++++++++- train.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/eval.py b/eval.py index 12e3813..3c1529b 100644 --- a/eval.py +++ b/eval.py @@ -95,7 +95,15 @@ def load_primary_models(pretrained_model_path): ) unet.load_state_dict(pretrained_dict, strict=False) - unet.infinet._init_weights() + has_pretrained_weights = False + + for k, v in pretrained_dict.items(): + if k.startswith('infinet'): + has_pretrained_weights = True + + if not has_pretrained_weights: + print('Pretrained Infinet not found, setting its weights to zeros') + unet.infinet._init_weights() unet.infinet.diffusion_depth = 1 #unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") diff --git a/train.py b/train.py index 95379de..4ddd208 100644 --- a/train.py +++ b/train.py @@ -95,7 +95,16 @@ def load_primary_models(pretrained_model_path): ) unet.load_state_dict(pretrained_dict, strict=False) - unet.infinet._init_weights() + has_pretrained_weights = False + + for k, v in pretrained_dict.items(): + if k.startswith('infinet'): + has_pretrained_weights = True + + if not has_pretrained_weights: + print('Pretrained Infinet not found, setting its weights to zeros') + print("It's expected when training for the first time") + unet.infinet._init_weights() return noise_scheduler, tokenizer, text_encoder, vae, unet From bd69c67596d04eee4c6b92a72756320160f7ad68 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 9 Apr 2023 12:52:55 +0300 Subject: [PATCH 05/12] try catch for empty items --- utils/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/dataset.py b/utils/dataset.py index f8115e1..1ad592a 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -207,7 +207,10 @@ def get_vid_idx(self, vr, vid_data=None): def __len__(self): if self.train_data is not None: - return len(self.train_data['data']) + try: + return len(self.train_data['data']) + except: + return 1 else: return 1 From a7abd9749d8a44c88108ed25a8bcd44bcacca599 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sun, 9 Apr 2023 12:53:09 +0300 Subject: [PATCH 06/12] code for depth alteration --- utils/dataset.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/utils/dataset.py b/utils/dataset.py index 1ad592a..ab3e1be 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -61,6 +61,7 @@ def __init__( if self.train_infinet: self.init_infinet_video_data() + self.depth = None def init_infinet_video_data(self): print("Initializing InfiNet Dataset") @@ -214,14 +215,19 @@ def __len__(self): else: return 1 - def __getitem__(self, index, depth=None): + def __getitem__(self, index): # Initialize variables video = None prompt = None prompt_ids = None - if self.train_infinet and depth != None: + if self.train_infinet and self.depth != None: + + print("Using Depth:",self.depth) + + depth = self.depth + parts = int(self.train_data[depth]["part_count"]) if depth != self.prev_depth or self.infindex > parts: self.infindex = 0 From 380eeb72b928ea5de48f6dab92c2f9bd266ead70 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 17 Apr 2023 02:13:34 +0300 Subject: [PATCH 07/12] fix the dataset composer --- video_chop.py | 66 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/video_chop.py b/video_chop.py index 44d98ae..2529dc0 100644 --- a/video_chop.py +++ b/video_chop.py @@ -5,34 +5,40 @@ from tqdm import tqdm from pathlib import Path -def chop_video(video_path: str, L: int) -> None: +def chop_video(video_path: str, folder:str, L: int, start_frame:int) -> int: if not os.path.exists(video_path): raise FileNotFoundError(f"Video file '{video_path}' not found.") video = cv2.VideoCapture(video_path) fps = int(video.get(cv2.CAP_PROP_FPS)) - total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - start_frame + + # Calculate the maximum depth level + max_depth = 0 + while L ** (max_depth) <= total_frames: + max_depth += 1 + + max_depth = max_depth - 1 - total_frames = (total_frames // L) * L + total_frames = L**max_depth video_frames = [] + for _ in tqdm(range(start_frame), desc='Reading dummy frames'): + _, _ = video.read() + for _ in tqdm(range(total_frames), desc='Reading video frames'): ret, frame = video.read() if ret: video_frames.append(frame) - # Calculate the maximum depth level - max_depth = 0 - while L ** (max_depth) <= total_frames // L: - max_depth += 1 - - dir_name = Path(video_path).stem + dir_name = folder#Path(video_path).stem + #dir_name = os.path.join(folder, dir_name) for curr_depth in range(max_depth): num_splits = L ** curr_depth frames_per_split = total_frames // num_splits if dir_name == "": - dir_name = f"depth_{curr_depth}" + dir_name = os.path.join(f"depth_{curr_depth}") else: dir_name = os.path.join(dir_name, f"depth_{curr_depth}") os.makedirs(dir_name, exist_ok=True) @@ -48,6 +54,9 @@ def chop_video(video_path: str, L: int) -> None: start_index = i * frames_per_split end_index = (i + 1) * frames_per_split + print(f'start_index: {start_index}') + print(f'end_index: {end_index}') + for j in tqdm(range(start_index, end_index), desc=f'Subset {i}, {len(range(start_index, end_index))} Frames'): out.write(video_frames[j]) @@ -57,13 +66,48 @@ def chop_video(video_path: str, L: int) -> None: f.write(f"") video.release() + return total_frames + +def stuff(video_path: str, L: int): + + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + cur_dir_name = os.path.split(video_path)[0]#Path(video_path).stem + orig_name = Path(video_path).stem + #dir_name = cur_dir_name + #os.mkdir(dir_name) + vid_name = os.path.split(video_path)[1] + + scenario = 0 + start_frame = 0 + + while start_frame < total_frames - L: + dir_name = f"scenario_{scenario}" + os.mkdir(os.path.join(cur_dir_name, dir_name)) + video_path_new = os.path.join(cur_dir_name, dir_name, vid_name) + os.rename(video_path, video_path_new) + video_path = video_path_new + start_frame += chop_video(video_path, dir_name, L, start_frame) + scenario += 1 + + os.rename(video_path, os.path.join(os.getcwd(), vid_name)) + os.mkdir(orig_name) + + for i in os.listdir(os.getcwd()): + if i.startswith('scenario_'): + os.rename(i, os.path.join(orig_name, i)) def main(): parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") parser.add_argument("video_file", help="Path to the video file.") parser.add_argument("--L", help="Num of splits on each level.") args = parser.parse_args() - chop_video(args.video_file, int(args.L)) + stuff(args.video_file, int(args.L)) if __name__ == "__main__": main() From 67b7a4b83caac5fe3a13d16c4591ca09bcf7faa7 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 22 Apr 2023 14:53:11 +0300 Subject: [PATCH 08/12] fix infinet dataset collection --- utils/chops_to_folder_dataset.py | 95 ++++++++++++++++++++++++ utils/video_chop.py | 120 +++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 utils/chops_to_folder_dataset.py create mode 100644 utils/video_chop.py diff --git a/utils/chops_to_folder_dataset.py b/utils/chops_to_folder_dataset.py new file mode 100644 index 0000000..7c716f4 --- /dev/null +++ b/utils/chops_to_folder_dataset.py @@ -0,0 +1,95 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +import shutil +from pathlib import Path + +def write_as_video(output_filename, video_frames, overwrite_dims, width, height, fps): + if overwrite_dims: + height, width, _ = video_frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + for j in video_frames: + out.write(j) + + out.release() + +def read_first_frame(video_path): + patience = 5 + p = 0 + video = cv2.VideoCapture(video_path) + ret = False + while not ret: + ret, frame = video.read() + p += 1 + if p > patience: + raise Exception(f'Cannot read the video at {video_path}') + video.release() + return frame + +def get_fps(video_path): + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + video.release() + return fps + +def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite_fps, fps): + + folder_dataset_path = os.path.join(init_path, 'folder_dataset') + depth_name = init_path + + t_counter=0 + for d in range(0, depth): + for j in range(L**(d-1) if d > 1 else 1): + for i in range(L if d > 0 else 1): + t_counter+=1 + tq = tqdm(total=t_counter) + + for d in range(0, depth): + depth_name = os.path.join(depth_name, f'depth_{d}') + for j in range(L**(d-1) if d > 1 else 1): + part_path = os.path.join(depth_name, f'part_{j}') + # sample the text info for the next subset + for i in range(L if d > 0 else 1): + txt_path = os.path.join(part_path, f'subset_{i}.txt') + + # go to the subset for video frames sampling + next_depth_name = os.path.join(depth_name, f'depth_{d+1}') + next_part_path = os.path.join(next_depth_name, f'part_{i}') # `i` cause we want to sample each corresponding *subset* + + # depths > 0 are *guaranteed* to have L videos in their part_j folders + + # now sampling each first frame at the next level + L_frames = [read_first_frame(os.path.join(next_part_path, f'subset_{k}.mp4')) for k in range(L)] + + # write all the L sampled frames to an mp4 in the folder dataset + if overwrite_fps: + fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) + + write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i}.mp4'), L_frames, overwrite_dims, width, height, fps) + shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i}.txt')) + + t += 1 + tq.set_description(f'Depth {d}, part {j}, subset{i}') + tq.update(t) + + tq.close() + +def main(): + parser = argparse.ArgumentParser(description="Convert the chopped labeled tree-like data into a FolderDataset") + parser.add_argument("video_file", help="Path to the video file.") + parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--D", help="Tree depth") + parser.add_argument("--overwrite_dims", help="Preserve the original video dims", action="store_true") + parser.add_argument("--w", help="Output video width", default=384) + parser.add_argument("--h", help="Output video height", default=256) + parser.add_argument("--overwrite_fps", help="Preserve the original video fps", action="store_true") + parser.add_argument("--fps", help="Output video fps", default=12) + args = parser.parse_args() + move_the_files(args.video_file, int(args.L), int(args.D)) + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/utils/video_chop.py b/utils/video_chop.py new file mode 100644 index 0000000..4b2b33d --- /dev/null +++ b/utils/video_chop.py @@ -0,0 +1,120 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +from pathlib import Path + +def chop_video(video_path: str, folder:str, L: int, start_frame:int) -> int: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - start_frame + + # Calculate the maximum depth level + max_depth = 0 + while L ** (max_depth) <= total_frames: + max_depth += 1 + + max_depth = max_depth - 1 + + total_frames = L**max_depth + + video_frames = [] + for _ in tqdm(range(start_frame), desc='Reading dummy frames'): + _, _ = video.read() + + for _ in tqdm(range(total_frames), desc='Reading video frames'): + ret, frame = video.read() + if ret: + video_frames.append(frame) + + dir_name = folder#Path(video_path).stem + #dir_name = os.path.join(folder, dir_name) + + for curr_depth in range(max_depth): + num_splits = L ** curr_depth + frames_per_split = total_frames // num_splits + if dir_name == "": + dir_name = os.path.join(f"depth_{curr_depth}") + else: + dir_name = os.path.join(dir_name, f"depth_{curr_depth}") + os.makedirs(dir_name, exist_ok=True) + + for i in tqdm(range(num_splits), desc=f'Depth {curr_depth}'): + + os.makedirs(os.path.join(dir_name, f"part_{i//L}"), exist_ok=True) + output_filename = f"{dir_name}/part_{i//L}/subset_{i%L}.mp4" + height, width, _ = video_frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + + start_index = i * frames_per_split + end_index = (i + 1) * frames_per_split + + print(f'start_index: {start_index}') + print(f'end_index: {end_index}') + + for j in tqdm(range(start_index, end_index), desc=f'Subset {i}, {len(range(start_index, end_index))} Frames'): + out.write(video_frames[j]) + + out.release() + # create a txt file alongside the video + with open(f"{dir_name}/part_{i//L}/subset_{i%L}.txt", "w") as f: + f.write(f"") + + video.release() + return total_frames + +def stuff(video_path: str, L: int, only_once = True): + + only_once = True + + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file '{video_path}' not found.") + + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + cur_dir_name = os.path.split(video_path)[0]#Path(video_path).stem + orig_name = Path(video_path).stem + #dir_name = cur_dir_name + #os.mkdir(dir_name) + vid_name = os.path.split(video_path)[1] + + scenario = 0 + start_frame = 0 + + while start_frame < total_frames - L: + + dir_name = f"scenario_{scenario}" + os.mkdir(os.path.join(cur_dir_name, dir_name)) + video_path_new = os.path.join(cur_dir_name, dir_name, vid_name) + os.rename(video_path, video_path_new) + video_path = video_path_new + start_frame += chop_video(video_path, dir_name, L, start_frame) + scenario += 1 + + if only_once: + break + + os.rename(video_path, os.path.join(os.getcwd(), vid_name)) + os.mkdir(orig_name) + + for i in os.listdir(os.getcwd()): + if i.startswith('scenario_'): + os.rename(i, os.path.join(orig_name, i)) + +def main(): + parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") + parser.add_argument("video_file", help="Path to the video file.") + parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--subscenariosplit", help="Should it split ", action='store_true', default=False) + args = parser.parse_args() + stuff(args.video_file, int(args.L), bool(args.subscenariosplit != None and args.subscenariosplit)) + +if __name__ == "__main__": + main() From 5b860d0e1f35fb049b1151573039ad0b7717f4ed Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 22 Apr 2023 15:29:15 +0300 Subject: [PATCH 09/12] fix dataset converter not reading parts > L --- utils/chops_to_folder_dataset.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/utils/chops_to_folder_dataset.py b/utils/chops_to_folder_dataset.py index 7c716f4..f670703 100644 --- a/utils/chops_to_folder_dataset.py +++ b/utils/chops_to_folder_dataset.py @@ -38,6 +38,7 @@ def get_fps(video_path): def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite_fps, fps): folder_dataset_path = os.path.join(init_path, 'folder_dataset') + os.mkdir(folder_dataset_path) depth_name = init_path t_counter=0 @@ -57,7 +58,7 @@ def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite # go to the subset for video frames sampling next_depth_name = os.path.join(depth_name, f'depth_{d+1}') - next_part_path = os.path.join(next_depth_name, f'part_{i}') # `i` cause we want to sample each corresponding *subset* + next_part_path = os.path.join(next_depth_name, f'part_{i+L*j}') # `i` cause we want to sample each corresponding *subset* # depths > 0 are *guaranteed* to have L videos in their part_j folders @@ -68,18 +69,18 @@ def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite if overwrite_fps: fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) - write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i}.mp4'), L_frames, overwrite_dims, width, height, fps) - shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i}.txt')) + write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.mp4'), L_frames, overwrite_dims, width, height, fps) + shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.txt')) - t += 1 tq.set_description(f'Depth {d}, part {j}, subset{i}') - tq.update(t) + #tq.set_description(os.path.join(next_part_path, f'subset_{0}.mp4')) + tq.update(1) tq.close() def main(): parser = argparse.ArgumentParser(description="Convert the chopped labeled tree-like data into a FolderDataset") - parser.add_argument("video_file", help="Path to the video file.") + parser.add_argument("outpath", help="Path where to save the end FolderDataset", default=os.getcwd()) parser.add_argument("--L", help="Num of splits on each level.") parser.add_argument("--D", help="Tree depth") parser.add_argument("--overwrite_dims", help="Preserve the original video dims", action="store_true") @@ -88,7 +89,7 @@ def main(): parser.add_argument("--overwrite_fps", help="Preserve the original video fps", action="store_true") parser.add_argument("--fps", help="Output video fps", default=12) args = parser.parse_args() - move_the_files(args.video_file, int(args.L), int(args.D)) + move_the_files(args.outpath, int(args.L), int(args.D), bool(args.overwrite_dims), int(args.w), int(args.h), bool(args.overwrite_fps), int(args.fps)) if __name__ == "__main__": main() From 664f647ddd24100d5075c29b6d2e87a6e2fda712 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 22 Apr 2023 16:25:23 +0300 Subject: [PATCH 10/12] save to PIL gifs instead of mp4s as mp4 often fails for such short videos --- utils/chops_to_folder_dataset.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/utils/chops_to_folder_dataset.py b/utils/chops_to_folder_dataset.py index f670703..1861662 100644 --- a/utils/chops_to_folder_dataset.py +++ b/utils/chops_to_folder_dataset.py @@ -5,16 +5,22 @@ from tqdm import tqdm import shutil from pathlib import Path +from PIL import Image def write_as_video(output_filename, video_frames, overwrite_dims, width, height, fps): if overwrite_dims: height, width, _ = video_frames[0].shape - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) - for j in video_frames: - out.write(j) + #fourcc = cv2.VideoWriter_fourcc(*'mp4v') + #out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + #for j in video_frames: + # out.write(j) + video_frames = [Image.fromarray(cv2.cvtColor(j, cv2.COLOR_BGR2RGB)) for j in video_frames] + if overwrite_dims: + video_frames = [j.resize(width, height) for j in video_frames] + + video_frames[0].save(output_filename, save_all=True, append_images=video_frames[1:]) - out.release() + #out.release() def read_first_frame(video_path): patience = 5 @@ -69,7 +75,7 @@ def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite if overwrite_fps: fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) - write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.mp4'), L_frames, overwrite_dims, width, height, fps) + write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.gif'), L_frames, overwrite_dims, width, height, fps) shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.txt')) tq.set_description(f'Depth {d}, part {j}, subset{i}') From a821d8b88baf94ad7803deae73e1d029ead4b37e Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 22 Apr 2023 18:49:42 +0300 Subject: [PATCH 11/12] add support for InfiNet training for FolderDataset --- chops_to_folder_dataset.py | 102 ++++++++++++++++++++++++++++++++++++ configs/video_folder.yaml | 1 + models/unet_3d_blocks.py | 0 models/unet_3d_condition.py | 0 train.py | 7 ++- utils/dataset.py | 17 ++++-- video_chop.py | 11 +++- 7 files changed, 130 insertions(+), 8 deletions(-) create mode 100644 chops_to_folder_dataset.py mode change 100755 => 100644 models/unet_3d_blocks.py mode change 100755 => 100644 models/unet_3d_condition.py diff --git a/chops_to_folder_dataset.py b/chops_to_folder_dataset.py new file mode 100644 index 0000000..1861662 --- /dev/null +++ b/chops_to_folder_dataset.py @@ -0,0 +1,102 @@ +import os +import sys +import argparse +import cv2 +from tqdm import tqdm +import shutil +from pathlib import Path +from PIL import Image + +def write_as_video(output_filename, video_frames, overwrite_dims, width, height, fps): + if overwrite_dims: + height, width, _ = video_frames[0].shape + #fourcc = cv2.VideoWriter_fourcc(*'mp4v') + #out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) + #for j in video_frames: + # out.write(j) + video_frames = [Image.fromarray(cv2.cvtColor(j, cv2.COLOR_BGR2RGB)) for j in video_frames] + if overwrite_dims: + video_frames = [j.resize(width, height) for j in video_frames] + + video_frames[0].save(output_filename, save_all=True, append_images=video_frames[1:]) + + #out.release() + +def read_first_frame(video_path): + patience = 5 + p = 0 + video = cv2.VideoCapture(video_path) + ret = False + while not ret: + ret, frame = video.read() + p += 1 + if p > patience: + raise Exception(f'Cannot read the video at {video_path}') + video.release() + return frame + +def get_fps(video_path): + video = cv2.VideoCapture(video_path) + fps = int(video.get(cv2.CAP_PROP_FPS)) + video.release() + return fps + +def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite_fps, fps): + + folder_dataset_path = os.path.join(init_path, 'folder_dataset') + os.mkdir(folder_dataset_path) + depth_name = init_path + + t_counter=0 + for d in range(0, depth): + for j in range(L**(d-1) if d > 1 else 1): + for i in range(L if d > 0 else 1): + t_counter+=1 + tq = tqdm(total=t_counter) + + for d in range(0, depth): + depth_name = os.path.join(depth_name, f'depth_{d}') + for j in range(L**(d-1) if d > 1 else 1): + part_path = os.path.join(depth_name, f'part_{j}') + # sample the text info for the next subset + for i in range(L if d > 0 else 1): + txt_path = os.path.join(part_path, f'subset_{i}.txt') + + # go to the subset for video frames sampling + next_depth_name = os.path.join(depth_name, f'depth_{d+1}') + next_part_path = os.path.join(next_depth_name, f'part_{i+L*j}') # `i` cause we want to sample each corresponding *subset* + + # depths > 0 are *guaranteed* to have L videos in their part_j folders + + # now sampling each first frame at the next level + L_frames = [read_first_frame(os.path.join(next_part_path, f'subset_{k}.mp4')) for k in range(L)] + + # write all the L sampled frames to an mp4 in the folder dataset + if overwrite_fps: + fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) + + write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.gif'), L_frames, overwrite_dims, width, height, fps) + shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.txt')) + + tq.set_description(f'Depth {d}, part {j}, subset{i}') + #tq.set_description(os.path.join(next_part_path, f'subset_{0}.mp4')) + tq.update(1) + + tq.close() + +def main(): + parser = argparse.ArgumentParser(description="Convert the chopped labeled tree-like data into a FolderDataset") + parser.add_argument("outpath", help="Path where to save the end FolderDataset", default=os.getcwd()) + parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--D", help="Tree depth") + parser.add_argument("--overwrite_dims", help="Preserve the original video dims", action="store_true") + parser.add_argument("--w", help="Output video width", default=384) + parser.add_argument("--h", help="Output video height", default=256) + parser.add_argument("--overwrite_fps", help="Preserve the original video fps", action="store_true") + parser.add_argument("--fps", help="Output video fps", default=12) + args = parser.parse_args() + move_the_files(args.outpath, int(args.L), int(args.D), bool(args.overwrite_dims), int(args.w), int(args.h), bool(args.overwrite_fps), int(args.fps)) + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/configs/video_folder.yaml b/configs/video_folder.yaml index d0529b1..a3b8628 100644 --- a/configs/video_folder.yaml +++ b/configs/video_folder.yaml @@ -30,6 +30,7 @@ trainable_modules: - "attn1" - "attn2" - "attn3" + - "infinet" seed: 1234 mixed_precision: "fp16" use_8bit_adam: False # This seems to be incompatible at the moment. diff --git a/models/unet_3d_blocks.py b/models/unet_3d_blocks.py old mode 100755 new mode 100644 diff --git a/models/unet_3d_condition.py b/models/unet_3d_condition.py old mode 100755 new mode 100644 diff --git a/train.py b/train.py index 4ddd208..f070239 100644 --- a/train.py +++ b/train.py @@ -334,11 +334,13 @@ def main( num_training_steps=max_train_steps * gradient_accumulation_steps, ) + train_infinet = 'infinet' in trainable_modules if trainable_modules is not None else False + # Get the training dataset if train_data.pop("type", "regular") == "folder": - train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer) + train_dataset = VideoFolderDataset(**train_data, tokenizer=tokenizer, train_infinet=train_infinet) else: - train_dataset = VideoDataset(**train_data, tokenizer=tokenizer, train_infinet='infinet' in trainable_modules if trainable_modules is not None else False) + train_dataset = VideoDataset(**train_data, tokenizer=tokenizer, train_infinet=train_infinet) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -398,6 +400,7 @@ def main( logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") + logger.info(f" InfiNet training = {train_infinet}") global_step = 0 first_epoch = 0 diff --git a/utils/dataset.py b/utils/dataset.py index ab3e1be..c7b6154 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -322,6 +322,7 @@ def __init__( fps: int = 8, path: str = "./data", fallback_prompt: str = "", + train_infinet=False, **kwargs ): self.tokenizer = tokenizer @@ -335,6 +336,7 @@ def __init__( self.n_sample_frames = n_sample_frames self.fps = fps + self.train_infinet = train_infinet def get_prompt_ids(self, prompt): return self.tokenizer( @@ -349,7 +351,8 @@ def __len__(self): return len(self.video_files) def __getitem__(self, index): - vr = decord.VideoReader(self.video_files[index], width=self.width, height=self.height) + vid_filename = self.video_files[index] + vr = decord.VideoReader(vid_filename, width=self.width, height=self.height) native_fps = vr.get_avg_fps() every_nth_frame = round(native_fps / self.fps) @@ -365,12 +368,18 @@ def __getitem__(self, index): video = vr.get_batch(idxs) video = rearrange(video, "f h w c -> f c h w") - if os.path.exists(self.video_files[index].replace(".mp4", ".txt")): - with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f: + if os.path.exists(vid_filename.replace(".mp4", ".txt")): + with open(vid_filename.replace(".mp4", ".txt"), "r") as f: prompt = f.read() else: prompt = self.fallback_prompt + + # TODO: use regex + depth = 0 + if vid_filename.startswith('depth_'): + depth = int(vid_filename[len('depth_'):vid_filename[len('depth_'):].index('_')]) prompt_ids = self.get_prompt_ids(prompt) - return {"pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt} + return {"pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, "diffusion_depth":depth} + diff --git a/video_chop.py b/video_chop.py index 2529dc0..4b2b33d 100644 --- a/video_chop.py +++ b/video_chop.py @@ -68,7 +68,9 @@ def chop_video(video_path: str, folder:str, L: int, start_frame:int) -> int: video.release() return total_frames -def stuff(video_path: str, L: int): +def stuff(video_path: str, L: int, only_once = True): + + only_once = True if not os.path.exists(video_path): raise FileNotFoundError(f"Video file '{video_path}' not found.") @@ -87,6 +89,7 @@ def stuff(video_path: str, L: int): start_frame = 0 while start_frame < total_frames - L: + dir_name = f"scenario_{scenario}" os.mkdir(os.path.join(cur_dir_name, dir_name)) video_path_new = os.path.join(cur_dir_name, dir_name, vid_name) @@ -94,6 +97,9 @@ def stuff(video_path: str, L: int): video_path = video_path_new start_frame += chop_video(video_path, dir_name, L, start_frame) scenario += 1 + + if only_once: + break os.rename(video_path, os.path.join(os.getcwd(), vid_name)) os.mkdir(orig_name) @@ -106,8 +112,9 @@ def main(): parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") parser.add_argument("video_file", help="Path to the video file.") parser.add_argument("--L", help="Num of splits on each level.") + parser.add_argument("--subscenariosplit", help="Should it split ", action='store_true', default=False) args = parser.parse_args() - stuff(args.video_file, int(args.L)) + stuff(args.video_file, int(args.L), bool(args.subscenariosplit != None and args.subscenariosplit)) if __name__ == "__main__": main() From 1e60627e2bec5945f60393a64bdcbd0aa7015d8f Mon Sep 17 00:00:00 2001 From: kabachuha Date: Sat, 22 Apr 2023 18:50:40 +0300 Subject: [PATCH 12/12] move to the utils --- chops_to_folder_dataset.py | 102 ------------------------------- video_chop.py | 120 ------------------------------------- 2 files changed, 222 deletions(-) delete mode 100644 chops_to_folder_dataset.py delete mode 100644 video_chop.py diff --git a/chops_to_folder_dataset.py b/chops_to_folder_dataset.py deleted file mode 100644 index 1861662..0000000 --- a/chops_to_folder_dataset.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import sys -import argparse -import cv2 -from tqdm import tqdm -import shutil -from pathlib import Path -from PIL import Image - -def write_as_video(output_filename, video_frames, overwrite_dims, width, height, fps): - if overwrite_dims: - height, width, _ = video_frames[0].shape - #fourcc = cv2.VideoWriter_fourcc(*'mp4v') - #out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) - #for j in video_frames: - # out.write(j) - video_frames = [Image.fromarray(cv2.cvtColor(j, cv2.COLOR_BGR2RGB)) for j in video_frames] - if overwrite_dims: - video_frames = [j.resize(width, height) for j in video_frames] - - video_frames[0].save(output_filename, save_all=True, append_images=video_frames[1:]) - - #out.release() - -def read_first_frame(video_path): - patience = 5 - p = 0 - video = cv2.VideoCapture(video_path) - ret = False - while not ret: - ret, frame = video.read() - p += 1 - if p > patience: - raise Exception(f'Cannot read the video at {video_path}') - video.release() - return frame - -def get_fps(video_path): - video = cv2.VideoCapture(video_path) - fps = int(video.get(cv2.CAP_PROP_FPS)) - video.release() - return fps - -def move_the_files(init_path, L, depth, overwrite_dims, width, height, overwrite_fps, fps): - - folder_dataset_path = os.path.join(init_path, 'folder_dataset') - os.mkdir(folder_dataset_path) - depth_name = init_path - - t_counter=0 - for d in range(0, depth): - for j in range(L**(d-1) if d > 1 else 1): - for i in range(L if d > 0 else 1): - t_counter+=1 - tq = tqdm(total=t_counter) - - for d in range(0, depth): - depth_name = os.path.join(depth_name, f'depth_{d}') - for j in range(L**(d-1) if d > 1 else 1): - part_path = os.path.join(depth_name, f'part_{j}') - # sample the text info for the next subset - for i in range(L if d > 0 else 1): - txt_path = os.path.join(part_path, f'subset_{i}.txt') - - # go to the subset for video frames sampling - next_depth_name = os.path.join(depth_name, f'depth_{d+1}') - next_part_path = os.path.join(next_depth_name, f'part_{i+L*j}') # `i` cause we want to sample each corresponding *subset* - - # depths > 0 are *guaranteed* to have L videos in their part_j folders - - # now sampling each first frame at the next level - L_frames = [read_first_frame(os.path.join(next_part_path, f'subset_{k}.mp4')) for k in range(L)] - - # write all the L sampled frames to an mp4 in the folder dataset - if overwrite_fps: - fps = get_fps(os.path.join(next_part_path, f'subset_{0}.mp4')) - - write_as_video(os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.gif'), L_frames, overwrite_dims, width, height, fps) - shutil.copy(txt_path, os.path.join(folder_dataset_path, f'depth_{d}_part_{j}_subset{i+L*j}.txt')) - - tq.set_description(f'Depth {d}, part {j}, subset{i}') - #tq.set_description(os.path.join(next_part_path, f'subset_{0}.mp4')) - tq.update(1) - - tq.close() - -def main(): - parser = argparse.ArgumentParser(description="Convert the chopped labeled tree-like data into a FolderDataset") - parser.add_argument("outpath", help="Path where to save the end FolderDataset", default=os.getcwd()) - parser.add_argument("--L", help="Num of splits on each level.") - parser.add_argument("--D", help="Tree depth") - parser.add_argument("--overwrite_dims", help="Preserve the original video dims", action="store_true") - parser.add_argument("--w", help="Output video width", default=384) - parser.add_argument("--h", help="Output video height", default=256) - parser.add_argument("--overwrite_fps", help="Preserve the original video fps", action="store_true") - parser.add_argument("--fps", help="Output video fps", default=12) - args = parser.parse_args() - move_the_files(args.outpath, int(args.L), int(args.D), bool(args.overwrite_dims), int(args.w), int(args.h), bool(args.overwrite_fps), int(args.fps)) - -if __name__ == "__main__": - main() - \ No newline at end of file diff --git a/video_chop.py b/video_chop.py deleted file mode 100644 index 4b2b33d..0000000 --- a/video_chop.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import sys -import argparse -import cv2 -from tqdm import tqdm -from pathlib import Path - -def chop_video(video_path: str, folder:str, L: int, start_frame:int) -> int: - if not os.path.exists(video_path): - raise FileNotFoundError(f"Video file '{video_path}' not found.") - - video = cv2.VideoCapture(video_path) - fps = int(video.get(cv2.CAP_PROP_FPS)) - total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - start_frame - - # Calculate the maximum depth level - max_depth = 0 - while L ** (max_depth) <= total_frames: - max_depth += 1 - - max_depth = max_depth - 1 - - total_frames = L**max_depth - - video_frames = [] - for _ in tqdm(range(start_frame), desc='Reading dummy frames'): - _, _ = video.read() - - for _ in tqdm(range(total_frames), desc='Reading video frames'): - ret, frame = video.read() - if ret: - video_frames.append(frame) - - dir_name = folder#Path(video_path).stem - #dir_name = os.path.join(folder, dir_name) - - for curr_depth in range(max_depth): - num_splits = L ** curr_depth - frames_per_split = total_frames // num_splits - if dir_name == "": - dir_name = os.path.join(f"depth_{curr_depth}") - else: - dir_name = os.path.join(dir_name, f"depth_{curr_depth}") - os.makedirs(dir_name, exist_ok=True) - - for i in tqdm(range(num_splits), desc=f'Depth {curr_depth}'): - - os.makedirs(os.path.join(dir_name, f"part_{i//L}"), exist_ok=True) - output_filename = f"{dir_name}/part_{i//L}/subset_{i%L}.mp4" - height, width, _ = video_frames[0].shape - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height)) - - start_index = i * frames_per_split - end_index = (i + 1) * frames_per_split - - print(f'start_index: {start_index}') - print(f'end_index: {end_index}') - - for j in tqdm(range(start_index, end_index), desc=f'Subset {i}, {len(range(start_index, end_index))} Frames'): - out.write(video_frames[j]) - - out.release() - # create a txt file alongside the video - with open(f"{dir_name}/part_{i//L}/subset_{i%L}.txt", "w") as f: - f.write(f"") - - video.release() - return total_frames - -def stuff(video_path: str, L: int, only_once = True): - - only_once = True - - if not os.path.exists(video_path): - raise FileNotFoundError(f"Video file '{video_path}' not found.") - - video = cv2.VideoCapture(video_path) - fps = int(video.get(cv2.CAP_PROP_FPS)) - total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - - cur_dir_name = os.path.split(video_path)[0]#Path(video_path).stem - orig_name = Path(video_path).stem - #dir_name = cur_dir_name - #os.mkdir(dir_name) - vid_name = os.path.split(video_path)[1] - - scenario = 0 - start_frame = 0 - - while start_frame < total_frames - L: - - dir_name = f"scenario_{scenario}" - os.mkdir(os.path.join(cur_dir_name, dir_name)) - video_path_new = os.path.join(cur_dir_name, dir_name, vid_name) - os.rename(video_path, video_path_new) - video_path = video_path_new - start_frame += chop_video(video_path, dir_name, L, start_frame) - scenario += 1 - - if only_once: - break - - os.rename(video_path, os.path.join(os.getcwd(), vid_name)) - os.mkdir(orig_name) - - for i in os.listdir(os.getcwd()): - if i.startswith('scenario_'): - os.rename(i, os.path.join(orig_name, i)) - -def main(): - parser = argparse.ArgumentParser(description="Chop a video file into subsets of frames.") - parser.add_argument("video_file", help="Path to the video file.") - parser.add_argument("--L", help="Num of splits on each level.") - parser.add_argument("--subscenariosplit", help="Should it split ", action='store_true', default=False) - args = parser.parse_args() - stuff(args.video_file, int(args.L), bool(args.subscenariosplit != None and args.subscenariosplit)) - -if __name__ == "__main__": - main()