Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 155 additions & 113 deletions pyaml/configuration/factory.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,164 @@
# PyAML factory (construct AML objects from config files)
import importlib
import pprint as pp
import traceback
from threading import Lock

from .config_exception import PyAMLConfigException
from ..exception import PyAMLException
from ..lattice.element import Element

#TODO:
#Implement trace for error management. Hints: Implement private field __file__ in dictionary to report errors.

_ALL_ELEMENTS: dict = {}

def buildObject(d:dict):
"""Build an object from the dict"""

if not isinstance(d,dict):
raise PyAMLException("Unexpected object " + str(d))
if not "type" in d:
raise PyAMLException("No type specified for " + str(type(d)) + ":" + str(d))
type_str = d["type"]
del d["type"]

try:
module = importlib.import_module(type_str)
except ModuleNotFoundError as ex:
raise PyAMLException(f"Module referenced in type cannot be founded: '{type_str}'") from ex

# Get the config object
config_cls = getattr(module, "ConfigModel", None)
if config_cls is None:
raise ValueError(f"ConfigModel class '{type_str}.ConfigModel' not found")

# Get the class name
cls_name = getattr(module, "PYAMLCLASS", None)
if cls_name is None:
raise ValueError(f"PYAMLCLASS definition not found in '{type_str}'")

try:

# Validate the model
cfg = config_cls.model_validate(d)

# Construct and return the object
elem_cls = getattr(module, cls_name, None)
if elem_cls is None:
raise ValueError(
f"Unknown element class '{type_str}.{cls_name}'"
)

obj = elem_cls(cfg)
register_element(obj)
return obj

except Exception as e:

print(traceback.format_exc())
print(e)
print(type_str)
pp.pprint(d)
#Fatal
quit()


def depthFirstBuild(d):
"""Main factory function (Depth-first factory)"""

