Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 0 additions & 49 deletions pyaml/configuration/cfg_dict.py

This file was deleted.

71 changes: 47 additions & 24 deletions pyaml/configuration/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import fnmatch
import importlib
from threading import Lock
from typing import TypedDict, get_type_hints

from pydantic import ValidationError

Expand Down Expand Up @@ -49,11 +50,15 @@ def handle_validation_error(self, e, type_str: str, location_str: str, field_loc
# Discard pydantic stack trace
raise PyAMLConfigException(f"{globalMessage} for object: '{type_str}' {location_str}") from None

def build_object(self, d: dict, ignore_external: bool = False):
"""Build an object from the dict"""
def get_field_type(self, config_cls, field_name) -> type:
if config_cls is None:
return None
type_hints = get_type_hints(config_cls)
return type_hints[field_name] if field_name in type_hints else None

location = d.pop("__location__", None)
field_locations = d.pop("__fieldlocations__", None)
def get_infos(self, d, ignore_external: bool):
location = d["__location__"] if "__location__" in d else None
field_locations = d["__fieldlocations__"] if "__fieldlocations__" in d else None
location_str = ""
if location:
file, line, col = location
Expand All @@ -64,9 +69,9 @@ def build_object(self, d: dict, ignore_external: bool = False):
if "type" not in d:
raise PyAMLConfigException(f"No type specified for {str(type(d))}:{str(d)} {location_str}")

module_str = d.pop("type")
class_str = d.pop("class", None)
validation_class_str = d.pop("validation_class", "ConfigModel")
module_str = d["type"]
class_str = d["class"] if "class" in d else None
validation_class_str = d["validation_class"] if "validation_class" in d else "ConfigModel"

# Import the module
try:
Expand All @@ -88,41 +93,57 @@ def build_object(self, d: dict, ignore_external: bool = False):
f"PYAMLCLASS definition not found or class not specified in '{module_str}' {location_str}"
)

# Get the validation class
config_cls = getattr(module, validation_class_str, None)
if config_cls is None:
raise PyAMLConfigException(f"No validation class for '{module.__name__}.{class_str}' {location_str}")

return (module, config_cls, class_str, field_locations, location_str)

def build_object(self, d: dict, ignore_external: bool = False):
"""Build an object from the dict"""

(module, config_cls, class_str, field_locations, location_str) = self.get_infos(d, ignore_external)

# Clean up dict
d.pop("__location__", None)
d.pop("__fieldlocations__", None)
d.pop("type")
d.pop("class", None)
d.pop("validation_class", None)

control_modes = d.pop("control_modes", None)

if control_modes is None:
# Immediate contruction

# Get the validation class
config_cls = getattr(module, validation_class_str, None)
if config_cls is None:
raise PyAMLConfigException(f"No validation class for '{module_str}.{class_str}' {location_str}")

# Validate the model
try:
cfg = config_cls.model_validate(d)
except ValidationError as e:
self.handle_validation_error(e, module_str, location_str, field_locations)
self.handle_validation_error(e, module.__name__, location_str, field_locations)

elem_cls = getattr(module, class_str, None)
if elem_cls is None:
raise PyAMLConfigException(f"Unknown element class '{module_str}.{class_str}' {location_str}")
raise PyAMLConfigException(f"Unknown element class '{module.__name__}.{class_str}' {location_str}")

# Construct and return the object
try:
obj = elem_cls(cfg)
self.register_element(obj)
except Exception as e:
raise PyAMLConfigException(f"{str(e)} when creating '{module_str}.{class_str}' {location_str}") from e
raise PyAMLConfigException(
f"{str(e)} when creating '{module.__name__}.{class_str}' {location_str}"
) from e

else:
# Delayed construction
element_name = d.pop("name", None)
if element_name is None:
raise PyAMLConfigException(
f"Name not speficied for element class '{module_str}.{class_str}' {location_str}"
f"Name not speficied for element class '{module.__name__}.{class_str}' {location_str}"
)
obj = UnboundElement(element_name, class_str, module_str, control_modes, d)
obj = UnboundElement(element_name, class_str, module.__name__, control_modes, d)

return obj

