11# PyAML config file loader
22import logging
33import json
4- from typing import Union
4+ from typing import Union , TYPE_CHECKING
55from pathlib import Path
66import io
7+ import os
78
89import yaml
910from yaml .loader import SafeLoader
11+ from yaml import CLoader
1012from yaml .constructor import ConstructorError
1113import collections .abc
1214
13- from . import get_root_folder
1415from .. import PyAMLException
16+ from pyaml .configuration .factory import Factory
17+
18+ if TYPE_CHECKING :
19+ from pyaml .accelerator import Accelerator
20+
1521
1622logger = logging .getLogger (__name__ )
1723
1824accepted_suffixes = [".yaml" , ".yml" , ".json" ]
1925
26+
27+ ROOT = {"path" : Path .cwd ().resolve ()}
28+
29+
30+ def set_root_folder (path : Union [str , Path ]):
31+ """
32+ Set the root path for configuration files.
33+ """
34+ ROOT ["path" ] = Path (path )
35+
36+
37+ def get_root_folder () -> Path :
38+ """
39+ Get the root path for configuration files.
40+ """
41+ return ROOT ["path" ]
42+
2043class PyAMLConfigCyclingException (PyAMLException ):
2144
2245 def __init__ (self , error_filename :str , path_stack :list [Path ]):
@@ -25,12 +48,26 @@ def __init__(self, error_filename:str, path_stack:list[Path]):
2548 super ().__init__ (f"Circular file inclusion of { error_filename } . File list before reaching it: { parent_file_stack } " )
2649 pass
2750
28- def load (filename :str , paths_stack :list = None ) -> Union [dict ,list ]:
51+ def load_accelerator (filename :str , use_fast_loader :bool = False ) -> "Accelerator" :
52+ """ Load an accelerator from file."""
53+
54+ # Asume that all files are referenced from folder where main AML file is stored
55+ if not os .path .exists (filename ):
56+ raise PyAMLException (f"{ filename } file not found" )
57+ rootfolder = os .path .abspath (os .path .dirname (filename ))
58+ set_root_folder (rootfolder )
59+ config_dict = load (os .path .basename (filename ),None ,use_fast_loader )
60+ aml = Factory .depth_first_build (config_dict )
61+
62+ Factory .clear ()
63+ return aml
64+
65+ def load (filename :str , paths_stack :list = None , use_fast_loader :bool = False ) -> Union [dict ,list ]:
2966 """Load recursively a configuration setup"""
3067 if filename .endswith (".yaml" ) or filename .endswith (".yml" ):
31- l = YAMLLoader (filename , paths_stack )
68+ l = YAMLLoader (filename , paths_stack , use_fast_loader )
3269 elif filename .endswith (".json" ):
33- l = JSONLoader (filename , paths_stack )
70+ l = JSONLoader (filename , paths_stack , use_fast_loader )
3471 else :
3572 raise PyAMLException (f"{ filename } File format not supported (only .yaml .yml or .json)" )
3673 return l .load ()
@@ -58,7 +95,7 @@ def expand_dict(self,d:dict):
5895 for key , value in d .items ():
5996 try :
6097 if hasToExpand (value ):
61- d [key ] = load (value , self .files_stack )
98+ d [key ] = load (value , self .files_stack , self . use_fast_loader )
6299 else :
63100 self .expand (value )
64101 except PyAMLConfigCyclingException as pyaml_ex :
@@ -119,21 +156,24 @@ def construct_mapping(self, node, deep=False):
119156
120157# YAML loader
121158class YAMLLoader (Loader ):
122- def __init__ (self , filename : str , parent_paths_stack :list ):
159+ def __init__ (self , filename : str , parent_paths_stack :list , use_fast_loader : bool ):
123160 super ().__init__ (filename , parent_paths_stack )
161+ self ._loader = SafeLineLoader if not use_fast_loader else CLoader
162+ self .use_fast_loader = use_fast_loader
124163
125164 def load (self ) -> Union [dict ,list ]:
126165 logger .log (logging .DEBUG , f"Loading YAML file '{ self .path } '" )
127166 with open (self .path ) as file :
128167 try :
129- return self .expand (yaml .load (file ,Loader = SafeLineLoader ))
168+ return self .expand (yaml .load (file ,Loader = self . _loader ))
130169 except yaml .YAMLError as e :
131170 raise PyAMLException (str (self .path ) + ": " + str (e )) from e
132171
133172# JSON loader
134173class JSONLoader (Loader ):
135- def __init__ (self , filename : str , parent_paths_stack :list ):
174+ def __init__ (self , filename : str , parent_paths_stack :list , use_fast_loader : bool ):
136175 super ().__init__ (filename , parent_paths_stack )
176+ self .use_fast_loader = False
137177
138178 def load (self ) -> Union [dict ,list ]:
139179 logger .log (logging .DEBUG , f"Loading JSON file '{ self .path } '" )
0 commit comments