From 8ea84ef443eaa28bbe9217cdf8bbd606f6bf6e70 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Thu, 13 Mar 2025 20:36:48 +0000 Subject: [PATCH 1/2] dictable --- .github/workflows/main.yml | 88 +- autoconf/__init__.py | 1 + autoconf/class_path.py | 122 +-- autoconf/dictable.py | 762 +++++++++--------- autoconf/directory_config.py | 462 +++++------ autoconf/output.py | 128 +-- autoconf/tools/decorators.py | 65 ++ requirements.txt | 8 +- scripts/convert_config.py | 74 +- scripts/convert_prior_configs.py | 110 +-- test_autoconf/files/config/embedded.yaml | 8 +- test_autoconf/files/config/logging.yaml | 24 +- test_autoconf/files/config/one/two.ini | 4 +- test_autoconf/files/config/output.yaml | 4 +- .../config/priors/subdirectory/subconfig.yaml | 8 +- .../files/config/priors/test_yaml_config.yaml | 8 +- test_autoconf/files/default/embedded.yaml | 6 +- test_autoconf/files/default/logging.yaml | 24 +- test_autoconf/files/default/one.yaml | 6 +- .../source_code/subdirectory/subconfig.py | 6 +- test_autoconf/json_prior/test_yaml_config.py | 54 +- test_autoconf/test_decorator.py | 38 +- test_autoconf/test_default.py | 176 ++-- test_autoconf/test_dictable.py | 402 ++++----- test_autoconf/test_output_config.py | 92 +-- 25 files changed, 1373 insertions(+), 1307 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4b14c48..f0ab395 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,44 +1,44 @@ -name: Tests - -on: [push, pull_request] - -jobs: - unittest: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.9, '3.10', '3.11', '3.12'] - steps: - - name: Checkout PyAutoConf - uses: actions/checkout@v2 - with: - path: PyAutoConf - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip3 install --upgrade pip - pip3 install wheel - pip3 install numpy - pip3 install pytest==6.2.5 coverage pytest-cov - pip3 install -r PyAutoConf/requirements.txt - - name: Run tests - run: | - pushd PyAutoConf - python3 -m pytest --cov autoconf --cov-report xml:coverage.xml - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - - name: Slack send - if: ${{ failure() }} - id: slack - uses: slackapi/slack-github-action@v1.21.0 - env: - SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} - with: - channel-id: C03S98FEDK2 - payload: | - { - "text": "${{ github.repository }}/${{ github.ref_name }} (Python ${{ matrix.python-version }}) build result: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" - } +name: Tests + +on: [push, pull_request] + +jobs: + unittest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9, '3.10', '3.11', '3.12'] + steps: + - name: Checkout PyAutoConf + uses: actions/checkout@v2 + with: + path: PyAutoConf + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip3 install --upgrade pip + pip3 install wheel + pip3 install numpy + pip3 install pytest==6.2.5 coverage pytest-cov + pip3 install -r PyAutoConf/requirements.txt + - name: Run tests + run: | + pushd PyAutoConf + python3 -m pytest --cov autoconf --cov-report xml:coverage.xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + - name: Slack send + if: ${{ failure() }} + id: slack + uses: slackapi/slack-github-action@v1.21.0 + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + with: + channel-id: C03S98FEDK2 + payload: | + { + "text": "${{ github.repository }}/${{ github.ref_name }} (Python ${{ matrix.python-version }}) build result: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" + } diff --git a/autoconf/__init__.py b/autoconf/__init__.py index f9ff537..cdecd83 100644 --- a/autoconf/__init__.py +++ b/autoconf/__init__.py @@ -1,5 +1,6 @@ from . import exc from .tools.decorators import cached_property +from .tools.decorators import lazy_property from .conf import Config from .conf import instance from .json_prior.config import default_prior diff --git a/autoconf/class_path.py b/autoconf/class_path.py index 3c257c0..5d3f185 100644 --- a/autoconf/class_path.py +++ b/autoconf/class_path.py @@ -1,61 +1,61 @@ -import builtins -import importlib -import re -from typing import List, Type - - -def get_class_path(cls: type) -> str: - """ - The full import path of the type - """ - if hasattr(cls, "__class_path__"): - cls = cls.__class_path__ - return re.search("'(.*)'", str(cls))[1] - - -def get_class(class_path: str) -> Type[object]: - return GetClass(class_path).cls - - -class GetClass: - def __init__(self, class_path): - self.class_path = class_path - - @property - def _class_path_array(self) -> List[str]: - """ - A list of strings describing the module and class of the - real object represented here - """ - return self.class_path.split(".") - - @property - def _class_name(self) -> str: - """ - The name of the real class - """ - return self._class_path_array[-1] - - @property - def _module_path(self) -> str: - """ - The path of the module containing the real class - """ - return ".".join(self._class_path_array[:-1]) - - @property - def _module(self): - """ - The module containing the real class - """ - try: - return importlib.import_module(self._module_path) - except ValueError: - return builtins - - @property - def cls(self) -> Type[object]: - """ - The class of the real object - """ - return getattr(self._module, self._class_name) +import builtins +import importlib +import re +from typing import List, Type + + +def get_class_path(cls: type) -> str: + """ + The full import path of the type + """ + if hasattr(cls, "__class_path__"): + cls = cls.__class_path__ + return re.search("'(.*)'", str(cls))[1] + + +def get_class(class_path: str) -> Type[object]: + return GetClass(class_path).cls + + +class GetClass: + def __init__(self, class_path): + self.class_path = class_path + + @property + def _class_path_array(self) -> List[str]: + """ + A list of strings describing the module and class of the + real object represented here + """ + return self.class_path.split(".") + + @property + def _class_name(self) -> str: + """ + The name of the real class + """ + return self._class_path_array[-1] + + @property + def _module_path(self) -> str: + """ + The path of the module containing the real class + """ + return ".".join(self._class_path_array[:-1]) + + @property + def _module(self): + """ + The module containing the real class + """ + try: + return importlib.import_module(self._module_path) + except ValueError: + return builtins + + @property + def cls(self) -> Type[object]: + """ + The class of the real object + """ + return getattr(self._module, self._class_name) diff --git a/autoconf/dictable.py b/autoconf/dictable.py index a2a8dfd..a59f6c0 100644 --- a/autoconf/dictable.py +++ b/autoconf/dictable.py @@ -1,381 +1,381 @@ -import inspect -import json -import logging - -import numpy as np -from pathlib import Path -from typing import Union, Callable, Set, Tuple - -from autoconf.class_path import get_class_path, get_class - -logger = logging.getLogger(__name__) - -np_type_map = { - "bool": "bool_", -} - - -def nd_array_as_dict(obj: np.ndarray) -> dict: - """ - Converts a numpy array to a dictionary representation. - """ - np_type = str(obj.dtype) - return { - "type": "ndarray", - "array": obj.tolist(), - "dtype": np_type_map.get(np_type, np_type), - } - - -def nd_array_from_dict(nd_array_dict: dict, **_) -> np.ndarray: - """ - Converts a dictionary representation back to a numpy array. - """ - return np.array(nd_array_dict["array"], dtype=getattr(np, nd_array_dict["dtype"])) - - -def is_array(obj) -> bool: - """ - True if the object is a numpy array or an ArrayImpl (i.e. from JAX) - """ - if isinstance(obj, np.ndarray): - return True - try: - return obj.__class__.__name__ == "ArrayImpl" - except AttributeError: - return False - - -def compound_key_dict(obj): - """ - Converts a dictionary with compound keys to a dictionary with a single key. - """ - return { - "type": "compound_dict", - "arguments": [ - { - "key": to_dict(key), - "value": to_dict(value), - } - for key, value in obj.items() - ], - } - - -def to_dict(obj, filter_args: Tuple[str, ...] = ()) -> dict: - if isinstance(obj, (int, float, str, bool, type(None))): - return obj - - if isinstance(obj, slice): - return { - "type": "slice", - "start": to_dict(obj.start), - "stop": to_dict(obj.stop), - "step": to_dict(obj.step), - } - - if isinstance(obj, np.number): - return { - "type": "np.number", - "dtype": str(obj.dtype), - "value": obj.item(), - } - - if inspect.isfunction(obj): - return { - "type": "function", - "class_path": obj.__module__ + "." + obj.__qualname__, - } - - if hasattr(obj, "dict"): - try: - return obj.dict() - except TypeError as e: - logger.debug(e) - - if is_array(obj): - try: - return nd_array_as_dict(obj) - except Exception as e: - logger.info(e) - - if isinstance(obj, Path): - return { - "type": "path", - "path": str(obj), - } - - if inspect.isclass(obj): - return { - "type": "type", - "class_path": get_class_path(obj), - } - - if isinstance(obj, list): - return {"type": "list", "values": list(map(to_dict, obj))} - if isinstance(obj, tuple): - return {"type": "tuple", "values": list(map(to_dict, obj))} - if isinstance(obj, dict): - result = { - "type": "dict", - "arguments": { - key: to_dict(value) - for key, value in obj.items() - if key not in filter_args - }, - } - try: - json.dumps(result) - return result - except TypeError: - return compound_key_dict(obj) - - if obj.__class__.__name__ == "method": - return to_dict(obj()) - if obj.__class__.__module__ == "builtins": - return obj - - if inspect.isclass(type(obj)): - return instance_as_dict(obj, filter_args=filter_args) - - return obj - - -def get_arguments(obj) -> Set[str]: - """ - Get the arguments of a class. This is done by inspecting the constructor. - - If the constructor has a **kwargs parameter, the arguments of the base classes are also included. - - Parameters - ---------- - obj - The class to get the arguments of. - - Returns - ------- - A set of the arguments of the class. - """ - args_spec = inspect.getfullargspec(obj.__init__) - args = set(args_spec.args[1:]) - if args_spec.varkw: - for base in obj.__bases__: - if base is object: - continue - args |= get_arguments(base) - return args - - -def instance_as_dict(obj, filter_args: Tuple[str, ...] = ()): - """ - Convert an instance of a class to a dictionary representation. - - Serialises any children of the object which are given as constructor arguments - or included in the __identifier_fields__ attribute. - - Sets any fields in the __nullify_fields__ attribute to None. - - Parameters - ---------- - obj - The instance of the class to be converted to a dictionary representation. - filter_args - A tuple of arguments to exclude from the dictionary representation. - - Returns - ------- - A dictionary representation of the instance. - """ - arguments = get_arguments(type(obj)) - try: - arguments |= set(obj.__identifier_fields__) - except (AttributeError, TypeError): - pass - - argument_dict = { - arg: getattr(obj, arg) - for arg in arguments - if arg not in filter_args - if hasattr(obj, arg) - and not inspect.ismethod( - getattr(obj, arg), - ) - } - try: - for field in obj.__nullify_fields__: - argument_dict[field] = None - except (AttributeError, TypeError): - pass - - try: - for field in obj.__exclude_fields__: - try: - argument_dict.pop(field) - except KeyError: - logger.debug(f"Field {field} not found in object") - except (AttributeError, TypeError): - pass - - return { - "type": "instance", - "class_path": get_class_path(obj.__class__), - "arguments": {key: to_dict(value) for key, value in argument_dict.items()}, - } - - -__parsers = { - "ndarray": nd_array_from_dict, -} - - -def register_parser(type_: str, parser: Callable[[dict], object]): - """ - Register a parser for a given type. - - This parser will be used to instantiate objects of the given type from a - dictionary representation. - - Parameters - ---------- - type_ - The type of the object to be parsed. This is a string uniquely - identifying the type. - parser - A function which takes a dictionary representation of an object and - returns an instance of the object. - """ - __parsers[type_] = parser - - -def from_dict(dictionary, **kwargs): - """ - Instantiate an instance of a class from its dictionary representation. - - Parameters - ---------- - dictionary - An object which may be a dictionary representation of an object. - - This may contain the following keys: - type: str - The type of the object. This may be a built-in type, a numpy array, - a list, a dictionary, a class, or an instance of a class. - - If a parser has been registered for the given type that parser will - be used to instantiate the object. - class_path: str - The path to the class of the object. This is used to instantiate - the object if it is not a built-in type. - arguments: dict - A dictionary of arguments to pass to the class constructor. - - Returns - ------- - An object that was represented by the dictionary. - """ - if isinstance(dictionary, (int, float, str, bool, type(None))): - return dictionary - - if isinstance(dictionary, list): - return list(map(from_dict, dictionary)) - - if isinstance(dictionary, tuple): - return tuple(map(from_dict, dictionary)) - - try: - type_ = dictionary["type"] - except KeyError: - logger.debug("No type field in dictionary") - return dictionary - except TypeError as e: - logger.debug(e) - return None - - if type_ == "path": - return Path(dictionary["path"]) - - if type_ == "slice": - return slice( - from_dict(dictionary["start"]), - from_dict(dictionary["stop"]), - from_dict(dictionary["step"]), - ) - - if type_ == "np.number": - return getattr( - np, - dictionary["dtype"], - )(dictionary["value"]) - - if type_ == "function": - return get_class(dictionary["class_path"]) - - if type_ in __parsers: - return __parsers[type_](dictionary, **kwargs) - - if type_ == "list": - return list(map(from_dict, dictionary["values"])) - if type_ == "tuple": - return tuple(map(from_dict, dictionary["values"])) - if type_ == "dict": - return { - key: from_dict(value, **kwargs) - for key, value in dictionary["arguments"].items() - } - if type_ == "compound_dict": - return { - from_dict(item["key"], **kwargs): from_dict(item["value"], **kwargs) - for item in dictionary["arguments"] - } - - if type_ == "type": - return get_class(dictionary["class_path"]) - - cls = get_class(dictionary["class_path"]) - - if cls is np.ndarray: - return nd_array_from_dict(dictionary) - if hasattr(cls, "from_dict"): - return cls.from_dict(dictionary, **kwargs) - - # noinspection PyArgumentList - return cls( - **{ - name: from_dict(value, **kwargs) - for name, value in dictionary["arguments"].items() - } - ) - - -def from_json(file_path: str): - """ - Load the dictable object to a .json file, whereby all attributes are converted from the .json file's dictionary - representation to create the instance of the object - - A json file of the instance can be created from the .json file via the `output_to_json` method. - - Parameters - ---------- - file_path - The path to the .json file that the dictionary representation of the object is loaded from. - """ - with open(file_path, "r+") as f: - cls_dict = json.load(f) - - return from_dict(cls_dict) - - -def output_to_json(obj, file_path: Union[Path, str]): - """ - Output the dictable object to a .json file, whereby all attributes are converted to a dictionary representation - first. - - An instance of the object can be created from the .json file via the `from_json` method. - - Parameters - ---------- - file_path - The path to the .json file that the dictionary representation of the object is written too. - """ - with open(file_path, "w+") as f: - json.dump(to_dict(obj), f, indent=4) +import inspect +import json +import logging + +import numpy as np +from pathlib import Path +from typing import Union, Callable, Set, Tuple + +from autoconf.class_path import get_class_path, get_class + +logger = logging.getLogger(__name__) + +np_type_map = { + "bool": "bool_", +} + + +def nd_array_as_dict(obj: np.ndarray) -> dict: + """ + Converts a numpy array to a dictionary representation. + """ + np_type = str(obj.dtype) + return { + "type": "ndarray", + "array": obj.tolist(), + "dtype": np_type_map.get(np_type, np_type), + } + + +def nd_array_from_dict(nd_array_dict: dict, **_) -> np.ndarray: + """ + Converts a dictionary representation back to a numpy array. + """ + return np.array(nd_array_dict["array"], dtype=getattr(np, nd_array_dict["dtype"])) + + +def is_array(obj) -> bool: + """ + True if the object is a numpy array or an ArrayImpl (i.e. from JAX) + """ + if isinstance(obj, np.ndarray): + return True + try: + return obj.__class__.__name__ == "ArrayImpl" + except AttributeError: + return False + + +def compound_key_dict(obj): + """ + Converts a dictionary with compound keys to a dictionary with a single key. + """ + return { + "type": "compound_dict", + "arguments": [ + { + "key": to_dict(key), + "value": to_dict(value), + } + for key, value in obj.items() + ], + } + + +def to_dict(obj, filter_args: Tuple[str, ...] = ()) -> dict: + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + + if isinstance(obj, slice): + return { + "type": "slice", + "start": to_dict(obj.start), + "stop": to_dict(obj.stop), + "step": to_dict(obj.step), + } + + if isinstance(obj, np.number): + return { + "type": "np.number", + "dtype": str(obj.dtype), + "value": obj.item(), + } + + if inspect.isfunction(obj): + return { + "type": "function", + "class_path": obj.__module__ + "." + obj.__qualname__, + } + + if hasattr(obj, "dict"): + try: + return obj.dict() + except TypeError as e: + logger.debug(e) + + if is_array(obj): + try: + return nd_array_as_dict(obj) + except Exception as e: + logger.info(e) + + if isinstance(obj, Path): + return { + "type": "path", + "path": str(obj), + } + + if inspect.isclass(obj): + return { + "type": "type", + "class_path": get_class_path(obj), + } + + if isinstance(obj, list): + return {"type": "list", "values": list(map(to_dict, obj))} + if isinstance(obj, tuple): + return {"type": "tuple", "values": list(map(to_dict, obj))} + if isinstance(obj, dict): + result = { + "type": "dict", + "arguments": { + key: to_dict(value) + for key, value in obj.items() + if key not in filter_args + }, + } + try: + json.dumps(result) + return result + except TypeError: + return compound_key_dict(obj) + + if obj.__class__.__name__ == "method": + return to_dict(obj()) + if obj.__class__.__module__ == "builtins": + return obj + + if inspect.isclass(type(obj)): + return instance_as_dict(obj, filter_args=filter_args) + + return obj + + +def get_arguments(obj) -> Set[str]: + """ + Get the arguments of a class. This is done by inspecting the constructor. + + If the constructor has a **kwargs parameter, the arguments of the base classes are also included. + + Parameters + ---------- + obj + The class to get the arguments of. + + Returns + ------- + A set of the arguments of the class. + """ + args_spec = inspect.getfullargspec(obj.__init__) + args = set(args_spec.args[1:]) + if args_spec.varkw: + for base in obj.__bases__: + if base is object: + continue + args |= get_arguments(base) + return args + + +def instance_as_dict(obj, filter_args: Tuple[str, ...] = ()): + """ + Convert an instance of a class to a dictionary representation. + + Serialises any children of the object which are given as constructor arguments + or included in the __identifier_fields__ attribute. + + Sets any fields in the __nullify_fields__ attribute to None. + + Parameters + ---------- + obj + The instance of the class to be converted to a dictionary representation. + filter_args + A tuple of arguments to exclude from the dictionary representation. + + Returns + ------- + A dictionary representation of the instance. + """ + arguments = get_arguments(type(obj)) + try: + arguments |= set(obj.__identifier_fields__) + except (AttributeError, TypeError): + pass + + argument_dict = { + arg: getattr(obj, arg) + for arg in arguments + if arg not in filter_args + if hasattr(obj, arg) + and not inspect.ismethod( + getattr(obj, arg), + ) + } + try: + for field in obj.__nullify_fields__: + argument_dict[field] = None + except (AttributeError, TypeError): + pass + + try: + for field in obj.__exclude_fields__: + try: + argument_dict.pop(field) + except KeyError: + logger.debug(f"Field {field} not found in object") + except (AttributeError, TypeError): + pass + + return { + "type": "instance", + "class_path": get_class_path(obj.__class__), + "arguments": {key: to_dict(value) for key, value in argument_dict.items()}, + } + + +__parsers = { + "ndarray": nd_array_from_dict, +} + + +def register_parser(type_: str, parser: Callable[[dict], object]): + """ + Register a parser for a given type. + + This parser will be used to instantiate objects of the given type from a + dictionary representation. + + Parameters + ---------- + type_ + The type of the object to be parsed. This is a string uniquely + identifying the type. + parser + A function which takes a dictionary representation of an object and + returns an instance of the object. + """ + __parsers[type_] = parser + + +def from_dict(dictionary, **kwargs): + """ + Instantiate an instance of a class from its dictionary representation. + + Parameters + ---------- + dictionary + An object which may be a dictionary representation of an object. + + This may contain the following keys: + type: str + The type of the object. This may be a built-in type, a numpy array, + a list, a dictionary, a class, or an instance of a class. + + If a parser has been registered for the given type that parser will + be used to instantiate the object. + class_path: str + The path to the class of the object. This is used to instantiate + the object if it is not a built-in type. + arguments: dict + A dictionary of arguments to pass to the class constructor. + + Returns + ------- + An object that was represented by the dictionary. + """ + if isinstance(dictionary, (int, float, str, bool, type(None))): + return dictionary + + if isinstance(dictionary, list): + return list(map(from_dict, dictionary)) + + if isinstance(dictionary, tuple): + return tuple(map(from_dict, dictionary)) + + try: + type_ = dictionary["type"] + except KeyError: + logger.debug("No type field in dictionary") + return dictionary + except TypeError as e: + logger.debug(e) + return None + + if type_ == "path": + return Path(dictionary["path"]) + + if type_ == "slice": + return slice( + from_dict(dictionary["start"]), + from_dict(dictionary["stop"]), + from_dict(dictionary["step"]), + ) + + if type_ == "np.number": + return getattr( + np, + dictionary["dtype"], + )(dictionary["value"]) + + if type_ == "function": + return get_class(dictionary["class_path"]) + + if type_ in __parsers: + return __parsers[type_](dictionary, **kwargs) + + if type_ == "list": + return list(map(from_dict, dictionary["values"])) + if type_ == "tuple": + return tuple(map(from_dict, dictionary["values"])) + if type_ == "dict": + return { + key: from_dict(value, **kwargs) + for key, value in dictionary["arguments"].items() + } + if type_ == "compound_dict": + return { + from_dict(item["key"], **kwargs): from_dict(item["value"], **kwargs) + for item in dictionary["arguments"] + } + + if type_ == "type": + return get_class(dictionary["class_path"]) + + cls = get_class(dictionary["class_path"]) + + if cls is np.ndarray: + return nd_array_from_dict(dictionary) + if hasattr(cls, "from_dict"): + return cls.from_dict(dictionary, **kwargs) + + # noinspection PyArgumentList + return cls( + **{ + name: from_dict(value, **kwargs) + for name, value in dictionary["arguments"].items() + } + ) + + +def from_json(file_path: str): + """ + Load the dictable object to a .json file, whereby all attributes are converted from the .json file's dictionary + representation to create the instance of the object + + A json file of the instance can be created from the .json file via the `output_to_json` method. + + Parameters + ---------- + file_path + The path to the .json file that the dictionary representation of the object is loaded from. + """ + with open(file_path, "r+") as f: + cls_dict = json.load(f) + + return from_dict(cls_dict) + + +def output_to_json(obj, file_path: Union[Path, str]): + """ + Output the dictable object to a .json file, whereby all attributes are converted to a dictionary representation + first. + + An instance of the object can be created from the .json file via the `from_json` method. + + Parameters + ---------- + file_path + The path to the .json file that the dictionary representation of the object is written too. + """ + with open(file_path, "w+") as f: + json.dump(to_dict(obj), f, indent=4) diff --git a/autoconf/directory_config.py b/autoconf/directory_config.py index a419178..8250025 100644 --- a/autoconf/directory_config.py +++ b/autoconf/directory_config.py @@ -1,231 +1,231 @@ -import configparser -import os -from abc import abstractmethod, ABC -from pathlib import Path - -import yaml - -from autoconf import exc - - -class AbstractConfig(ABC): - @abstractmethod - def _getitem(self, item): - pass - - def __getitem__(self, item): - if isinstance(item, int): - return self.items()[item] - return self._getitem(item) - - def items(self): - return [(key, self[key]) for key in self.keys()] - - def __len__(self): - return len(self.items()) - - @abstractmethod - def keys(self): - pass - - def family(self, cls): - for cls in family(cls): - key = cls.__name__ - try: - return self[key] - except (KeyError, configparser.NoOptionError): - pass - raise KeyError(f"No configuration found for {cls.__name__}") - - def dict(self): - d = {} - for key in self.keys(): - value = self[key] - if isinstance(value, AbstractConfig): - value = value.dict() - d[key] = value - return d - - -class DictConfig(AbstractConfig): - def keys(self): - return self.d.keys() - - def __init__(self, d): - self.d = d - - def __getitem__(self, item): - value = self.d[item] - if isinstance(value, dict): - return DictConfig(value) - return value - - def _getitem(self, item): - return self[item] - - def items(self): - for key in self.d: - yield key, self[key] - - -class YAMLConfig(AbstractConfig): - def __init__(self, path): - with open(path) as f: - self._dict = yaml.safe_load(f) - - def _getitem(self, item): - value = self._dict[item] - if isinstance(value, dict): - return DictConfig(value) - return value - - def keys(self): - return self._dict.keys() - - -class SectionConfig(AbstractConfig): - def __init__(self, path, parser, section): - self.path = path - self.section = section - self.parser = parser - - def keys(self): - with open(self.path) as f: - string = f.read() - - lines = string.split("\n") - is_section = False - for line in lines: - if line == f"[{self.section}]": - is_section = True - continue - if line.startswith("["): - is_section = False - continue - if is_section and "=" in line: - yield line.split("=")[0] - - def _getitem(self, item): - try: - result = self.parser.get(self.section, item) - if result.lower() == "true": - return True - if result.lower() == "false": - return False - if result.lower() in ("none", "null"): - return None - if result.isdigit(): - return int(result) - try: - return float(result) - except ValueError: - return result - except (configparser.NoSectionError, configparser.NoOptionError): - raise KeyError(f"No configuration found for {item} at path {self.path}") - - -class NamedConfig(AbstractConfig): - def __init__(self, config_path): - """ - Parses generic config - - Parameters - ---------- - config_path - The path to the config file - """ - self.path = config_path - self.parser = configparser.ConfigParser() - self.parser.read(self.path) - - def keys(self): - return self.parser.sections() - - def _getitem(self, item): - return SectionConfig( - self.path, - self.parser, - item, - ) - - -class RecursiveConfig(AbstractConfig): - def keys(self): - try: - return [ - path.split(".")[0] - for path in os.listdir(self.path) - if all( - [ - path != "priors", - len(path.split(".")[0]) != 0, - os.path.isdir(f"{self.path}/{path}") - or path.endswith(".ini") - or path.endswith(".yaml") - or path.endswith(".yml"), - ] - ) - ] - except FileNotFoundError as e: - raise KeyError(f"No configuration found at {self.path}") from e - - def __init__(self, path): - self.path = Path(path) - - def __eq__(self, other): - return str(self) == str(other) - - def __str__(self): - return str(self.path) - - def __repr__(self): - return f"<{self.__class__.__name__} {self.path}>" - - def _getitem(self, item): - item_path = self.path / f"{item}" - file_path = f"{item_path}.ini" - if os.path.isfile(file_path): - return NamedConfig(file_path) - yml_path = item_path.with_suffix(".yml") - if yml_path.exists(): - return YAMLConfig(yml_path) - yaml_path = item_path.with_suffix(".yaml") - if yaml_path.exists(): - return YAMLConfig(yaml_path) - if os.path.isdir(item_path): - return RecursiveConfig(item_path) - raise KeyError(f"No configuration found for {item} at path {self.path}") - - -class PriorConfigWrapper: - def __init__(self, prior_configs): - self.prior_configs = prior_configs - - def for_class_and_suffix_path(self, cls, path): - for config in self.prior_configs: - try: - return config.for_class_and_suffix_path(cls, path) - except KeyError: - pass - directories = " ".join(str(config.directory) for config in self.prior_configs) - - print() - - raise exc.ConfigException( - f"No prior config found for class: \n\n" - f"{cls.__name__} \n\n" - f"For parameter name and path: \n\n " - f"{'.'.join(path)} \n\n " - f"In any of the following directories:\n\n" - f"{directories}\n\n" - f"Either add configuration for the parameter or a type annotation for a class with valid configuration.\n\n" - f"The following readthedocs page explains prior configuration files in PyAutoFit and will help you fix " - f"the error https://pyautofit.readthedocs.io/en/latest/general/adding_a_model_component.html" - ) - - -def family(current_class): - yield current_class - for next_class in current_class.__bases__: - for val in family(next_class): - yield val +import configparser +import os +from abc import abstractmethod, ABC +from pathlib import Path + +import yaml + +from autoconf import exc + + +class AbstractConfig(ABC): + @abstractmethod + def _getitem(self, item): + pass + + def __getitem__(self, item): + if isinstance(item, int): + return self.items()[item] + return self._getitem(item) + + def items(self): + return [(key, self[key]) for key in self.keys()] + + def __len__(self): + return len(self.items()) + + @abstractmethod + def keys(self): + pass + + def family(self, cls): + for cls in family(cls): + key = cls.__name__ + try: + return self[key] + except (KeyError, configparser.NoOptionError): + pass + raise KeyError(f"No configuration found for {cls.__name__}") + + def dict(self): + d = {} + for key in self.keys(): + value = self[key] + if isinstance(value, AbstractConfig): + value = value.dict() + d[key] = value + return d + + +class DictConfig(AbstractConfig): + def keys(self): + return self.d.keys() + + def __init__(self, d): + self.d = d + + def __getitem__(self, item): + value = self.d[item] + if isinstance(value, dict): + return DictConfig(value) + return value + + def _getitem(self, item): + return self[item] + + def items(self): + for key in self.d: + yield key, self[key] + + +class YAMLConfig(AbstractConfig): + def __init__(self, path): + with open(path) as f: + self._dict = yaml.safe_load(f) + + def _getitem(self, item): + value = self._dict[item] + if isinstance(value, dict): + return DictConfig(value) + return value + + def keys(self): + return self._dict.keys() + + +class SectionConfig(AbstractConfig): + def __init__(self, path, parser, section): + self.path = path + self.section = section + self.parser = parser + + def keys(self): + with open(self.path) as f: + string = f.read() + + lines = string.split("\n") + is_section = False + for line in lines: + if line == f"[{self.section}]": + is_section = True + continue + if line.startswith("["): + is_section = False + continue + if is_section and "=" in line: + yield line.split("=")[0] + + def _getitem(self, item): + try: + result = self.parser.get(self.section, item) + if result.lower() == "true": + return True + if result.lower() == "false": + return False + if result.lower() in ("none", "null"): + return None + if result.isdigit(): + return int(result) + try: + return float(result) + except ValueError: + return result + except (configparser.NoSectionError, configparser.NoOptionError): + raise KeyError(f"No configuration found for {item} at path {self.path}") + + +class NamedConfig(AbstractConfig): + def __init__(self, config_path): + """ + Parses generic config + + Parameters + ---------- + config_path + The path to the config file + """ + self.path = config_path + self.parser = configparser.ConfigParser() + self.parser.read(self.path) + + def keys(self): + return self.parser.sections() + + def _getitem(self, item): + return SectionConfig( + self.path, + self.parser, + item, + ) + + +class RecursiveConfig(AbstractConfig): + def keys(self): + try: + return [ + path.split(".")[0] + for path in os.listdir(self.path) + if all( + [ + path != "priors", + len(path.split(".")[0]) != 0, + os.path.isdir(f"{self.path}/{path}") + or path.endswith(".ini") + or path.endswith(".yaml") + or path.endswith(".yml"), + ] + ) + ] + except FileNotFoundError as e: + raise KeyError(f"No configuration found at {self.path}") from e + + def __init__(self, path): + self.path = Path(path) + + def __eq__(self, other): + return str(self) == str(other) + + def __str__(self): + return str(self.path) + + def __repr__(self): + return f"<{self.__class__.__name__} {self.path}>" + + def _getitem(self, item): + item_path = self.path / f"{item}" + file_path = f"{item_path}.ini" + if os.path.isfile(file_path): + return NamedConfig(file_path) + yml_path = item_path.with_suffix(".yml") + if yml_path.exists(): + return YAMLConfig(yml_path) + yaml_path = item_path.with_suffix(".yaml") + if yaml_path.exists(): + return YAMLConfig(yaml_path) + if os.path.isdir(item_path): + return RecursiveConfig(item_path) + raise KeyError(f"No configuration found for {item} at path {self.path}") + + +class PriorConfigWrapper: + def __init__(self, prior_configs): + self.prior_configs = prior_configs + + def for_class_and_suffix_path(self, cls, path): + for config in self.prior_configs: + try: + return config.for_class_and_suffix_path(cls, path) + except KeyError: + pass + directories = " ".join(str(config.directory) for config in self.prior_configs) + + print() + + raise exc.ConfigException( + f"No prior config found for class: \n\n" + f"{cls.__name__} \n\n" + f"For parameter name and path: \n\n " + f"{'.'.join(path)} \n\n " + f"In any of the following directories:\n\n" + f"{directories}\n\n" + f"Either add configuration for the parameter or a type annotation for a class with valid configuration.\n\n" + f"The following readthedocs page explains prior configuration files in PyAutoFit and will help you fix " + f"the error https://pyautofit.readthedocs.io/en/latest/general/adding_a_model_component.html" + ) + + +def family(current_class): + yield current_class + for next_class in current_class.__bases__: + for val in family(next_class): + yield val diff --git a/autoconf/output.py b/autoconf/output.py index 775ec4d..fc44fc6 100644 --- a/autoconf/output.py +++ b/autoconf/output.py @@ -1,64 +1,64 @@ -from functools import wraps -from typing import Callable -import logging - -from autoconf.conf import instance - -logger = logging.getLogger(__name__) - - -def should_output(name: str) -> bool: - """ - Determine whether a file with a given name (excluding extension) should be output. - - This is configured in config/output.yaml. If the file is not present in the config, the default value is used. - - Parameters - ---------- - name - The name of the file to be output, excluding extension. - - Returns - ------- - Whether the file should be output. - """ - output_config = instance["output"] - try: - return output_config[name] - except KeyError: - return output_config["default"] - - -def conditional_output(func: Callable): - """ - Decorator for functions that output files. If the file should not be output, the function is not called. - - Parameters - ---------- - func - A method where the first argument is the name of the file to be output. - - Returns - ------- - The decorated function. - """ - - @wraps(func) - def wrapper(self, name: str, *args, **kwargs): - """ - Conditionally call the decorated function if the file should be output according - to the config. - - Parameters - ---------- - self - name - The name of the file to be output, excluding extension. - args - kwargs - """ - if should_output(name): - return func(self, name, *args, **kwargs) - logger.info(f"Skipping output of {name}") - - return wrapper +from functools import wraps +from typing import Callable +import logging + +from autoconf.conf import instance + +logger = logging.getLogger(__name__) + + +def should_output(name: str) -> bool: + """ + Determine whether a file with a given name (excluding extension) should be output. + + This is configured in config/output.yaml. If the file is not present in the config, the default value is used. + + Parameters + ---------- + name + The name of the file to be output, excluding extension. + + Returns + ------- + Whether the file should be output. + """ + output_config = instance["output"] + try: + return output_config[name] + except KeyError: + return output_config["default"] + + +def conditional_output(func: Callable): + """ + Decorator for functions that output files. If the file should not be output, the function is not called. + + Parameters + ---------- + func + A method where the first argument is the name of the file to be output. + + Returns + ------- + The decorated function. + """ + + @wraps(func) + def wrapper(self, name: str, *args, **kwargs): + """ + Conditionally call the decorated function if the file should be output according + to the config. + + Parameters + ---------- + self + name + The name of the file to be output, excluding extension. + args + kwargs + """ + if should_output(name): + return func(self, name, *args, **kwargs) + logger.info(f"Skipping output of {name}") + + return wrapper diff --git a/autoconf/tools/decorators.py b/autoconf/tools/decorators.py index a1b5fb8..dc493a6 100644 --- a/autoconf/tools/decorators.py +++ b/autoconf/tools/decorators.py @@ -19,3 +19,68 @@ def __get__(self, obj, cls): cached_property = CachedProperty + + + + + +# The is sourced from: torch.distributions.util.py +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +from jax.core import Tracer +from functools import update_wrapper + +def not_jax_tracer(x): + """ + Checks if `x` is not an array generated inside `jit`, `pmap`, `vmap`, or `lax_control_flow`. + """ + return not isinstance(x, Tracer) + + +def identity(x, *args, **kwargs): + return x + +class lazy_property(object): + r""" + Used as a decorator for lazy loading of class attributes. This uses a + non-data descriptor that calls the wrapped method to compute the property on + first call; thereafter replacing the wrapped method into an instance + attribute. + """ + + def __init__(self, wrapped): + self.wrapped = wrapped + update_wrapper(self, wrapped) + + # This is to prevent warnings from sphinx + def __call__(self, *args, **kwargs): + return self.wrapped(*args, **kwargs) + + def __get__(self, instance, obj_type=None): + if instance is None: + return self + value = self.wrapped(instance) + if not_jax_tracer(value): + setattr(instance, self.wrapped.__name__, value) + return value \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2219694..ee75b39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -typing-inspect>=0.4.0 -pathlib -PyYAML>=6.0.1 -numpy>=1.24.0,<=2.0.1 +typing-inspect>=0.4.0 +pathlib +PyYAML>=6.0.1 +numpy>=1.24.0,<=2.0.1 diff --git a/scripts/convert_config.py b/scripts/convert_config.py index 4c8f6f6..f7abba7 100755 --- a/scripts/convert_config.py +++ b/scripts/convert_config.py @@ -1,37 +1,37 @@ -#!/usr/bin/env python -""" -Converts the configuration files and directories in a given directory into YAML configs. - -Usage: -./convert_config.py path/to/directory -""" - -import os -import shutil -import sys -from pathlib import Path - -import yaml - -from autoconf.directory_config import RecursiveConfig, YAMLConfig - -target_path = Path(sys.argv[1]) - -config = RecursiveConfig(str(target_path)) - -for key in config.keys(): - value = config[key] - if isinstance(value, YAMLConfig): - continue - - d = value.dict() - path = target_path / key - with open(path.with_suffix(".yaml"), "w") as f: - yaml.dump(d, f) - - try: - os.remove(path.with_suffix(".ini")) - except FileNotFoundError: - pass - - shutil.rmtree(path, ignore_errors=True) +#!/usr/bin/env python +""" +Converts the configuration files and directories in a given directory into YAML configs. + +Usage: +./convert_config.py path/to/directory +""" + +import os +import shutil +import sys +from pathlib import Path + +import yaml + +from autoconf.directory_config import RecursiveConfig, YAMLConfig + +target_path = Path(sys.argv[1]) + +config = RecursiveConfig(str(target_path)) + +for key in config.keys(): + value = config[key] + if isinstance(value, YAMLConfig): + continue + + d = value.dict() + path = target_path / key + with open(path.with_suffix(".yaml"), "w") as f: + yaml.dump(d, f) + + try: + os.remove(path.with_suffix(".ini")) + except FileNotFoundError: + pass + + shutil.rmtree(path, ignore_errors=True) diff --git a/scripts/convert_prior_configs.py b/scripts/convert_prior_configs.py index f105729..5cb4fbd 100755 --- a/scripts/convert_prior_configs.py +++ b/scripts/convert_prior_configs.py @@ -1,55 +1,55 @@ -#!/usr/bin/env python -""" -Converts JSON prior configs to YAML equivalent. - -Usage: -./convert_prior_configs.py /path/to/prior/directory -""" -import json -import os -import sys -from pathlib import Path - -import oyaml as yaml - -ORDER = [ - "type", - "mean", - "sigma", - "lower_limit", - "upper_limit", - "width_modifier", - "gaussian_limits", -] - -for path in Path(sys.argv[1]).rglob("*.json"): - with open(path) as f: - d = json.load(f) - - with open(path.with_suffix(".yaml"), "w") as f: - yaml.dump(d, f) - - os.remove(path) - - -def sort_dict(obj): - if isinstance(obj, dict): - return { - key: sort_dict(value) - for key, value in sorted( - obj.items(), - key=lambda item: ORDER.index(item[0]) if item[0] in ORDER else 999, - ) - } - if isinstance(obj, list): - return list(map(sort_dict, obj)) - return obj - - -for path in Path(sys.argv[1]).rglob("*.yaml"): - with open(path) as f: - d = yaml.safe_load(f) - - print(sort_dict(d)) - with open(path, "w") as f: - yaml.dump(sort_dict(d), f) +#!/usr/bin/env python +""" +Converts JSON prior configs to YAML equivalent. + +Usage: +./convert_prior_configs.py /path/to/prior/directory +""" +import json +import os +import sys +from pathlib import Path + +import oyaml as yaml + +ORDER = [ + "type", + "mean", + "sigma", + "lower_limit", + "upper_limit", + "width_modifier", + "gaussian_limits", +] + +for path in Path(sys.argv[1]).rglob("*.json"): + with open(path) as f: + d = json.load(f) + + with open(path.with_suffix(".yaml"), "w") as f: + yaml.dump(d, f) + + os.remove(path) + + +def sort_dict(obj): + if isinstance(obj, dict): + return { + key: sort_dict(value) + for key, value in sorted( + obj.items(), + key=lambda item: ORDER.index(item[0]) if item[0] in ORDER else 999, + ) + } + if isinstance(obj, list): + return list(map(sort_dict, obj)) + return obj + + +for path in Path(sys.argv[1]).rglob("*.yaml"): + with open(path) as f: + d = yaml.safe_load(f) + + print(sort_dict(d)) + with open(path, "w") as f: + yaml.dump(sort_dict(d), f) diff --git a/test_autoconf/files/config/embedded.yaml b/test_autoconf/files/config/embedded.yaml index 697fad1..9b93dd9 100644 --- a/test_autoconf/files/config/embedded.yaml +++ b/test_autoconf/files/config/embedded.yaml @@ -1,4 +1,4 @@ -first: - first_a: - first_a_a: one - first_a_c: three +first: + first_a: + first_a_a: one + first_a_c: three diff --git a/test_autoconf/files/config/logging.yaml b/test_autoconf/files/config/logging.yaml index 0e803c6..6bd8722 100644 --- a/test_autoconf/files/config/logging.yaml +++ b/test_autoconf/files/config/logging.yaml @@ -1,12 +1,12 @@ -name: config -version: 1 -disable_existing_loggers: false -handlers: - console: - class: logging.StreamHandler - level: WARN - stream: ext://sys.stdout - -root: - level: INFO - handlers: [ console ] +name: config +version: 1 +disable_existing_loggers: false +handlers: + console: + class: logging.StreamHandler + level: WARN + stream: ext://sys.stdout + +root: + level: INFO + handlers: [ console ] diff --git a/test_autoconf/files/config/one/two.ini b/test_autoconf/files/config/one/two.ini index 2277aba..727638f 100644 --- a/test_autoconf/files/config/one/two.ini +++ b/test_autoconf/files/config/one/two.ini @@ -1,3 +1,3 @@ -[three] -four=five +[three] +four=five six=seven \ No newline at end of file diff --git a/test_autoconf/files/config/output.yaml b/test_autoconf/files/config/output.yaml index 1362a9c..1525237 100644 --- a/test_autoconf/files/config/output.yaml +++ b/test_autoconf/files/config/output.yaml @@ -1,3 +1,3 @@ -should_output: true -should_not_output: false +should_output: true +should_not_output: false default: true \ No newline at end of file diff --git a/test_autoconf/files/config/priors/subdirectory/subconfig.yaml b/test_autoconf/files/config/priors/subdirectory/subconfig.yaml index 0dbe2c4..de92798 100644 --- a/test_autoconf/files/config/priors/subdirectory/subconfig.yaml +++ b/test_autoconf/files/config/priors/subdirectory/subconfig.yaml @@ -1,5 +1,5 @@ -SubClass: - variable: - type: Uniform - lower_limit: 0.0 +SubClass: + variable: + type: Uniform + lower_limit: 0.0 upper_limit: 3.0 \ No newline at end of file diff --git a/test_autoconf/files/config/priors/test_yaml_config.yaml b/test_autoconf/files/config/priors/test_yaml_config.yaml index 798c786..5ffbad5 100644 --- a/test_autoconf/files/config/priors/test_yaml_config.yaml +++ b/test_autoconf/files/config/priors/test_yaml_config.yaml @@ -1,5 +1,5 @@ -YAMLClass: - variable: - type: Uniform - lower_limit: 0.0 +YAMLClass: + variable: + type: Uniform + lower_limit: 0.0 upper_limit: 3.0 \ No newline at end of file diff --git a/test_autoconf/files/default/embedded.yaml b/test_autoconf/files/default/embedded.yaml index 6c56ed7..6e78077 100644 --- a/test_autoconf/files/default/embedded.yaml +++ b/test_autoconf/files/default/embedded.yaml @@ -1,4 +1,4 @@ -first: - first_a: - first_a_a: one +first: + first_a: + first_a_a: one first_a_b: two \ No newline at end of file diff --git a/test_autoconf/files/default/logging.yaml b/test_autoconf/files/default/logging.yaml index 8d93b7f..2b5c6ba 100644 --- a/test_autoconf/files/default/logging.yaml +++ b/test_autoconf/files/default/logging.yaml @@ -1,12 +1,12 @@ -name: default -version: 1 -disable_existing_loggers: false -handlers: - console: - class: logging.StreamHandler - level: WARN - stream: ext://sys.stdout - -root: - level: INFO - handlers: [ console ] +name: default +version: 1 +disable_existing_loggers: false +handlers: + console: + class: logging.StreamHandler + level: WARN + stream: ext://sys.stdout + +root: + level: INFO + handlers: [ console ] diff --git a/test_autoconf/files/default/one.yaml b/test_autoconf/files/default/one.yaml index 426842b..28c3c61 100644 --- a/test_autoconf/files/default/one.yaml +++ b/test_autoconf/files/default/one.yaml @@ -1,4 +1,4 @@ -two: - three: - four: five +two: + three: + four: five eight: nine \ No newline at end of file diff --git a/test_autoconf/json_prior/source_code/subdirectory/subconfig.py b/test_autoconf/json_prior/source_code/subdirectory/subconfig.py index fcaf5ca..15380e1 100644 --- a/test_autoconf/json_prior/source_code/subdirectory/subconfig.py +++ b/test_autoconf/json_prior/source_code/subdirectory/subconfig.py @@ -1,3 +1,3 @@ -class SubClass: - def __init__(self, variable: float): - self.variable = variable +class SubClass: + def __init__(self, variable: float): + self.variable = variable diff --git a/test_autoconf/json_prior/test_yaml_config.py b/test_autoconf/json_prior/test_yaml_config.py index cbbc20e..fb8994a 100644 --- a/test_autoconf/json_prior/test_yaml_config.py +++ b/test_autoconf/json_prior/test_yaml_config.py @@ -1,27 +1,27 @@ -from .source_code.subdirectory.subconfig import SubClass - - -class YAMLClass: - def __init__(self, variable: float): - self.variable = variable - - -def test_load_yaml_config(config): - assert config.prior_config.for_class_and_suffix_path(YAMLClass, ["variable"]) == { - "lower_limit": 0.0, - "type": "Uniform", - "upper_limit": 3.0, - } - - -def test_embedded_path(config): - path_value_map = config.prior_config.prior_configs[0].path_value_map - assert "subdirectory.subconfig.SubClass.variable.type" in path_value_map - - -def test_subdirectory(config): - assert config.prior_config.for_class_and_suffix_path(SubClass, ["variable"]) == { - "lower_limit": 0.0, - "type": "Uniform", - "upper_limit": 3.0, - } +from .source_code.subdirectory.subconfig import SubClass + + +class YAMLClass: + def __init__(self, variable: float): + self.variable = variable + + +def test_load_yaml_config(config): + assert config.prior_config.for_class_and_suffix_path(YAMLClass, ["variable"]) == { + "lower_limit": 0.0, + "type": "Uniform", + "upper_limit": 3.0, + } + + +def test_embedded_path(config): + path_value_map = config.prior_config.prior_configs[0].path_value_map + assert "subdirectory.subconfig.SubClass.variable.type" in path_value_map + + +def test_subdirectory(config): + assert config.prior_config.for_class_and_suffix_path(SubClass, ["variable"]) == { + "lower_limit": 0.0, + "type": "Uniform", + "upper_limit": 3.0, + } diff --git a/test_autoconf/test_decorator.py b/test_autoconf/test_decorator.py index e8ebfa9..44bac15 100644 --- a/test_autoconf/test_decorator.py +++ b/test_autoconf/test_decorator.py @@ -1,19 +1,19 @@ -import pytest - -from autoconf import conf -from autoconf.conf import with_config - - -@pytest.fixture(autouse=True) -def push_configs(files_directory): - conf.instance.push(files_directory / "config") - conf.instance.push(files_directory / "default") - - -@with_config("general", "output", "identifier_version", value=9) -def test_with_config(): - assert conf.instance["general"]["output"]["identifier_version"] == 9 - - -def test_config(): - assert conf.instance["general"]["output"]["identifier_version"] == 4 +import pytest + +from autoconf import conf +from autoconf.conf import with_config + + +@pytest.fixture(autouse=True) +def push_configs(files_directory): + conf.instance.push(files_directory / "config") + conf.instance.push(files_directory / "default") + + +@with_config("general", "output", "identifier_version", value=9) +def test_with_config(): + assert conf.instance["general"]["output"]["identifier_version"] == 9 + + +def test_config(): + assert conf.instance["general"]["output"]["identifier_version"] == 4 diff --git a/test_autoconf/test_default.py b/test_autoconf/test_default.py index 6dd6c37..821c6b9 100644 --- a/test_autoconf/test_default.py +++ b/test_autoconf/test_default.py @@ -1,88 +1,88 @@ -from autoconf.mock.mock_real import Redshift - - -def test_override_file(config): - hpc = config["general"]["hpc"] - - assert hpc["hpc_mode"] is False - assert hpc["default_field"] == "hello" - - -def test_logging_config(config, files_directory): - assert config.logging_config["name"] == "config" - - config.push(files_directory / "default") - assert config.logging_config["name"] == "default" - - -def test_push(config, files_directory): - assert len(config.configs) == 2 - assert config["general"]["hpc"]["hpc_mode"] is False - - config.push(files_directory / "default") - - assert len(config.configs) == 2 - assert config["general"]["hpc"]["hpc_mode"] is True - - config.push(files_directory / "config") - - assert len(config.configs) == 2 - assert config["general"]["hpc"]["hpc_mode"] is False - - -def test_keep_first(config, files_directory): - config.push(files_directory / "default", keep_first=True) - - assert config["general"]["hpc"]["hpc_mode"] is False - - -def test_override_in_directory(config): - superscript = config["text"]["label"]["superscript"] - - assert superscript["Galaxy"] == "g" - assert superscript["default_field"] == "label default" - - -def test_novel_directory(config): - assert config["default"]["other"]["section"]["key"] == "value" - - -def test_novel_file(config): - assert config["default_file"]["section"]["key"] == "file value" - - -def test_json(config): - assert ( - config.prior_config.for_class_and_suffix_path(Redshift, ["redshift"])[ - "upper_limit" - ] - == 3.0 - ) - assert ( - config.prior_config.for_class_and_suffix_path(Redshift, ["rodshift"])[ - "upper_limit" - ] - == 4.0 - ) - - -def test_embedded_yaml_default(config): - embedded_dict = config["embedded"]["first"]["first_a"] - - assert embedded_dict["first_a_a"] == "one" - assert embedded_dict["first_a_b"] == "two" - assert embedded_dict["first_a_c"] == "three" - - -def test_as_dict(config): - embedded_dict = config["embedded"]["first"]["first_a"] - - assert {**embedded_dict} - - -def test_mix_files(config): - embedded_dict = config["one"]["two"]["three"] - - assert embedded_dict["four"] == "five" - assert embedded_dict["six"] == "seven" - assert embedded_dict["eight"] == "nine" +from autoconf.mock.mock_real import Redshift + + +def test_override_file(config): + hpc = config["general"]["hpc"] + + assert hpc["hpc_mode"] is False + assert hpc["default_field"] == "hello" + + +def test_logging_config(config, files_directory): + assert config.logging_config["name"] == "config" + + config.push(files_directory / "default") + assert config.logging_config["name"] == "default" + + +def test_push(config, files_directory): + assert len(config.configs) == 2 + assert config["general"]["hpc"]["hpc_mode"] is False + + config.push(files_directory / "default") + + assert len(config.configs) == 2 + assert config["general"]["hpc"]["hpc_mode"] is True + + config.push(files_directory / "config") + + assert len(config.configs) == 2 + assert config["general"]["hpc"]["hpc_mode"] is False + + +def test_keep_first(config, files_directory): + config.push(files_directory / "default", keep_first=True) + + assert config["general"]["hpc"]["hpc_mode"] is False + + +def test_override_in_directory(config): + superscript = config["text"]["label"]["superscript"] + + assert superscript["Galaxy"] == "g" + assert superscript["default_field"] == "label default" + + +def test_novel_directory(config): + assert config["default"]["other"]["section"]["key"] == "value" + + +def test_novel_file(config): + assert config["default_file"]["section"]["key"] == "file value" + + +def test_json(config): + assert ( + config.prior_config.for_class_and_suffix_path(Redshift, ["redshift"])[ + "upper_limit" + ] + == 3.0 + ) + assert ( + config.prior_config.for_class_and_suffix_path(Redshift, ["rodshift"])[ + "upper_limit" + ] + == 4.0 + ) + + +def test_embedded_yaml_default(config): + embedded_dict = config["embedded"]["first"]["first_a"] + + assert embedded_dict["first_a_a"] == "one" + assert embedded_dict["first_a_b"] == "two" + assert embedded_dict["first_a_c"] == "three" + + +def test_as_dict(config): + embedded_dict = config["embedded"]["first"]["first_a"] + + assert {**embedded_dict} + + +def test_mix_files(config): + embedded_dict = config["one"]["two"]["three"] + + assert embedded_dict["four"] == "five" + assert embedded_dict["six"] == "seven" + assert embedded_dict["eight"] == "nine" diff --git a/test_autoconf/test_dictable.py b/test_autoconf/test_dictable.py index 5395542..d5e6378 100644 --- a/test_autoconf/test_dictable.py +++ b/test_autoconf/test_dictable.py @@ -1,201 +1,201 @@ -import json - -import numpy as np -import pytest -from pathlib import Path - -from autoconf.dictable import to_dict, from_dict, register_parser - - -@pytest.fixture(name="array_dict") -def make_array_dict(): - return {"array": [1.0], "dtype": "float64", "type": "ndarray"} - - -@pytest.fixture(name="array") -def make_array(): - return np.array([1.0]) - - -class ArrayImpl: - def __init__(self, array): - self.array = array - - @property - def dtype(self): - return self.array.dtype - - def tolist(self): - return self.array.tolist() - - def __array__(self): - return self.array - - -def test_array_impl(array): - assert to_dict(ArrayImpl(array)) == to_dict(array) - - -def test_array_as_dict(array_dict, array): - assert to_dict(array) == array_dict - - -def test_from_dict(array_dict, array): - assert from_dict(array_dict) == array - - -@pytest.mark.parametrize( - "array", - [ - np.array([True]), - np.array([[1.0]]), - np.array([[1.0, 2.0], [3.0, 4.0]]), - np.array([[1, 2], [3, 4]]), - ], -) -def test_multiple(array): - assert (from_dict(to_dict(array)) == array).all() - - -def test_as_json(array): - assert from_dict(json.loads(json.dumps(to_dict(array)))) == array - - -def test_with_type_attribute(): - float_dict = {"class_path": "float", "type": "type"} - assert to_dict(float) == float_dict - assert from_dict(float_dict) is float - - -def test_register_parser(): - register_parser("test", lambda x: x["value"]) - assert from_dict({"type": "test", "value": 1}) == 1 - - -def test_no_type(): - assert from_dict({"hi": "there"}) == {"hi": "there"} - - -def test_serialise_path(): - path = Path("/path/to/file.json") - path_dict = to_dict(path) - assert from_dict(path_dict) == path - - -class Parent: - def __init__(self, parent_arg): - self.parent_arg = parent_arg - - -class Child(Parent): - def __init__(self, child_arg, **kwargs): - super().__init__(**kwargs) - self.child_arg = child_arg - - -def test_serialise_kwargs(): - child = Child( - child_arg="child", - parent_arg="parent", - ) - child_dict = to_dict(child) - assert child_dict == { - "arguments": { - "child_arg": "child", - "parent_arg": "parent", - }, - "class_path": "test_autoconf.test_dictable.Child", - "type": "instance", - } - new_child = from_dict(child_dict) - assert new_child.child_arg == "child" - assert new_child.parent_arg == "parent" - - -class WithOptional: - def __init__(self, arg: int = 1): - self.arg = arg - - -def test_serialise_with_arg(): - assert to_dict(WithOptional()) == { - "arguments": {"arg": 1}, - "class_path": "test_autoconf.test_dictable.WithOptional", - "type": "instance", - } - - -def test_serialise_without_arg(): - assert to_dict(WithOptional(), filter_args=("arg",)) == { - "arguments": {}, - "class_path": "test_autoconf.test_dictable.WithOptional", - "type": "instance", - } - - -class C: - pass - - -def test_tuple(): - assert from_dict( - json.loads( - json.dumps( - to_dict( - (C, C), - ) - ) - ) - ) == (C, C) - - -def function(): - return 1 - - -def test_function(): - assert ( - from_dict( - json.loads( - json.dumps( - to_dict( - function, - ) - ) - ) - )() - == 1 - ) - - -def test_slice(): - s = slice(1, 2, 3) - - assert s == from_dict(to_dict(s)) - - -def test_int_64(): - i = np.int64(1) - result = from_dict(to_dict(i)) - - assert result == 1 - assert isinstance(result, np.int64) - - -def test_int64_slice(): - s = slice(np.int64(1), np.int64(2), np.int64(3)) - - assert s == from_dict( - json.loads( - json.dumps( - to_dict(s), - ) - ) - ) - - -def test_compound_key(): - d = {(1, 2): 1} - - string = json.dumps(to_dict(d)) - assert d == from_dict(json.loads(string)) +import json + +import numpy as np +import pytest +from pathlib import Path + +from autoconf.dictable import to_dict, from_dict, register_parser + + +@pytest.fixture(name="array_dict") +def make_array_dict(): + return {"array": [1.0], "dtype": "float64", "type": "ndarray"} + + +@pytest.fixture(name="array") +def make_array(): + return np.array([1.0]) + + +class ArrayImpl: + def __init__(self, array): + self.array = array + + @property + def dtype(self): + return self.array.dtype + + def tolist(self): + return self.array.tolist() + + def __array__(self): + return self.array + + +def test_array_impl(array): + assert to_dict(ArrayImpl(array)) == to_dict(array) + + +def test_array_as_dict(array_dict, array): + assert to_dict(array) == array_dict + + +def test_from_dict(array_dict, array): + assert from_dict(array_dict) == array + + +@pytest.mark.parametrize( + "array", + [ + np.array([True]), + np.array([[1.0]]), + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.array([[1, 2], [3, 4]]), + ], +) +def test_multiple(array): + assert (from_dict(to_dict(array)) == array).all() + + +def test_as_json(array): + assert from_dict(json.loads(json.dumps(to_dict(array)))) == array + + +def test_with_type_attribute(): + float_dict = {"class_path": "float", "type": "type"} + assert to_dict(float) == float_dict + assert from_dict(float_dict) is float + + +def test_register_parser(): + register_parser("test", lambda x: x["value"]) + assert from_dict({"type": "test", "value": 1}) == 1 + + +def test_no_type(): + assert from_dict({"hi": "there"}) == {"hi": "there"} + + +def test_serialise_path(): + path = Path("/path/to/file.json") + path_dict = to_dict(path) + assert from_dict(path_dict) == path + + +class Parent: + def __init__(self, parent_arg): + self.parent_arg = parent_arg + + +class Child(Parent): + def __init__(self, child_arg, **kwargs): + super().__init__(**kwargs) + self.child_arg = child_arg + + +def test_serialise_kwargs(): + child = Child( + child_arg="child", + parent_arg="parent", + ) + child_dict = to_dict(child) + assert child_dict == { + "arguments": { + "child_arg": "child", + "parent_arg": "parent", + }, + "class_path": "test_autoconf.test_dictable.Child", + "type": "instance", + } + new_child = from_dict(child_dict) + assert new_child.child_arg == "child" + assert new_child.parent_arg == "parent" + + +class WithOptional: + def __init__(self, arg: int = 1): + self.arg = arg + + +def test_serialise_with_arg(): + assert to_dict(WithOptional()) == { + "arguments": {"arg": 1}, + "class_path": "test_autoconf.test_dictable.WithOptional", + "type": "instance", + } + + +def test_serialise_without_arg(): + assert to_dict(WithOptional(), filter_args=("arg",)) == { + "arguments": {}, + "class_path": "test_autoconf.test_dictable.WithOptional", + "type": "instance", + } + + +class C: + pass + + +def test_tuple(): + assert from_dict( + json.loads( + json.dumps( + to_dict( + (C, C), + ) + ) + ) + ) == (C, C) + + +def function(): + return 1 + + +def test_function(): + assert ( + from_dict( + json.loads( + json.dumps( + to_dict( + function, + ) + ) + ) + )() + == 1 + ) + + +def test_slice(): + s = slice(1, 2, 3) + + assert s == from_dict(to_dict(s)) + + +def test_int_64(): + i = np.int64(1) + result = from_dict(to_dict(i)) + + assert result == 1 + assert isinstance(result, np.int64) + + +def test_int64_slice(): + s = slice(np.int64(1), np.int64(2), np.int64(3)) + + assert s == from_dict( + json.loads( + json.dumps( + to_dict(s), + ) + ) + ) + + +def test_compound_key(): + d = {(1, 2): 1} + + string = json.dumps(to_dict(d)) + assert d == from_dict(json.loads(string)) diff --git a/test_autoconf/test_output_config.py b/test_autoconf/test_output_config.py index 59a7537..87adab1 100644 --- a/test_autoconf/test_output_config.py +++ b/test_autoconf/test_output_config.py @@ -1,46 +1,46 @@ -import pytest - -from autoconf import instance -from autoconf.conf import with_config -from autoconf.output import conditional_output - - -class OutputClass: - def __init__(self): - self.output_names = [] - - @conditional_output - def output_function(self, name): - self.output_names.append(name) - - -@pytest.fixture(name="output_class") -def make_output_class(): - return OutputClass() - - -@pytest.fixture(autouse=True) -def add_config(files_directory): - instance.push(files_directory / "config") - - -def test_output(output_class): - output_class.output_function("should_output") - assert output_class.output_names == ["should_output"] - - -def test_no_output(output_class): - output_class.output_function("should_not_output") - assert output_class.output_names == [] - - -@with_config("output", "default", value=True) -def test_default_true(output_class): - output_class.output_function("other") - assert output_class.output_names == ["other"] - - -@with_config("output", "default", value=False) -def test_default_false(output_class): - output_class.output_function("other") - assert output_class.output_names == [] +import pytest + +from autoconf import instance +from autoconf.conf import with_config +from autoconf.output import conditional_output + + +class OutputClass: + def __init__(self): + self.output_names = [] + + @conditional_output + def output_function(self, name): + self.output_names.append(name) + + +@pytest.fixture(name="output_class") +def make_output_class(): + return OutputClass() + + +@pytest.fixture(autouse=True) +def add_config(files_directory): + instance.push(files_directory / "config") + + +def test_output(output_class): + output_class.output_function("should_output") + assert output_class.output_names == ["should_output"] + + +def test_no_output(output_class): + output_class.output_function("should_not_output") + assert output_class.output_names == [] + + +@with_config("output", "default", value=True) +def test_default_true(output_class): + output_class.output_function("other") + assert output_class.output_names == ["other"] + + +@with_config("output", "default", value=False) +def test_default_false(output_class): + output_class.output_function("other") + assert output_class.output_names == [] From 10121db68b366f199e6785f5cf29fbc8da71404a Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 8 Apr 2025 17:47:40 +0100 Subject: [PATCH 2/2] convert to numpy array for fitsable output --- autoconf/fitsable.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/autoconf/fitsable.py b/autoconf/fitsable.py index 9901630..fae5784 100644 --- a/autoconf/fitsable.py +++ b/autoconf/fitsable.py @@ -109,6 +109,12 @@ def hdu_list_for_output_from( if ext_name_list is not None: header["EXTNAME"] = ext_name_list[i].upper() + # Convert from JAX + try: + values = np.array(values.array) + except AttributeError: + values = np.array(values) + values = flip_for_ds9_from(values) if i == 0: