diff --git a/pySC/configuration/bpm_system_conf.py b/pySC/configuration/bpm_system_conf.py index bc547172..f9a5452e 100644 --- a/pySC/configuration/bpm_system_conf.py +++ b/pySC/configuration/bpm_system_conf.py @@ -2,7 +2,6 @@ import logging from ..core.simulated_commissioning import SimulatedCommissioning -from ..core.bpm_system import BPM_FIELDS_TO_INITIALISE, BPM_FIELDS_TO_INITIALISE_ONES from .general import get_error, get_indices_and_names from .supports_conf import generate_element_misalignments @@ -87,11 +86,7 @@ def configure_bpms(SC: SimulatedCommissioning) -> None: SC.bpm_system.noise_tbt_x = np.array(bpms_tbt_noise) SC.bpm_system.noise_tbt_y = np.array(bpms_tbt_noise) - nbpm = len(bpms_indices) - for field in BPM_FIELDS_TO_INITIALISE: - setattr(SC.bpm_system, field, np.zeros(nbpm, dtype=float)) - for field in BPM_FIELDS_TO_INITIALISE_ONES: - setattr(SC.bpm_system, field, np.ones(nbpm, dtype=float)) + SC.bpm_system.initialize_empty_arrays() for index, bpm_category in zip(bpms_indices, bpms_categories): generate_element_misalignments(SC, index, bpms_conf[bpm_category]) diff --git a/pySC/core/bpm_system.py b/pySC/core/bpm_system.py index 281c58cb..6d8e80ce 100644 --- a/pySC/core/bpm_system.py +++ b/pySC/core/bpm_system.py @@ -12,9 +12,9 @@ def _rotation_matrix(a): return np.array([[np.cos(a), -np.sin(a)], [np.sin(a), np.cos(a)]]) -BPM_FIELDS_TO_INITIALISE = ['offsets_x', 'offsets_y', 'rolls', - 'bba_offsets_x', 'bba_offsets_y', - 'reference_x', 'reference_y'] +BPM_FIELDS_TO_INITIALISE_ZEROS = ['offsets_x', 'offsets_y', 'rolls', + 'bba_offsets_x', 'bba_offsets_y', + 'reference_x', 'reference_y'] # These fields are initialized to ones (not zeros) — multiplicative corrections BPM_FIELDS_TO_INITIALISE_ONES = ['gain_corrections_x', 'gain_corrections_y'] @@ -54,8 +54,19 @@ class BPMSystem(BaseModel, extra='forbid'): def initialize(self): if len(self.rolls) > 0: self.update_rot_matrices() + if len(self.indices): + self.initialize_empty_arrays() return self + def initialize_empty_arrays(self): + nbpm = len(self.indices) + for field in BPM_FIELDS_TO_INITIALISE_ZEROS: + if not len(getattr(self, field)): # array is empty + setattr(self, field, np.zeros(nbpm, dtype=float)) + for field in BPM_FIELDS_TO_INITIALISE_ONES: + if not len(getattr(self, field)): # array is empty + setattr(self, field, np.ones(nbpm, dtype=float)) + def update_rot_matrices(self): self._rot_matrices = _rotation_matrix(self.rolls) diff --git a/tests/core/test_bpm_system.py b/tests/core/test_bpm_system.py index c6013bfc..7f9d3984 100644 --- a/tests/core/test_bpm_system.py +++ b/tests/core/test_bpm_system.py @@ -69,6 +69,56 @@ def test_gain_corrections_default_ones(sc): np.testing.assert_array_equal(bpm.gain_corrections_y, np.ones(len(bpm.indices))) +# --------------------------------------------------------------------------- +# backwards-compatible field initialisation +# --------------------------------------------------------------------------- + +def test_bpm_system_initializes_missing_empty_arrays(): + """Older payloads with BPM indices get missing BPM arrays initialized.""" + bpm = BPMSystem( + indices=[10, 20, 30], + names=['BPM1', 'BPM2', 'BPM3'], + calibration_errors_x=[0.0, 0.0, 0.0], + calibration_errors_y=[0.0, 0.0, 0.0], + noise_co_x=[0.0, 0.0, 0.0], + noise_co_y=[0.0, 0.0, 0.0], + noise_tbt_x=[0.0, 0.0, 0.0], + noise_tbt_y=[0.0, 0.0, 0.0], + ) + + for field in ['offsets_x', 'offsets_y', 'rolls', 'bba_offsets_x', 'bba_offsets_y', + 'reference_x', 'reference_y']: + np.testing.assert_array_equal(getattr(bpm, field), np.zeros(3), err_msg=field) + + np.testing.assert_array_equal(bpm.gain_corrections_x, np.ones(3)) + np.testing.assert_array_equal(bpm.gain_corrections_y, np.ones(3)) + + +def test_bpm_system_initialize_empty_arrays_preserves_existing_values(): + """initialize_empty_arrays fills only empty fields and preserves supplied arrays.""" + bpm = BPMSystem( + indices=[10, 20], + offsets_x=[1.0, 2.0], + gain_corrections_x=[3.0, 4.0], + ) + + bpm.initialize_empty_arrays() + + np.testing.assert_array_equal(bpm.offsets_x, np.array([1.0, 2.0])) + np.testing.assert_array_equal(bpm.gain_corrections_x, np.array([3.0, 4.0])) + np.testing.assert_array_equal(bpm.offsets_y, np.zeros(2)) + np.testing.assert_array_equal(bpm.gain_corrections_y, np.ones(2)) + + +def test_bpm_system_validation_updates_rot_matrices_when_rolls_provided(): + """Validation still builds rotation matrices when rolls are present.""" + bpm = BPMSystem(indices=[10, 20], rolls=[0.0, 0.1]) + + assert bpm._rot_matrices is not None + assert bpm._rot_matrices.shape == (2, 2, 2) + np.testing.assert_array_almost_equal(bpm._rot_matrices[:, :, 0], np.eye(2)) + + # --------------------------------------------------------------------------- # einsum rotation contraction # ---------------------------------------------------------------------------