Expand Down Expand Up @@ -171,17 +192,19 @@ def depth_first_build(self, d, ignore_external: bool):
return l

elif isinstance(d, dict):
# Do not recurse CfgDict
if "type" in d:
if d["type"] == "pyaml.configuration.cfg_dict":
return self.build_object(d)
_, config_cls, *_ = self.get_infos(d, ignore_external)

for key, value in d.items():
if not key == "__fieldlocations__":
if isinstance(value, dict) or isinstance(value, list):
obj = self.depth_first_build(value, ignore_external)
# Replace the inner dict by the object itself
d[key] = obj
# Get the type of the field
fieldType = self.get_field_type(config_cls, key)
# Do not recurse dict defined in ConfigModel
# pydantic use TypedDict not usable with isinstance
if str(fieldType) != "<class 'dict'>":
obj = self.depth_first_build(value, ignore_external)
# Replace the inner dict by the object itself
d[key] = obj

# We are now on leaf (no nested object), we can construct
return self.build_object(d, ignore_external)
Expand Down
50 changes: 24 additions & 26 deletions tests/config/bad_conf.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
type: pyaml.pyaml
instruments:
- type: pyaml.instrument
facility: ESRF
machine: sr
energy: 6e9
simulators:
- type: pyalkqln # Error here
lattice: sr/lattices/ebs.mat
name: design
data_folder: /data/store
arrays:
- type: pyaml.arrays.magnet
name: HCORR
elements:
- SH1A-C01-H
- SH1A-C02-H
- type: pyaml.arrays.magnet
name: VCORR
elements:
- SH1A-C01-V
- SH1A-C02-V
devices:
- sr/quadrupoles/QF1AC01.yaml
- sr/correctors/SH1AC01.yaml
- sr/correctors/SH1AC02.yaml
type: pyaml.accelerator
facility: ESRF
machine: sr
energy: 6e9
simulators:
- type: pyalkqln # Error here
lattice: sr/lattices/ebs.mat
name: design
data_folder: /data/store
arrays:
- type: pyaml.arrays.magnet
name: HCORR
elements:
- SH1A-C01-H
- SH1A-C02-H
- type: pyaml.arrays.magnet
name: VCORR
elements:
- SH1A-C01-V
- SH1A-C02-V
devices:
- sr/quadrupoles/QF1AC01.yaml
- sr/correctors/SH1AC01.yaml
- sr/correctors/SH1AC02.yaml
10 changes: 3 additions & 7 deletions tests/test_accelerator_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pyaml import PyAMLConfigException
from pyaml.accelerator import Accelerator, ElementHolder
from pyaml.common.element import __pyaml_repr__
from pyaml.configuration.cfg_dict import CfgDict
from pyaml.control.controlsystem import ControlSystemAdapter


Expand Down Expand Up @@ -60,7 +59,7 @@ def test_accelerator_load_supports_remote_sources(http_config_server, ebs_lattic
class MyControlSystemConfigModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
name: str
dconfig: CfgDict
dconfig: dict


class MyControlSystem(ControlSystemAdapter):
Expand All @@ -72,7 +71,7 @@ def name(self) -> str:
return self._cfg.name

def dconfig(self) -> dict:
return self._cfg.dconfig.get()
return self._cfg.dconfig

def __repr__(self):
return __pyaml_repr__(self)
Expand All @@ -91,10 +90,7 @@ def test_config_dict():
"class": "MyControlSystem",
"validation_class": "MyControlSystemConfigModel",
"name": "live",
"dconfig": {
"type": "pyaml.configuration.cfg_dict",
"cfg_dict": {"prefix": "VA:", "info": {"param1": "Param1 value", "param2": 12345.0}},
},
"dconfig": {"prefix": "VA:", "info": {"param1": "Param1 value", "param2": 12345.0}},
}
],
"devices": [],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_error_location(test_file):
print(str(exc.value))
test_file_names = test_file.split("/")
test_file_name = test_file_names[len(test_file_names) - 1]
assert f"{test_file_name} at line 8, column 9" in str(exc.value)
assert f"{test_file_name} at line 6, column 5" in str(exc.value)


@pytest.mark.parametrize(
Expand Down
Loading