diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index ce0ae516..562d5c71 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -117,13 +117,16 @@ def make_tfrecord_iterator( check out preparation script maxdiffusion/pedagogical_examples/to_tfrecords.py """ - # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. + + # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked. + is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location) + if ( config.cache_latents_text_encoder_outputs - and os.path.isdir(config.dataset_save_location) + and is_dataset_dir_valid and "load_tfrecord_cached" in config.get_keys() and config.load_tfrecord_cached ): diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 2ceb0f7e..5a27591d 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -57,8 +57,10 @@ def rename_for_custom_trasformer(key): return renamed_pt_key -def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_fusionx_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="Wan14BT2VFusioniX_fp16_.safetensors") @@ -97,7 +99,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -107,8 +109,10 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_causvid_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] with jax.default_device(device): if hf_download: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, filename="causal_model.pt") @@ -145,7 +149,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) @@ -155,18 +159,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di return flax_state_dict -def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): +def load_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: - return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: - return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) else: - return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download) + return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) -def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): - device = jax.devices(device)[0] +def load_base_wan_transformer( + pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 +): + device = jax.local_devices(backend=device)[0] subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False @@ -237,7 +245,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index abf44929..9ca2e03b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -95,7 +95,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # 4. Load pretrained weights and move them to device using the state shardings from (3) above. # This helps with loading sharded weights directly into the accelerators without fist copying them # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_transformer(config.wan_transformer_pretrained_model_name_or_path, params, "cpu") + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value