Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/commands/custom_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def run(self):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")

def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:
Expand Down
62 changes: 51 additions & 11 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.import_utils import _is_package_available
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
ComponentSpec,
Expand Down Expand Up @@ -240,6 +241,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):

config_name = "modular_config.json"
model_name = None
_requirements: Union[List[Tuple[str, str]], Tuple[str, str]] = None

@classmethod
def _get_signature_keys(cls, obj):
Expand Down Expand Up @@ -302,6 +304,19 @@ def from_pretrained(
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)

if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])

hub_kwargs_names = [
"cache_dir",
"force_download",
Expand All @@ -314,16 +329,6 @@ def from_pretrained(
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}

config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)

class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
Expand All @@ -349,8 +354,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}

self.register_to_config(auto_map=auto_map)

# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)

self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
Expand Down Expand Up @@ -2539,3 +2549,33 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
return state.get(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")


def _validate_requirements(reqs):
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return []

final: List[Tuple[str, str]] = []
for req, specified_ver in normalized_reqs:
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
raise ValueError(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver != req_actual_ver:
logger.warning(
f"Version of {req} was specified to be {specified_ver} in the configuration. However, the actual installed version if {req_actual_ver}. Things might work unexpected."
)

final.append((req, specified_ver))

return final


def _normalize_requirements(reqs):
if not reqs:
return []
if isinstance(reqs, tuple) and len(reqs) == 2 and isinstance(reqs[0], str):
req_seq: List[Tuple[str, str]] = [reqs] # single pair
else:
req_seq = reqs
return req_seq