if isinstance(d,list):
# list can be a list of objects or a list of native types
l = []
for index, e in enumerate(d):
if isinstance(e,dict) or isinstance(e,list):
try:
obj = depthFirstBuild(e)
l.append(obj)
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(f"[{index}]", pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLConfigException(f"[{index}]") from ex
else:
l.append(e)
return l

elif isinstance(d,dict):
for key, value in d.items():
if isinstance(value,dict) or isinstance(value,list):
class BuildStrategy:
def can_handle(self, module: object, config_dict: dict) -> bool:
"""Return True if this strategy can handle the module/config."""
raise NotImplementedError

def build(self, module: object, config_dict: dict):
"""Build the object according to custom logic."""
raise NotImplementedError

class PyAMLFactory:
"""Singleton factory to build PyAML elements with future compatibility logic."""

_instance = None
_lock = Lock()

def __new__(cls):
"""
No matter how many times you call PyAMLFactory(), it will be created only once.
"""
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._elements = {}
cls._instance._strategies = []
return cls._instance

def register_strategy(self, strategy: BuildStrategy):
"""Register a plugin-based strategy for object creation."""
self._strategies.append(strategy)

def remove_strategy(self, strategy: BuildStrategy):
"""Register a plugin-based strategy for object creation."""
self._strategies.remove(strategy)

def build_object(self, d:dict):
"""Build an object from the dict"""

if not isinstance(d,dict):
raise PyAMLException("Unexpected object " + str(d))
if not "type" in d:
raise PyAMLException("No type specified for " + str(type(d)) + ":" + str(d))
type_str = d.pop("type")

try:
module = importlib.import_module(type_str)
except ModuleNotFoundError as ex:
raise PyAMLException(f"Module referenced in type cannot be founded: '{type_str}'") from ex

# Try plugin strategies first
for strategy in self._strategies:
try:
obj = depthFirstBuild(value)
# Replace the inner dict by the object itself
d[key]=obj
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(key, pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLConfigException(key) from ex

# We are now on leaf (no nested object), we can construct
try:
obj = buildObject(d)
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(None, pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLException("An exception occurred while building object") from ex
return obj

raise PyAMLException("Unexpected element found.")

def register_element(elt):
if isinstance(elt,Element):
name = str(elt)
if name in _ALL_ELEMENTS:
raise PyAMLException(f"element {name} already defined")
_ALL_ELEMENTS[name] = elt


def get_element(name:str):
if name not in _ALL_ELEMENTS:
raise PyAMLException(f"element {name} not defined")
return _ALL_ELEMENTS[name]

def clear():
_ALL_ELEMENTS.clear()
if strategy.can_handle(module, d):
obj = strategy.build(module, d)
self.register_element(obj)
return obj
except Exception as e:
raise PyAMLException("Custom strategy failed") from e

# Default loading strategy
# Get the config object
config_cls = getattr(module, "ConfigModel", None)
if config_cls is None:
raise ValueError(f"ConfigModel class '{type_str}.ConfigModel' not found")

# Get the class name
cls_name = getattr(module, "PYAMLCLASS", None)
if cls_name is None:
raise ValueError(f"PYAMLCLASS definition not found in '{type_str}'")

try:

# Validate the model
cfg = config_cls.model_validate(d)

# Construct and return the object
elem_cls = getattr(module, cls_name, None)
if elem_cls is None:
raise ValueError(
f"Unknown element class '{type_str}.{cls_name}'"
)

obj = elem_cls(cfg)
self.register_element(obj)
return obj

except Exception as e:
raise PyAMLConfigException(f'{type_str}.{cls_name}') from e


def depth_first_build(self, d):
"""Main factory function (Depth-first factory)"""

if isinstance(d,list):
# list can be a list of objects or a list of native types
l = []
for index, e in enumerate(d):
if isinstance(e,dict) or isinstance(e,list):
try:
obj = self.depth_first_build(e)
l.append(obj)
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(f"[{index}]", pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLConfigException(f"[{index}]") from ex
else:
l.append(e)
return l

elif isinstance(d,dict):
for key, value in d.items():
if isinstance(value,dict) or isinstance(value,list):
try:
obj = self.depth_first_build(value)
# Replace the inner dict by the object itself
d[key]=obj
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(key, pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLConfigException(key) from ex

# We are now on leaf (no nested object), we can construct
try:
obj = self.build_object(d)
except PyAMLException as pyaml_ex:
raise PyAMLConfigException(None, pyaml_ex) from pyaml_ex
except Exception as ex:
raise PyAMLException("An exception occurred while building object") from ex
return obj

raise PyAMLException("Unexpected element found.")

def register_element(self, elt):
if isinstance(elt,Element):
name = str(elt)
if name in self._elements:
raise PyAMLException(f"element {name} already defined")
self._elements[name] = elt


def get_element(self, name:str):
if name not in self._elements:
raise PyAMLException(f"element {name} not defined")
return self._elements[name]

def clear(self):
self._elements.clear()

factory = PyAMLFactory()

# For backward compatibility
buildObject = factory.build_object
depthFirstBuild = factory.depth_first_build
register_element = factory.register_element
get_element = factory.get_element
clear = factory.clear
56 changes: 56 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import types
import pytest
import subprocess
import sys
import pathlib
import numpy as np
from pyaml.control.readback_value import Value
from pydantic import BaseModel
from pyaml.configuration.factory import factory, BuildStrategy, clear


@pytest.fixture
Expand Down Expand Up @@ -105,3 +108,56 @@ def scalar_vector():
def broadcast_matrix():
"""Return a 3x3 matrix filled with 2s for broadcasted multiplication tests."""
return np.full((3, 3), 2)


# ────────────── Simulated module ──────────────

class MockConfig(BaseModel):
name: str

class MockElement:
def __init__(self, config):
self.name = config.name

mock_module = types.ModuleType("mock_module")
mock_module.ConfigModel = MockConfig
mock_module.PYAMLCLASS = "MockElement"
mock_module.MockElement = MockElement


# ────────────── Custom strategy ──────────────

class MockStrategy(BuildStrategy):
def can_handle(self, module, config_dict):
return config_dict.get("custom") is True

def build(self, module, config_dict):
name = config_dict.get("name", "default")
return MockElement(config=MockConfig(name=f"custom_{name}"))


# ────────────── Pytest fixtures ──────────────

@pytest.fixture(scope="module", autouse=True)
def inject_mock_module():
"""Inject a simulated external module into sys.modules."""
sys.modules["mock_module"] = mock_module
yield
sys.modules.pop("mock_module", None)


@pytest.fixture(autouse=True)
def clear_factory_registry():
"""Clear element registry before/after each test."""
clear()
yield
clear()


@pytest.fixture(autouse=True)
def register_mock_strategy():
"""Register and unregister mock build strategy."""
strategy = MockStrategy()
factory.register_strategy(strategy)
yield
factory.remove_strategy(strategy)
25 changes: 25 additions & 0 deletions tests/test_factory_custom_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pyaml.configuration.factory import depthFirstBuild
from tests.conftest import MockElement


def test_factory_build_default():
"""Test default PyAML module loading."""
data = {
"type": "mock_module",
"name": "simple"
}
obj = depthFirstBuild(data)
assert isinstance(obj, MockElement)
assert obj.name == "simple"


def test_factory_with_custom_strategy():
"""Test that custom BuildStrategy overrides default logic."""
data = {
"type": "mock_module",
"name": "injected",
"custom": True
}
obj = depthFirstBuild(data)
assert isinstance(obj, MockElement)
assert obj.name == "custom_injected"
Loading