diff --git a/pyaml/configuration/cfg_dict.py b/pyaml/configuration/cfg_dict.py deleted file mode 100644 index d6bdcbd2..00000000 --- a/pyaml/configuration/cfg_dict.py +++ /dev/null @@ -1,49 +0,0 @@ -import json - -from pydantic import BaseModel, ConfigDict - -from .manager import ConfigurationManager - -# Define the main class name for this module -PYAMLCLASS = "CfgDict" - - -class ConfigModel(BaseModel): - """ - Configuration model for random dict - - Parameters - ---------- - cfg_dict : dict - The dict - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") - - cfg_dict: dict - - -class CfgDict(object): - """ - Class allowing to have a dict in a configuration of an object - """ - - def __init__(self, cfg: ConfigModel): - """ - Construct a ConfigDict - - Parameters - ---------- - config: dict - Configuration dict - """ - self._config = ConfigurationManager.strip_internal_metadata(cfg.cfg_dict) - - def get(self) -> dict: - """ - Returns config dict - """ - return self._config - - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, json.dumps(self._config)) diff --git a/pyaml/configuration/factory.py b/pyaml/configuration/factory.py index 7523c777..54f7af4f 100644 --- a/pyaml/configuration/factory.py +++ b/pyaml/configuration/factory.py @@ -2,6 +2,7 @@ import fnmatch import importlib from threading import Lock +from typing import TypedDict, get_type_hints from pydantic import ValidationError @@ -49,11 +50,15 @@ def handle_validation_error(self, e, type_str: str, location_str: str, field_loc # Discard pydantic stack trace raise PyAMLConfigException(f"{globalMessage} for object: '{type_str}' {location_str}") from None - def build_object(self, d: dict, ignore_external: bool = False): - """Build an object from the dict""" + def get_field_type(self, config_cls, field_name) -> type: + if config_cls is None: + return None + type_hints = get_type_hints(config_cls) + return type_hints[field_name] if field_name in type_hints else None - location = d.pop("__location__", None) - field_locations = d.pop("__fieldlocations__", None) + def get_infos(self, d, ignore_external: bool): + location = d["__location__"] if "__location__" in d else None + field_locations = d["__fieldlocations__"] if "__fieldlocations__" in d else None location_str = "" if location: file, line, col = location @@ -64,9 +69,9 @@ def build_object(self, d: dict, ignore_external: bool = False): if "type" not in d: raise PyAMLConfigException(f"No type specified for {str(type(d))}:{str(d)} {location_str}") - module_str = d.pop("type") - class_str = d.pop("class", None) - validation_class_str = d.pop("validation_class", "ConfigModel") + module_str = d["type"] + class_str = d["class"] if "class" in d else None + validation_class_str = d["validation_class"] if "validation_class" in d else "ConfigModel" # Import the module try: @@ -88,41 +93,57 @@ def build_object(self, d: dict, ignore_external: bool = False): f"PYAMLCLASS definition not found or class not specified in '{module_str}' {location_str}" ) + # Get the validation class + config_cls = getattr(module, validation_class_str, None) + if config_cls is None: + raise PyAMLConfigException(f"No validation class for '{module.__name__}.{class_str}' {location_str}") + + return (module, config_cls, class_str, field_locations, location_str) + + def build_object(self, d: dict, ignore_external: bool = False): + """Build an object from the dict""" + + (module, config_cls, class_str, field_locations, location_str) = self.get_infos(d, ignore_external) + + # Clean up dict + d.pop("__location__", None) + d.pop("__fieldlocations__", None) + d.pop("type") + d.pop("class", None) + d.pop("validation_class", None) + control_modes = d.pop("control_modes", None) if control_modes is None: # Immediate contruction - # Get the validation class - config_cls = getattr(module, validation_class_str, None) - if config_cls is None: - raise PyAMLConfigException(f"No validation class for '{module_str}.{class_str}' {location_str}") - # Validate the model try: cfg = config_cls.model_validate(d) except ValidationError as e: - self.handle_validation_error(e, module_str, location_str, field_locations) + self.handle_validation_error(e, module.__name__, location_str, field_locations) elem_cls = getattr(module, class_str, None) if elem_cls is None: - raise PyAMLConfigException(f"Unknown element class '{module_str}.{class_str}' {location_str}") + raise PyAMLConfigException(f"Unknown element class '{module.__name__}.{class_str}' {location_str}") # Construct and return the object try: obj = elem_cls(cfg) self.register_element(obj) except Exception as e: - raise PyAMLConfigException(f"{str(e)} when creating '{module_str}.{class_str}' {location_str}") from e + raise PyAMLConfigException( + f"{str(e)} when creating '{module.__name__}.{class_str}' {location_str}" + ) from e else: # Delayed construction element_name = d.pop("name", None) if element_name is None: raise PyAMLConfigException( - f"Name not speficied for element class '{module_str}.{class_str}' {location_str}" + f"Name not speficied for element class '{module.__name__}.{class_str}' {location_str}" ) - obj = UnboundElement(element_name, class_str, module_str, control_modes, d) + obj = UnboundElement(element_name, class_str, module.__name__, control_modes, d) return obj @@ -171,17 +192,19 @@ def depth_first_build(self, d, ignore_external: bool): return l elif isinstance(d, dict): - # Do not recurse CfgDict - if "type" in d: - if d["type"] == "pyaml.configuration.cfg_dict": - return self.build_object(d) + _, config_cls, *_ = self.get_infos(d, ignore_external) for key, value in d.items(): if not key == "__fieldlocations__": if isinstance(value, dict) or isinstance(value, list): - obj = self.depth_first_build(value, ignore_external) - # Replace the inner dict by the object itself - d[key] = obj + # Get the type of the field + fieldType = self.get_field_type(config_cls, key) + # Do not recurse dict defined in ConfigModel + # pydantic use TypedDict not usable with isinstance + if str(fieldType) != "": + obj = self.depth_first_build(value, ignore_external) + # Replace the inner dict by the object itself + d[key] = obj # We are now on leaf (no nested object), we can construct return self.build_object(d, ignore_external) diff --git a/tests/config/bad_conf.yml b/tests/config/bad_conf.yml index f204f1dd..178733b4 100644 --- a/tests/config/bad_conf.yml +++ b/tests/config/bad_conf.yml @@ -1,26 +1,24 @@ -type: pyaml.pyaml -instruments: - - type: pyaml.instrument - facility: ESRF - machine: sr - energy: 6e9 - simulators: - - type: pyalkqln # Error here - lattice: sr/lattices/ebs.mat - name: design - data_folder: /data/store - arrays: - - type: pyaml.arrays.magnet - name: HCORR - elements: - - SH1A-C01-H - - SH1A-C02-H - - type: pyaml.arrays.magnet - name: VCORR - elements: - - SH1A-C01-V - - SH1A-C02-V - devices: - - sr/quadrupoles/QF1AC01.yaml - - sr/correctors/SH1AC01.yaml - - sr/correctors/SH1AC02.yaml +type: pyaml.accelerator +facility: ESRF +machine: sr +energy: 6e9 +simulators: + - type: pyalkqln # Error here + lattice: sr/lattices/ebs.mat + name: design +data_folder: /data/store +arrays: + - type: pyaml.arrays.magnet + name: HCORR + elements: + - SH1A-C01-H + - SH1A-C02-H + - type: pyaml.arrays.magnet + name: VCORR + elements: + - SH1A-C01-V + - SH1A-C02-V +devices: + - sr/quadrupoles/QF1AC01.yaml + - sr/correctors/SH1AC01.yaml + - sr/correctors/SH1AC02.yaml diff --git a/tests/test_accelerator_load.py b/tests/test_accelerator_load.py index cdd5d2be..5e1e3d6a 100644 --- a/tests/test_accelerator_load.py +++ b/tests/test_accelerator_load.py @@ -4,7 +4,6 @@ from pyaml import PyAMLConfigException from pyaml.accelerator import Accelerator, ElementHolder from pyaml.common.element import __pyaml_repr__ -from pyaml.configuration.cfg_dict import CfgDict from pyaml.control.controlsystem import ControlSystemAdapter @@ -60,7 +59,7 @@ def test_accelerator_load_supports_remote_sources(http_config_server, ebs_lattic class MyControlSystemConfigModel(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") name: str - dconfig: CfgDict + dconfig: dict class MyControlSystem(ControlSystemAdapter): @@ -72,7 +71,7 @@ def name(self) -> str: return self._cfg.name def dconfig(self) -> dict: - return self._cfg.dconfig.get() + return self._cfg.dconfig def __repr__(self): return __pyaml_repr__(self) @@ -91,10 +90,7 @@ def test_config_dict(): "class": "MyControlSystem", "validation_class": "MyControlSystemConfigModel", "name": "live", - "dconfig": { - "type": "pyaml.configuration.cfg_dict", - "cfg_dict": {"prefix": "VA:", "info": {"param1": "Param1 value", "param2": 12345.0}}, - }, + "dconfig": {"prefix": "VA:", "info": {"param1": "Param1 value", "param2": 12345.0}}, } ], "devices": [], diff --git a/tests/test_factory.py b/tests/test_factory.py index 9949dbc8..e60afad5 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -34,7 +34,7 @@ def test_error_location(test_file): print(str(exc.value)) test_file_names = test_file.split("/") test_file_name = test_file_names[len(test_file_names) - 1] - assert f"{test_file_name} at line 8, column 9" in str(exc.value) + assert f"{test_file_name} at line 6, column 5" in str(exc.value) @pytest.mark.parametrize(