From 8e5a36ef3e4467244e0d664e57ee6090d9d066fa Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:42:51 -0600 Subject: [PATCH 1/7] feat(plotting): add reusable plot editor controls Add shared plotting helpers for Igor-style labels, editable line plots, heatmaps, and stacked histograms. Wire the heatmap editor into cluster dynamics views and vectorize bond/angle measurements used by representative-style plotting workflows. --- src/saxshell/bondanalysis/bondanalyzer.py | 210 +- src/saxshell/clusterdynamics/ui/plot_panel.py | 569 +++++- .../clusterdynamicsml/ui/main_window.py | 4 +- .../clusterdynamicsml/ui/plot_panel.py | 3 +- src/saxshell/plotting/__init__.py | 32 + src/saxshell/plotting/igor_inline.py | 394 ++++ src/saxshell/plotting/labels.py | 8 + src/saxshell/plotting/line_plot_editor.py | 819 ++++++++ src/saxshell/plotting/plot_editor.py | 1767 +++++++++++++++++ src/saxshell/plotting/stacked_histogram.py | 546 +++++ src/saxshell/saxs/ui/dream_tab.py | 3 +- tests/test_clusterdynamics.py | 370 ++++ tests/test_clusterdynamicsml.py | 141 ++ 13 files changed, 4753 insertions(+), 113 deletions(-) create mode 100644 src/saxshell/plotting/igor_inline.py create mode 100644 src/saxshell/plotting/labels.py create mode 100644 src/saxshell/plotting/line_plot_editor.py create mode 100644 src/saxshell/plotting/plot_editor.py create mode 100644 src/saxshell/plotting/stacked_histogram.py diff --git a/src/saxshell/bondanalysis/bondanalyzer.py b/src/saxshell/bondanalysis/bondanalyzer.py index e0812b3..ed00e65 100644 --- a/src/saxshell/bondanalysis/bondanalyzer.py +++ b/src/saxshell/bondanalysis/bondanalyzer.py @@ -2,6 +2,7 @@ import math import re +from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Iterable @@ -184,88 +185,153 @@ def measure_atoms( ) -> tuple[ dict[BondPairDefinition, list[float]], dict[AngleTripletDefinition, list[float]], + ]: + if not atoms: + return ( + {definition: [] for definition in self.bond_pairs}, + {definition: [] for definition in self.angle_triplets}, + ) + coords = np.asarray( + [[atom.x, atom.y, atom.z] for atom in atoms], dtype=float + ) + elements = [atom.element for atom in atoms] + return self.measure_structure_data(coords, elements) + + def measure_structure_data( + self, + coordinates: np.ndarray, + elements: Iterable[str], + ) -> tuple[ + dict[BondPairDefinition, list[float]], + dict[AngleTripletDefinition, list[float]], ]: bond_values = {definition: [] for definition in self.bond_pairs} angle_values = {definition: [] for definition in self.angle_triplets} - if not atoms: + + coords = np.asarray(coordinates, dtype=float) + normalized_elements = tuple( + _normalized_element_symbol(element) for element in elements + ) + if coords.size == 0 or not normalized_elements: return bond_values, angle_values + if coords.ndim != 2 or coords.shape[0] != len(normalized_elements): + raise ValueError( + "Coordinates and element symbols must describe the same atoms." + ) - coords = np.array([[atom.x, atom.y, atom.z] for atom in atoms]) - elements = [atom.element for atom in atoms] tree = cKDTree(coords) + element_array = np.asarray(normalized_elements, dtype=object) + bond_groups: defaultdict[float, list[BondPairDefinition]] = ( + defaultdict(list) + ) for definition in self.bond_pairs: - expected = definition.normalized_pair - for index1, index2 in tree.query_pairs(definition.cutoff_angstrom): - actual = tuple(sorted((elements[index1], elements[index2]))) - if actual != expected: - continue - distance = float( - np.linalg.norm(coords[index1] - coords[index2]) - ) - bond_values[definition].append(distance) + bond_groups[float(definition.cutoff_angstrom)].append(definition) + for cutoff, definitions in bond_groups.items(): + raw_pairs = tree.query_pairs(cutoff) + if not raw_pairs: + continue + pair_indices = np.asarray(list(raw_pairs), dtype=int) + if pair_indices.size == 0: + continue + pair_indices = pair_indices.reshape(-1, 2) + left_elements = element_array[pair_indices[:, 0]] + right_elements = element_array[pair_indices[:, 1]] + distances = np.linalg.norm( + coords[pair_indices[:, 0]] - coords[pair_indices[:, 1]], + axis=1, + ) + for definition in definitions: + pair_a, pair_b = definition.normalized_pair + if pair_a == pair_b: + mask = (left_elements == pair_a) & ( + right_elements == pair_b + ) + else: + mask = ( + (left_elements == pair_a) & (right_elements == pair_b) + ) | ( + (left_elements == pair_b) & (right_elements == pair_a) + ) + if np.any(mask): + bond_values[definition].extend( + distances[mask].astype(float).tolist() + ) + angle_groups: defaultdict[ + tuple[str, float], list[AngleTripletDefinition] + ] = defaultdict(list) for definition in self.angle_triplets: - max_cutoff = max( - definition.cutoff1_angstrom, - definition.cutoff2_angstrom, - ) - for center_index, element in enumerate(elements): - if element != definition.vertex: + angle_groups[ + ( + definition.vertex, + max( + float(definition.cutoff1_angstrom), + float(definition.cutoff2_angstrom), + ), + ) + ].append(definition) + for (vertex, max_cutoff), definitions in angle_groups.items(): + center_indices = np.flatnonzero(element_array == vertex) + if center_indices.size == 0: + continue + for center_index in center_indices.tolist(): + neighbor_indices = np.asarray( + tree.query_ball_point(coords[center_index], r=max_cutoff), + dtype=int, + ) + if neighbor_indices.size == 0: continue - neighbor_indices = [ - index - for index in tree.query_ball_point( - coords[center_index], - r=max_cutoff, - ) - if index != center_index - ] - arm1_candidates = [ - index - for index in neighbor_indices - if elements[index] == definition.arm1 - and self._distance(coords, center_index, index) - <= definition.cutoff1_angstrom + neighbor_indices = neighbor_indices[ + neighbor_indices != center_index ] - arm2_candidates = [ - index - for index in neighbor_indices - if elements[index] == definition.arm2 - and self._distance(coords, center_index, index) - <= definition.cutoff2_angstrom - ] - if not arm1_candidates or not arm2_candidates: + if neighbor_indices.size == 0: continue - - if definition.arm1 == definition.arm2: - seen_pairs: set[tuple[int, int]] = set() - for arm1_index in arm1_candidates: - for arm2_index in arm2_candidates: - if arm1_index == arm2_index: - continue - pair = tuple(sorted((arm1_index, arm2_index))) - if pair in seen_pairs: + neighbor_vectors = ( + coords[neighbor_indices] - coords[center_index] + ) + neighbor_distances = np.linalg.norm(neighbor_vectors, axis=1) + valid_mask = neighbor_distances > 0.0 + if not np.any(valid_mask): + continue + neighbor_elements = element_array[neighbor_indices] + unit_vectors = np.zeros_like(neighbor_vectors) + unit_vectors[valid_mask] = ( + neighbor_vectors[valid_mask] + / neighbor_distances[valid_mask, np.newaxis] + ) + for definition in definitions: + arm1_positions = np.flatnonzero( + (neighbor_elements == definition.arm1) + & (neighbor_distances <= definition.cutoff1_angstrom) + & valid_mask + ) + arm2_positions = np.flatnonzero( + (neighbor_elements == definition.arm2) + & (neighbor_distances <= definition.cutoff2_angstrom) + & valid_mask + ) + if arm1_positions.size == 0 or arm2_positions.size == 0: + continue + if definition.arm1 == definition.arm2: + for offset, arm1_position in enumerate( + arm1_positions[:-1] + ): + other_positions = arm1_positions[offset + 1 :] + if other_positions.size == 0: continue - seen_pairs.add(pair) - angle = self._angle_between( - coords[arm1_index] - coords[center_index], - coords[arm2_index] - coords[center_index], + angles = self._angles_from_unit_vectors( + unit_vectors[arm1_position], + unit_vectors[other_positions], ) - if angle is not None: - angle_values[definition].append(angle) - continue - - for arm1_index in arm1_candidates: - for arm2_index in arm2_candidates: - if arm1_index == arm2_index: - continue - angle = self._angle_between( - coords[arm1_index] - coords[center_index], - coords[arm2_index] - coords[center_index], + angle_values[definition].extend(angles) + continue + for arm1_position in arm1_positions.tolist(): + angles = self._angles_from_unit_vectors( + unit_vectors[arm1_position], + unit_vectors[arm2_positions], ) - if angle is not None: - angle_values[definition].append(angle) + angle_values[definition].extend(angles) return bond_values, angle_values @@ -337,6 +403,18 @@ def _angle_between( cosine = float(np.dot(vector1, vector2) / (norm1 * norm2)) return float(math.degrees(math.acos(np.clip(cosine, -1.0, 1.0)))) + @staticmethod + def _angles_from_unit_vectors( + vector: np.ndarray, + other_vectors: np.ndarray, + ) -> list[float]: + vectors = np.asarray(other_vectors, dtype=float) + if vectors.size == 0: + return [] + dots = np.clip(vectors @ np.asarray(vector, dtype=float), -1.0, 1.0) + angles = np.degrees(np.arccos(dots)) + return np.asarray(angles, dtype=float).tolist() + @staticmethod def _dedupe_bond_pairs( bond_pairs: Iterable[BondPairDefinition], diff --git a/src/saxshell/clusterdynamics/ui/plot_panel.py b/src/saxshell/clusterdynamics/ui/plot_panel.py index 0392504..b856d06 100644 --- a/src/saxshell/clusterdynamics/ui/plot_panel.py +++ b/src/saxshell/clusterdynamics/ui/plot_panel.py @@ -1,6 +1,8 @@ from __future__ import annotations import math +from collections.abc import Mapping +from dataclasses import replace import numpy as np from matplotlib import colormaps @@ -10,16 +12,31 @@ NavigationToolbar2QT as NavigationToolbar, ) from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator from PySide6.QtWidgets import ( QComboBox, QDoubleSpinBox, QHBoxLayout, QLabel, + QPushButton, + QSizePolicy, QVBoxLayout, QWidget, ) from saxshell.clusterdynamics.workflow import ClusterDynamicsResult +from saxshell.plotting.igor_inline import ( + apply_igor_inline_text_artist, + igor_inline_to_mathtext, + prepare_igor_inline_segments, +) +from saxshell.plotting.plot_editor import ( + HeatmapPlotDefaults, + HeatmapPlotEditorControls, + HeatmapPlotSettings, + PlotEditorWindow, +) +from saxshell.saxs.stoichiometry import format_stoich_for_axis PLOT_COLORMAPS = ("viridis", "magma", "cividis", "inferno", "turbo") DISPLAY_MODE_LABELS = { @@ -48,17 +65,41 @@ class ClusterDynamicsPlotPanel(QWidget): """Interactive time-binned cluster heatmap panel.""" - def __init__(self, parent: QWidget | None = None) -> None: + _MIN_PANEL_HEIGHT = 420 + _MIN_CANVAS_HEIGHT = 300 + + def __init__( + self, + parent: QWidget | None = None, + *, + enable_plot_editor: bool = False, + ) -> None: super().__init__(parent) + self._enable_plot_editor = bool(enable_plot_editor) self._result: ClusterDynamicsResult | None = None + self._plot_settings = HeatmapPlotSettings() + self._plot_editor_window: PlotEditorWindow | None = None + self._plot_editor_controls: HeatmapPlotEditorControls | None = None + self.plot_editor_button: QPushButton | None = None self._build_ui() self.refresh_plot() def _build_ui(self) -> None: + self.setMinimumHeight(self._MIN_PANEL_HEIGHT) root = QVBoxLayout(self) root.setContentsMargins(0, 0, 0, 0) root.setSpacing(8) + if self._enable_plot_editor: + editor_row = QHBoxLayout() + editor_row.setContentsMargins(0, 0, 0, 0) + editor_row.setSpacing(8) + self.plot_editor_button = QPushButton("Open Plot Editor") + self.plot_editor_button.clicked.connect(self.open_plot_editor) + editor_row.addWidget(self.plot_editor_button) + editor_row.addStretch(1) + root.addLayout(editor_row) + controls_widget = QWidget() controls = QHBoxLayout(controls_widget) controls.setContentsMargins(0, 0, 0, 0) @@ -128,6 +169,11 @@ def _build_ui(self) -> None: self.figure = Figure(figsize=(9.2, 7.2)) self.canvas = FigureCanvas(self.figure) + self.canvas.setMinimumHeight(self._MIN_CANVAS_HEIGHT) + self.canvas.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.Expanding, + ) root.addWidget(NavigationToolbar(self.canvas, self)) root.addWidget(self.canvas, stretch=1) @@ -141,34 +187,205 @@ def set_result(self, result: ClusterDynamicsResult | None) -> None: self.overlay_combo.setCurrentIndex(0) self.refresh_plot() + def open_plot_editor(self) -> None: + if not self._enable_plot_editor: + return + if self._plot_editor_window is not None: + self._plot_editor_window.show() + self._plot_editor_window.raise_() + self._plot_editor_window.activateWindow() + self._plot_editor_window.refresh_preview() + return + + defaults = self._current_plot_defaults() + self._plot_editor_controls = HeatmapPlotEditorControls( + settings=self._plot_settings, + defaults=defaults, + parent=self, + ) + self._plot_editor_controls.settings_changed.connect(self.refresh_plot) + self._plot_editor_controls.x_axis_unit_changed.connect( + self._on_plot_editor_x_axis_unit_changed + ) + self._plot_editor_controls.colormap_changed.connect( + self._on_plot_editor_colormap_changed + ) + self._plot_editor_window = PlotEditorWindow( + window_title="Cluster Dynamics Colormap Editor", + controls_widget=self._plot_editor_controls, + render_preview=self._render_plot_figure, + pickle_state_provider=self._plot_editor_pickle_state, + apply_loaded_pickle_state=self._apply_loaded_plot_editor_pickle_state, + parent=self, + ) + self._plot_editor_window.closed.connect(self._on_plot_editor_closed) + self._plot_editor_window.refresh_preview() + self._plot_editor_window.show() + self._plot_editor_window.raise_() + self._plot_editor_window.activateWindow() + def refresh_plot(self) -> None: - self.figure.clear() + self._render_plot_figure(self.figure) + self.canvas.draw_idle() + if self._plot_editor_window is not None: + self._plot_editor_window.refresh_preview() + + def _on_plot_editor_closed(self) -> None: + self._plot_editor_window = None + self._plot_editor_controls = None + + def _on_plot_editor_colormap_changed(self, colormap_name: str) -> None: + index = self.colormap_combo.findData(colormap_name) + if index < 0 or index == self.colormap_combo.currentIndex(): + return + self.colormap_combo.setCurrentIndex(index) + + def _on_plot_editor_x_axis_unit_changed(self, unit_name: str) -> None: + index = self.time_unit_combo.findData(unit_name) + if index < 0 or index == self.time_unit_combo.currentIndex(): + return + self.time_unit_combo.setCurrentIndex(index) + + def _sync_plot_editor_defaults( + self, defaults: HeatmapPlotDefaults + ) -> None: + if ( + self._plot_editor_controls is not None + and self._plot_editor_controls.needs_default_sync(defaults) + ): + self._plot_editor_controls.sync_defaults(defaults) + + def _plot_editor_pickle_state(self) -> dict[str, object]: + return { + "plot_editor_state": { + "kind": "heatmap_plot_editor_state", + "version": 1, + "heatmap_settings": self._plot_settings.to_dict(), + "panel_state": { + "display_mode": self._display_mode(), + "time_unit": str(self.time_unit_combo.currentData() or ""), + "colormap_name": str( + self.colormap_combo.currentData() or "" + ), + "lower_quantile": float(self.lower_quantile_spin.value()), + "upper_quantile": float(self.upper_quantile_spin.value()), + "overlay_name": self.overlay_combo.currentData(), + }, + } + } + + def _apply_loaded_plot_editor_pickle_state( + self, + payload: Mapping[str, object], + ) -> bool: + editor_state = payload.get("plot_editor_state") + if not isinstance(editor_state, Mapping): + return False + if str(editor_state.get("kind")) != "heatmap_plot_editor_state": + return False + + heatmap_settings = editor_state.get("heatmap_settings") + if isinstance(heatmap_settings, Mapping): + self._plot_settings.update_from_dict(heatmap_settings) + + panel_state = editor_state.get("panel_state") + if isinstance(panel_state, Mapping): + self._apply_panel_state_from_pickle(panel_state) + + defaults = self._current_plot_defaults() + self._plot_settings.sync_labels( + defaults.raw_cluster_labels, + default_label_entries=defaults.default_label_entries, + ) + if self._plot_editor_controls is not None: + self._plot_editor_controls.sync_defaults(defaults) + self.refresh_plot() + return True + + def _apply_panel_state_from_pickle( + self, + panel_state: Mapping[str, object], + ) -> None: + self.display_mode_combo.blockSignals(True) + self.time_unit_combo.blockSignals(True) + self.colormap_combo.blockSignals(True) + self.lower_quantile_spin.blockSignals(True) + self.upper_quantile_spin.blockSignals(True) + self.overlay_combo.blockSignals(True) + try: + self._set_combo_data_if_present( + self.display_mode_combo, + panel_state.get("display_mode"), + ) + self._set_combo_data_if_present( + self.time_unit_combo, + panel_state.get("time_unit"), + ) + self._set_combo_data_if_present( + self.colormap_combo, + panel_state.get("colormap_name"), + ) + if "lower_quantile" in panel_state: + self.lower_quantile_spin.setValue( + float(panel_state["lower_quantile"]) + ) + if "upper_quantile" in panel_state: + self.upper_quantile_spin.setValue( + float(panel_state["upper_quantile"]) + ) + self._ensure_valid_quantiles() + self._set_combo_data_if_present( + self.overlay_combo, + panel_state.get("overlay_name"), + ) + finally: + self.display_mode_combo.blockSignals(False) + self.time_unit_combo.blockSignals(False) + self.colormap_combo.blockSignals(False) + self.lower_quantile_spin.blockSignals(False) + self.upper_quantile_spin.blockSignals(False) + self.overlay_combo.blockSignals(False) + + @staticmethod + def _set_combo_data_if_present(combo: QComboBox, value: object) -> None: + index = combo.findData(value) + if index >= 0: + combo.setCurrentIndex(index) + + def _render_plot_figure(self, figure: Figure) -> None: + defaults = self._current_plot_defaults() + self._plot_settings.sync_labels( + defaults.raw_cluster_labels, + default_label_entries=defaults.default_label_entries, + ) + + figure.clear() if self._result is None: - axis = self.figure.add_subplot(111) + self._sync_plot_editor_defaults(defaults) + axis = figure.add_subplot(111) self._draw_placeholder( axis, "Run the analysis to render the cluster-distribution heatmap.", ) - self.canvas.draw_idle() return if self._result.bin_count == 0: - axis = self.figure.add_subplot(111) + self._sync_plot_editor_defaults(defaults) + axis = figure.add_subplot(111) self._draw_placeholder( axis, "No time bins are available for the current selection.", ) - self.canvas.draw_idle() return matrix = self._result.matrix(self._display_mode()) if matrix.size == 0 or len(self._result.cluster_labels) == 0: - axis = self.figure.add_subplot(111) + self._sync_plot_editor_defaults(defaults) + axis = figure.add_subplot(111) self._draw_placeholder( axis, "No clusters were detected in the selected time window.", ) - self.canvas.draw_idle() return overlay_name = self.overlay_combo.currentData() @@ -177,65 +394,186 @@ def refresh_plot(self) -> None: ) if show_overlay: - grid = self.figure.add_gridspec( + grid = figure.add_gridspec( 2, 1, height_ratios=[4.0, 1.2], hspace=0.08, ) - heatmap_axis = self.figure.add_subplot(grid[0, 0]) - overlay_axis = self.figure.add_subplot( + heatmap_axis = figure.add_subplot(grid[0, 0]) + overlay_axis = figure.add_subplot( grid[1, 0], sharex=heatmap_axis, ) else: - heatmap_axis = self.figure.add_subplot(111) + heatmap_axis = figure.add_subplot(111) overlay_axis = None time_unit = self.time_unit_combo.currentData() time_edges = self._result.time_edges(time_unit) cmap = colormaps[self.colormap_combo.currentData()] - norm = self._quantile_norm(matrix) + ordered_labels = self._plot_settings.ordered_raw_labels(defaults) + ordered_index_lookup = { + str(label): index + for index, label in enumerate(self._result.cluster_labels) + } + ordered_indices = [ + ordered_index_lookup[label] + for label in ordered_labels + if label in ordered_index_lookup + ] + display_matrix = ( + matrix + if not ordered_indices + else np.asarray(matrix, dtype=float)[ordered_indices, :] + ) + auto_vmin, auto_vmax = self._auto_color_limits(display_matrix) + defaults = replace( + defaults, + auto_color_limit_min=auto_vmin, + auto_color_limit_max=auto_vmax, + ) + self._sync_plot_editor_defaults(defaults) + norm = self._heatmap_norm(defaults) image = heatmap_axis.imshow( - matrix, - aspect="auto", + display_matrix, + aspect=self._resolved_aspect(), origin="lower", interpolation="nearest", extent=( float(time_edges[0]), float(time_edges[-1]), -0.5, - len(self._result.cluster_labels) - 0.5, + len(ordered_labels) - 0.5, ), cmap=cmap, norm=norm, ) - colorbar = self.figure.colorbar(image, ax=heatmap_axis, pad=0.02) - colorbar.set_label(DISPLAY_MODE_COLORBAR_LABELS[self._display_mode()]) + colorbar = figure.colorbar(image, ax=heatmap_axis, pad=0.02) + colorbar.set_label( + self._plot_settings.resolve_colorbar_label(defaults), + fontsize=self._plot_settings.axis_label_font_size, + **self._font_kwargs(), + ) + apply_igor_inline_text_artist( + colorbar.ax.yaxis.label, + self._plot_settings.resolve_colorbar_label(defaults), + default_font_size=self._plot_settings.axis_label_font_size, + gid_prefix="heatmap-colorbar-label", + target_axes=colorbar.ax, + ) + colorbar.ax.tick_params( + labelsize=self._plot_settings.tick_label_font_size + ) + for tick_label in colorbar.ax.get_yticklabels(): + self._apply_font_to_text(tick_label) + label_count = len(ordered_labels) label_step = max( 1, - int(math.ceil(len(self._result.cluster_labels) / 24)), - ) - tick_positions = np.arange( - 0, len(self._result.cluster_labels), label_step + int( + math.ceil( + label_count / max(self._plot_settings.max_y_ticks, 1) + ) + ), ) + tick_positions = np.arange(0, len(ordered_labels), label_step) + y_tick_labels = [ + self._plot_settings.display_label(ordered_labels[index]) + for index in tick_positions + ] + rendered_y_tick_labels: list[str] = [] + composite_y_tick_labels: dict[int, str] = {} + for tick_index, tick_label in enumerate(y_tick_labels): + segments, has_markup = prepare_igor_inline_segments( + tick_label, + default_font_size=self._plot_settings.cluster_label_font_size, + ) + if not has_markup: + rendered_y_tick_labels.append(tick_label) + continue + if any( + not math.isclose( + segment.font_size, + self._plot_settings.cluster_label_font_size, + ) + for segment in segments + ): + rendered_y_tick_labels.append(" ") + composite_y_tick_labels[tick_index] = tick_label + continue + rendered_y_tick_labels.append( + igor_inline_to_mathtext( + tick_label, + default_font_size=self._plot_settings.cluster_label_font_size, + ) + ) heatmap_axis.set_yticks(tick_positions) - heatmap_axis.set_yticklabels( - [self._result.cluster_labels[index] for index in tick_positions] + heatmap_axis.set_yticklabels(rendered_y_tick_labels) + heatmap_axis.set_ylabel( + self._plot_settings.resolve_y_label(defaults), + fontsize=self._plot_settings.axis_label_font_size, + **self._font_kwargs(), + ) + apply_igor_inline_text_artist( + heatmap_axis.yaxis.label, + self._plot_settings.resolve_y_label(defaults), + default_font_size=self._plot_settings.axis_label_font_size, + gid_prefix="heatmap-y-label", + target_axes=heatmap_axis, ) - heatmap_axis.set_ylabel("Cluster label") heatmap_axis.set_xlim(float(time_edges[0]), float(time_edges[-1])) heatmap_axis.set_title( - "Time-Binned Cluster Distribution " - f"({DISPLAY_MODE_LABELS[self._display_mode()]})" + self._plot_settings.resolve_title(defaults), + y=self._plot_settings.resolve_title_position_y(defaults), + fontsize=self._plot_settings.title_font_size, + **self._font_kwargs(), + ) + heatmap_axis.title.set_x( + self._plot_settings.resolve_title_position_x(defaults) + ) + apply_igor_inline_text_artist( + heatmap_axis.title, + self._plot_settings.resolve_title(defaults), + default_font_size=self._plot_settings.title_font_size, + gid_prefix="heatmap-title", + target_axes=heatmap_axis, ) + heatmap_axis.xaxis.set_major_locator( + MaxNLocator(nbins=max(self._plot_settings.max_x_ticks, 2)) + ) + if overlay_axis is None: - heatmap_axis.set_xlabel(f"Time ({time_unit})") + heatmap_axis.set_xlabel( + self._plot_settings.resolve_x_label(defaults), + fontsize=self._plot_settings.axis_label_font_size, + **self._font_kwargs(), + ) + apply_igor_inline_text_artist( + heatmap_axis.xaxis.label, + self._plot_settings.resolve_x_label(defaults), + default_font_size=self._plot_settings.axis_label_font_size, + gid_prefix="heatmap-x-label", + target_axes=heatmap_axis, + ) else: heatmap_axis.tick_params(labelbottom=False) + self._style_heatmap_ticks(heatmap_axis) + for tick_index, tick_label in enumerate( + heatmap_axis.get_yticklabels() + ): + if tick_index not in composite_y_tick_labels: + continue + apply_igor_inline_text_artist( + tick_label, + composite_y_tick_labels[tick_index], + default_font_size=self._plot_settings.cluster_label_font_size, + gid_prefix=f"heatmap-y-tick-{tick_index}", + target_axes=heatmap_axis, + ) + if overlay_axis is not None and overlay_name is not None: x_values, y_values, y_label = self._result.energy_series( overlay_name, @@ -247,36 +585,179 @@ def refresh_plot(self) -> None: color=OVERLAY_COLORS.get(overlay_name, "#333333"), linewidth=1.5, ) - overlay_axis.set_ylabel(y_label) - overlay_axis.set_xlabel(f"Time ({time_unit})") + overlay_axis.set_ylabel( + y_label, + fontsize=self._plot_settings.axis_label_font_size, + **self._font_kwargs(), + ) + overlay_axis.set_xlabel( + self._plot_settings.resolve_x_label(defaults), + fontsize=self._plot_settings.axis_label_font_size, + **self._font_kwargs(), + ) + apply_igor_inline_text_artist( + overlay_axis.xaxis.label, + self._plot_settings.resolve_x_label(defaults), + default_font_size=self._plot_settings.axis_label_font_size, + gid_prefix="overlay-x-label", + target_axes=overlay_axis, + ) overlay_axis.grid(alpha=0.25, linestyle=":") + overlay_axis.xaxis.set_major_locator( + MaxNLocator(nbins=max(self._plot_settings.max_x_ticks, 2)) + ) + self._style_overlay_ticks(overlay_axis) - self.figure.tight_layout() - self.canvas.draw_idle() + figure.tight_layout() + + def _current_plot_defaults(self) -> HeatmapPlotDefaults: + time_unit = self.time_unit_combo.currentData() + raw_labels = ( + () + if self._result is None + else tuple(str(label) for label in self._result.cluster_labels) + ) + current_colormap = self.colormap_combo.currentData() + default_label_entries = tuple( + (raw_label, self._format_cluster_axis_label(raw_label)) + for raw_label in raw_labels + ) + return HeatmapPlotDefaults( + title=( + "Time-Binned Cluster Distribution " + f"({DISPLAY_MODE_LABELS[self._display_mode()]})" + ), + x_label=f"Time ({time_unit})", + y_label="Cluster label", + colorbar_label=DISPLAY_MODE_COLORBAR_LABELS[self._display_mode()], + default_x_axis_unit_name=( + "" if time_unit is None else str(time_unit) + ), + available_x_axis_unit_names=("fs", "ps"), + default_colormap_name=( + "" if current_colormap is None else str(current_colormap) + ), + available_colormap_names=tuple(PLOT_COLORMAPS), + raw_cluster_labels=raw_labels, + default_label_entries=default_label_entries, + ) + + def _resolved_aspect(self) -> str | float: + if self._plot_settings.aspect_mode == "equal": + return "equal" + if self._plot_settings.aspect_mode == "custom": + return float(self._plot_settings.custom_aspect) + return "auto" + + def _font_kwargs(self) -> dict[str, str]: + if not self._plot_settings.font_family: + return {} + return {"fontfamily": self._plot_settings.font_family} + + @staticmethod + def _format_cluster_axis_label(label: str) -> str: + return format_stoich_for_axis(label) + + def _apply_font_to_text(self, text_artist) -> None: + if self._plot_settings.font_family: + text_artist.set_fontfamily(self._plot_settings.font_family) + + def _style_heatmap_ticks(self, axis) -> None: + axis.tick_params( + axis="x", + labelsize=self._plot_settings.tick_label_font_size, + labelrotation=self._plot_settings.x_tick_rotation, + ) + axis.tick_params( + axis="y", + labelsize=self._plot_settings.cluster_label_font_size, + labelrotation=self._plot_settings.y_tick_rotation, + ) + if ( + self._plot_settings.show_minor_x_ticks + or self._plot_settings.show_minor_y_ticks + ): + axis.minorticks_on() + else: + axis.minorticks_off() + axis.tick_params( + axis="x", + which="minor", + bottom=self._plot_settings.show_minor_x_ticks, + top=False, + ) + axis.tick_params( + axis="y", + which="minor", + left=self._plot_settings.show_minor_y_ticks, + right=False, + ) + for tick_label in axis.get_xticklabels(): + self._apply_font_to_text(tick_label) + for tick_label in axis.get_yticklabels(): + self._apply_font_to_text(tick_label) + + def _style_overlay_ticks(self, axis) -> None: + axis.tick_params( + axis="both", + labelsize=self._plot_settings.tick_label_font_size, + ) + axis.tick_params( + axis="x", + labelrotation=self._plot_settings.x_tick_rotation, + ) + if self._plot_settings.show_minor_x_ticks: + axis.minorticks_on() + else: + axis.minorticks_off() + axis.tick_params( + axis="x", + which="minor", + bottom=self._plot_settings.show_minor_x_ticks, + top=False, + ) + for tick_label in axis.get_xticklabels(): + self._apply_font_to_text(tick_label) + for tick_label in axis.get_yticklabels(): + self._apply_font_to_text(tick_label) def _display_mode(self) -> str: value = self.display_mode_combo.currentData() return "fraction" if value is None else str(value) - def _on_quantile_changed(self) -> None: + def _ensure_valid_quantiles(self) -> None: lower = self.lower_quantile_spin.value() upper = self.upper_quantile_spin.value() if lower >= upper: - if self.sender() is self.lower_quantile_spin: - self.upper_quantile_spin.blockSignals(True) - self.upper_quantile_spin.setValue(min(lower + 0.05, 1.0)) - self.upper_quantile_spin.blockSignals(False) - else: - self.lower_quantile_spin.blockSignals(True) - self.lower_quantile_spin.setValue(max(upper - 0.05, 0.0)) - self.lower_quantile_spin.blockSignals(False) + self.upper_quantile_spin.blockSignals(True) + self.upper_quantile_spin.setValue(min(lower + 0.05, 1.0)) + self.upper_quantile_spin.blockSignals(False) + lower = self.lower_quantile_spin.value() + upper = self.upper_quantile_spin.value() + if lower >= upper: + self.lower_quantile_spin.blockSignals(True) + self.lower_quantile_spin.setValue(max(upper - 0.05, 0.0)) + self.lower_quantile_spin.blockSignals(False) + + def _on_quantile_changed(self) -> None: + self._ensure_valid_quantiles() self.refresh_plot() - def _quantile_norm(self, matrix: np.ndarray) -> mcolors.Normalize: + def _heatmap_norm( + self, + defaults: HeatmapPlotDefaults, + ) -> mcolors.Normalize: + vmin = float(self._plot_settings.resolve_color_limit_min(defaults)) + vmax = float(self._plot_settings.resolve_color_limit_max(defaults)) + if vmax <= vmin: + vmax = vmin + 1.0 + return mcolors.Normalize(vmin=vmin, vmax=vmax) + + def _auto_color_limits(self, matrix: np.ndarray) -> tuple[float, float]: values = np.asarray(matrix, dtype=float) finite = values[np.isfinite(values)] if finite.size == 0: - return mcolors.Normalize(vmin=0.0, vmax=1.0) + return (0.0, 1.0) positive = finite[finite > 0.0] if positive.size: @@ -291,7 +772,7 @@ def _quantile_norm(self, matrix: np.ndarray) -> mcolors.Normalize: vmax = float(np.max(finite)) if vmax <= vmin: vmax = vmin + 1.0 - return mcolors.Normalize(vmin=vmin, vmax=vmax) + return (vmin, vmax) @staticmethod def _draw_placeholder(axis, message: str) -> None: diff --git a/src/saxshell/clusterdynamicsml/ui/main_window.py b/src/saxshell/clusterdynamicsml/ui/main_window.py index c19a94c..c834960 100644 --- a/src/saxshell/clusterdynamicsml/ui/main_window.py +++ b/src/saxshell/clusterdynamicsml/ui/main_window.py @@ -877,7 +877,9 @@ def _build_ui(self) -> None: right_layout.setContentsMargins(0, 0, 0, 0) right_layout.setSpacing(0) - self.dynamics_plot_panel = ClusterDynamicsPlotPanel() + self.dynamics_plot_panel = ClusterDynamicsPlotPanel( + enable_plot_editor=True + ) self.right_splitter = QSplitter(Qt.Orientation.Vertical) self.right_splitter.setChildrenCollapsible(False) diff --git a/src/saxshell/clusterdynamicsml/ui/plot_panel.py b/src/saxshell/clusterdynamicsml/ui/plot_panel.py index 840d5af..7b86f04 100644 --- a/src/saxshell/clusterdynamicsml/ui/plot_panel.py +++ b/src/saxshell/clusterdynamicsml/ui/plot_panel.py @@ -25,6 +25,7 @@ ClusterDynamicsMLResult, _resolved_population_weights, ) +from saxshell.plotting import Q_A_INVERSE_LABEL from saxshell.saxs.debye.profiles import scan_structure_element_counts from saxshell.saxs.project_manager.prior_plot import ( list_secondary_filter_elements, @@ -701,7 +702,7 @@ def _apply_saxs_axis_style( axis.set_xscale("log" if self.log_x_checkbox.isChecked() else "linear") axis.set_yscale("log" if self.log_y_checkbox.isChecked() else "linear") if not is_model_axis or not has_separate_model_axis: - axis.set_xlabel("q (Å⁻¹)") + axis.set_xlabel(Q_A_INVERSE_LABEL) if not is_model_axis: axis.set_ylabel("Intensity (arb. units)") elif has_separate_model_axis: diff --git a/src/saxshell/plotting/__init__.py b/src/saxshell/plotting/__init__.py index 4bb6c0a..3a7f537 100644 --- a/src/saxshell/plotting/__init__.py +++ b/src/saxshell/plotting/__init__.py @@ -15,6 +15,38 @@ """Python package for analysis of small-angle scattering data from molecular dynamics derived liquid structures.""" +from saxshell.plotting.igor_inline import ( # noqa + IgorInlineSegment, + apply_igor_inline_text_artist, + has_igor_inline_markup, + igor_inline_to_mathtext, + prepare_igor_inline_segments, +) +from saxshell.plotting.labels import Q_A_INVERSE_LABEL # noqa +from saxshell.plotting.line_plot_editor import ( # noqa + LINE_PLOT_LEGEND_LOCATIONS, + LinePlotDefaults, + LinePlotEditorControls, + LinePlotSeriesDefaults, + LinePlotSettings, +) +from saxshell.plotting.plot_editor import ( # noqa + HeatmapPlotDefaults, + HeatmapPlotEditorControls, + HeatmapPlotSettings, + PlotEditorWindow, + StackedHistogramPlotDefaults, + StackedHistogramPlotEditorControls, + StackedHistogramPlotSettings, + load_pickled_plot_figure, + load_pickled_plot_payload, + save_pickled_plot_figure, +) +from saxshell.plotting.stacked_histogram import ( # noqa + STACKED_HISTOGRAM_LEGEND_LOCATIONS, + render_stacked_histogram_export_payload, +) + # package version from saxshell.version import __version__ # noqa diff --git a/src/saxshell/plotting/igor_inline.py b/src/saxshell/plotting/igor_inline.py new file mode 100644 index 0000000..700efa5 --- /dev/null +++ b/src/saxshell/plotting/igor_inline.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass + +from matplotlib.backends.backend_agg import RendererAgg +from matplotlib.font_manager import FontProperties +from matplotlib.text import Text +from matplotlib.transforms import IdentityTransform + + +@dataclass(frozen=True, slots=True) +class IgorInlineSegment: + text: str + bold: bool + italic: bool + font_size: float + + +def prepare_igor_inline_segments( + text: str, + *, + default_font_size: float, +) -> tuple[list[IgorInlineSegment], bool]: + if not text: + return ([], False) + + segments: list[IgorInlineSegment] = [] + has_markup = False + bold = False + italic = False + font_size = float(default_font_size) + buffer: list[str] = [] + + def flush() -> None: + if not buffer: + return + segment_text = "".join(buffer) + buffer.clear() + if segment_text: + segments.append( + IgorInlineSegment( + text=segment_text, + bold=bold, + italic=italic, + font_size=font_size, + ) + ) + + for in_math, chunk in _split_math_sections(text): + if in_math: + buffer.append(f"${chunk}$") + continue + + index = 0 + while index < len(chunk): + if chunk.startswith(r"\f", index) and index + 4 <= len(chunk): + code = chunk[index + 2 : index + 4] + if code in {"00", "01", "02"}: + flush() + has_markup = True + if code == "00": + bold = False + italic = False + elif code == "01": + bold = True + else: + italic = True + index += 4 + continue + if chunk.startswith(r"\Z", index): + size_text, consumed = _parse_igor_font_size(chunk[index + 2 :]) + if consumed > 0 and size_text is not None: + flush() + has_markup = True + size_value = float(size_text) + font_size = ( + float(default_font_size) + if size_value <= 0.0 + else size_value + ) + index += 2 + consumed + continue + buffer.append(chunk[index]) + index += 1 + + flush() + return (segments, has_markup) + + +def has_igor_inline_markup(text: str) -> bool: + _segments, has_markup = prepare_igor_inline_segments( + text, + default_font_size=12.0, + ) + return has_markup + + +def igor_inline_to_mathtext( + text: str, + *, + default_font_size: float, +) -> str: + segments, has_markup = prepare_igor_inline_segments( + text, + default_font_size=default_font_size, + ) + if not has_markup: + return text + if any( + not math.isclose(segment.font_size, default_font_size) + for segment in segments + ): + raise ValueError("Inline font-size changes require composite drawing") + return _segments_to_mathtext(segments) + + +def apply_igor_inline_text_artist( + text_artist: Text, + raw_text: str, + *, + default_font_size: float, + gid_prefix: str, + target_axes=None, +) -> list[Text]: + segments, has_markup = prepare_igor_inline_segments( + raw_text, + default_font_size=default_font_size, + ) + if not has_markup: + text_artist.set_text(raw_text) + return [] + + requires_composite = any( + not math.isclose(segment.font_size, default_font_size) + for segment in segments + ) + if not requires_composite: + text_artist.set_text(_segments_to_mathtext(segments)) + return [] + + text_artist.set_text(" ") + return _compose_inline_segments( + text_artist, + segments, + gid_prefix=gid_prefix, + target_axes=target_axes, + ) + + +def _parse_igor_font_size(chunk: str) -> tuple[str | None, int]: + if not chunk: + return (None, 0) + if chunk.startswith("<"): + end = chunk.find(">") + if end <= 1: + return (None, 0) + value = chunk[1:end] + try: + float(value) + except ValueError: + return (None, 0) + return (value, end + 1) + + end = 0 + while end < len(chunk) and (chunk[end].isdigit() or chunk[end] == "."): + end += 1 + if end == 0: + return (None, 0) + value = chunk[:end] + try: + float(value) + except ValueError: + return (None, 0) + return (value, end) + + +def _split_math_sections(text: str) -> list[tuple[bool, str]]: + sections: list[tuple[bool, str]] = [] + buffer: list[str] = [] + in_math = False + + def flush() -> None: + if buffer: + sections.append((in_math, "".join(buffer))) + buffer.clear() + + index = 0 + while index < len(text): + char = text[index] + if char == "$" and (index == 0 or text[index - 1] != "\\"): + flush() + in_math = not in_math + index += 1 + continue + buffer.append(char) + index += 1 + flush() + return sections + + +def _segments_to_mathtext(segments: list[IgorInlineSegment]) -> str: + body = "".join(_segment_mathtext_body(segment) for segment in segments) + return f"${body}$" + + +def _segment_mathtext_body(segment: IgorInlineSegment) -> str: + pieces: list[str] = [] + for in_math, chunk in _split_math_sections(segment.text): + if not chunk: + continue + if in_math: + if chunk.startswith("^") or chunk.startswith("_"): + pieces.append(chunk) + else: + pieces.append( + _wrap_with_math_style( + chunk, + bold=segment.bold, + italic=segment.italic, + ) + ) + continue + pieces.append( + _wrap_with_math_style( + _escape_literal_for_mathtext(chunk), + bold=segment.bold, + italic=segment.italic, + ) + ) + return "".join(pieces) + + +def _wrap_with_math_style( + text: str, + *, + bold: bool, + italic: bool, +) -> str: + if not text: + return "" + if bold and italic: + return rf"\mathbf{{\mathit{{{text}}}}}" + if bold: + return rf"\mathbf{{{text}}}" + if italic: + return rf"\mathit{{{text}}}" + return rf"\mathregular{{{text}}}" + + +def _escape_literal_for_mathtext(text: str) -> str: + escaped = ( + text.replace("\\", r"\backslash ") + .replace("{", r"\{") + .replace("}", r"\}") + .replace("_", r"\_") + .replace("^", r"\^{}") + .replace("%", r"\%") + .replace("&", r"\&") + .replace("#", r"\#") + .replace("$", r"\$") + ) + return escaped.replace(" ", r"\ ") + + +def _compose_inline_segments( + text_artist: Text, + segments: list[IgorInlineSegment], + *, + gid_prefix: str, + target_axes=None, +) -> list[Text]: + if not segments: + return [] + + figure = text_artist.figure + if figure is None: + return [] + renderer = _measurement_renderer(figure) + anchor_transform = text_artist.get_transform() + anchor = anchor_transform.transform(text_artist.get_position()) + rotation = float(text_artist.get_rotation()) + ha = str(text_artist.get_ha()) + va = str(text_artist.get_va()) + color = text_artist.get_color() + zorder = text_artist.get_zorder() + font_family = _font_family_name(text_artist) + axes = target_axes if target_axes is not None else text_artist.axes + if axes is None: + return [] + + metrics: list[tuple[IgorInlineSegment, str, float, float, float]] = [] + total_width = 0.0 + max_ascent = 0.0 + max_descent = 0.0 + for segment in segments: + mathtext = f"${_segment_mathtext_body(segment)}$" + font_properties = FontProperties(size=segment.font_size) + if font_family: + font_properties.set_family(font_family) + width, height, descent = renderer.get_text_width_height_descent( + mathtext, + font_properties, + ismath=True, + ) + ascent = height - descent + total_width += width + max_ascent = max(max_ascent, ascent) + max_descent = max(max_descent, descent) + metrics.append((segment, mathtext, width, ascent, descent)) + + x_offset = _horizontal_alignment_offset(total_width, ha) + y_offset = _vertical_alignment_offset(max_ascent, max_descent, va) + angle = math.radians(rotation) + cos_angle = math.cos(angle) + sin_angle = math.sin(angle) + current_x = x_offset + + artists: list[Text] = [] + for index, (segment, mathtext, width, _ascent, _descent) in enumerate( + metrics + ): + display_x = ( + anchor[0] + (current_x * cos_angle) - (y_offset * sin_angle) + ) + display_y = ( + anchor[1] + (current_x * sin_angle) + (y_offset * cos_angle) + ) + artist = axes.text( + display_x, + display_y, + mathtext, + transform=IdentityTransform(), + ha="left", + va="baseline", + rotation=rotation, + rotation_mode="anchor", + clip_on=False, + fontsize=segment.font_size, + color=color, + zorder=zorder, + ) + if font_family: + artist.set_fontfamily(font_family) + artist.set_gid(f"{gid_prefix}-{index}") + artists.append(artist) + current_x += width + return artists + + +def _measurement_renderer(figure) -> RendererAgg: + width = max(1, int(math.ceil(figure.bbox.width))) + height = max(1, int(math.ceil(figure.bbox.height))) + return RendererAgg(width, height, figure.dpi) + + +def _font_family_name(text_artist: Text) -> str: + family = text_artist.get_fontfamily() + if isinstance(family, str): + return family + if family: + return str(family[0]) + return "" + + +def _horizontal_alignment_offset(total_width: float, ha: str) -> float: + if ha == "center": + return -(0.5 * total_width) + if ha == "right": + return -total_width + return 0.0 + + +def _vertical_alignment_offset( + max_ascent: float, + max_descent: float, + va: str, +) -> float: + if va in {"center", "center_baseline"}: + return -((max_ascent - max_descent) * 0.5) + if va == "top": + return -max_ascent + if va == "bottom": + return max_descent + return 0.0 + + +__all__ = [ + "IgorInlineSegment", + "apply_igor_inline_text_artist", + "has_igor_inline_markup", + "igor_inline_to_mathtext", + "prepare_igor_inline_segments", +] diff --git a/src/saxshell/plotting/labels.py b/src/saxshell/plotting/labels.py new file mode 100644 index 0000000..052e301 --- /dev/null +++ b/src/saxshell/plotting/labels.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +# Use mathtext for the inverse-Angstrom exponent so Matplotlib does not rely +# on the selected UI font having the Unicode superscript minus glyph. +Q_A_INVERSE_LABEL = "q (Å$^{-1}$)" + + +__all__ = ["Q_A_INVERSE_LABEL"] diff --git a/src/saxshell/plotting/line_plot_editor.py b/src/saxshell/plotting/line_plot_editor.py new file mode 100644 index 0000000..b297099 --- /dev/null +++ b/src/saxshell/plotting/line_plot_editor.py @@ -0,0 +1,819 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field + +from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import QFont +from PySide6.QtWidgets import ( + QAbstractItemView, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFontComboBox, + QFormLayout, + QGroupBox, + QHeaderView, + QLabel, + QLineEdit, + QPushButton, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +LINE_PLOT_LEGEND_LOCATIONS = ( + ("Best", "best"), + ("Upper Right", "upper right"), + ("Upper Left", "upper left"), + ("Lower Right", "lower right"), + ("Lower Left", "lower left"), +) + + +@dataclass(slots=True) +class LinePlotSeriesDefaults: + key: str + label: str + axis_label: str = "Main" + + +@dataclass(slots=True) +class LinePlotDefaults: + title: str + x_label: str + primary_y_label: str + secondary_y_label: str = "" + residual_y_label: str = "" + title_position_x: float = 0.5 + title_position_y: float = 1.0 + has_secondary_y_axis: bool = False + has_residual_y_axis: bool = False + has_annotation: bool = False + default_show_legend: bool = True + default_legend_location: str = "best" + default_show_annotation: bool = True + series_defaults: tuple[LinePlotSeriesDefaults, ...] = () + + +@dataclass(slots=True) +class LinePlotSettings: + title: str | None = None + x_label: str | None = None + primary_y_label: str | None = None + secondary_y_label: str | None = None + residual_y_label: str | None = None + title_position_x: float | None = None + title_position_y: float | None = None + font_family: str = "" + title_font_size: float = 12.0 + axis_label_font_size: float = 11.0 + tick_label_font_size: float = 9.0 + primary_axis_label_font_size: float | None = None + primary_tick_label_font_size: float | None = None + secondary_axis_label_font_size: float | None = None + secondary_tick_label_font_size: float | None = None + legend_font_size: float = 9.0 + annotation_font_size: float = 9.0 + show_legend: bool | None = None + legend_location: str | None = None + show_annotation: bool | None = None + series_label_map: dict[str, str] = field(default_factory=dict) + + def resolve_title(self, defaults: LinePlotDefaults) -> str: + return defaults.title if self.title is None else self.title + + def resolve_x_label(self, defaults: LinePlotDefaults) -> str: + return defaults.x_label if self.x_label is None else self.x_label + + def resolve_primary_y_label(self, defaults: LinePlotDefaults) -> str: + return ( + defaults.primary_y_label + if self.primary_y_label is None + else self.primary_y_label + ) + + def resolve_secondary_y_label(self, defaults: LinePlotDefaults) -> str: + return ( + defaults.secondary_y_label + if self.secondary_y_label is None + else self.secondary_y_label + ) + + def resolve_residual_y_label(self, defaults: LinePlotDefaults) -> str: + return ( + defaults.residual_y_label + if self.residual_y_label is None + else self.residual_y_label + ) + + def resolve_title_position_x(self, defaults: LinePlotDefaults) -> float: + return ( + defaults.title_position_x + if self.title_position_x is None + else self.title_position_x + ) + + def resolve_title_position_y(self, defaults: LinePlotDefaults) -> float: + return ( + defaults.title_position_y + if self.title_position_y is None + else self.title_position_y + ) + + def resolve_show_legend(self, defaults: LinePlotDefaults) -> bool: + return ( + defaults.default_show_legend + if self.show_legend is None + else bool(self.show_legend) + ) + + def resolve_primary_axis_label_font_size( + self, + defaults: LinePlotDefaults, + ) -> float: + del defaults + return ( + self.axis_label_font_size + if self.primary_axis_label_font_size is None + else float(self.primary_axis_label_font_size) + ) + + def resolve_primary_tick_label_font_size( + self, + defaults: LinePlotDefaults, + ) -> float: + del defaults + return ( + self.tick_label_font_size + if self.primary_tick_label_font_size is None + else float(self.primary_tick_label_font_size) + ) + + def resolve_secondary_axis_label_font_size( + self, + defaults: LinePlotDefaults, + ) -> float: + del defaults + return ( + self.axis_label_font_size + if self.secondary_axis_label_font_size is None + else float(self.secondary_axis_label_font_size) + ) + + def resolve_secondary_tick_label_font_size( + self, + defaults: LinePlotDefaults, + ) -> float: + del defaults + return ( + self.tick_label_font_size + if self.secondary_tick_label_font_size is None + else float(self.secondary_tick_label_font_size) + ) + + def resolve_legend_location(self, defaults: LinePlotDefaults) -> str: + return ( + defaults.default_legend_location + if self.legend_location is None + else str(self.legend_location) + ) + + def resolve_show_annotation(self, defaults: LinePlotDefaults) -> bool: + return ( + defaults.default_show_annotation + if self.show_annotation is None + else bool(self.show_annotation) + ) + + def sync_series( + self, + series_defaults: Sequence[LinePlotSeriesDefaults], + ) -> None: + default_map = { + str(series.key): str(series.label) for series in series_defaults + } + existing = dict(self.series_label_map) + self.series_label_map = { + key: existing.get(key, label) for key, label in default_map.items() + } + + def display_series_label(self, series_key: str, fallback: str) -> str: + return self.series_label_map.get(series_key, fallback) + + def to_dict(self) -> dict[str, object]: + return { + "title": self.title, + "x_label": self.x_label, + "primary_y_label": self.primary_y_label, + "secondary_y_label": self.secondary_y_label, + "residual_y_label": self.residual_y_label, + "title_position_x": self.title_position_x, + "title_position_y": self.title_position_y, + "font_family": self.font_family, + "title_font_size": self.title_font_size, + "axis_label_font_size": self.axis_label_font_size, + "tick_label_font_size": self.tick_label_font_size, + "primary_axis_label_font_size": self.primary_axis_label_font_size, + "primary_tick_label_font_size": self.primary_tick_label_font_size, + "secondary_axis_label_font_size": self.secondary_axis_label_font_size, + "secondary_tick_label_font_size": self.secondary_tick_label_font_size, + "legend_font_size": self.legend_font_size, + "annotation_font_size": self.annotation_font_size, + "show_legend": self.show_legend, + "legend_location": self.legend_location, + "show_annotation": self.show_annotation, + "series_label_map": dict(self.series_label_map), + } + + def update_from_dict(self, payload: Mapping[str, object]) -> None: + optional_float_fields = { + "primary_axis_label_font_size", + "primary_tick_label_font_size", + "secondary_axis_label_font_size", + "secondary_tick_label_font_size", + } + for field_name in ( + "title", + "x_label", + "primary_y_label", + "secondary_y_label", + "residual_y_label", + "title_position_x", + "title_position_y", + "show_legend", + "legend_location", + "show_annotation", + ): + if field_name in payload: + setattr(self, field_name, payload[field_name]) + if "font_family" in payload: + self.font_family = str(payload["font_family"] or "") + for field_name in ( + "title_font_size", + "axis_label_font_size", + "tick_label_font_size", + "primary_axis_label_font_size", + "primary_tick_label_font_size", + "secondary_axis_label_font_size", + "secondary_tick_label_font_size", + "legend_font_size", + "annotation_font_size", + ): + if field_name in payload: + value = payload[field_name] + if value is None: + if field_name in optional_float_fields: + setattr(self, field_name, None) + continue + setattr(self, field_name, float(value)) + if "series_label_map" in payload: + series_label_map = payload["series_label_map"] + if isinstance(series_label_map, Mapping): + self.series_label_map = { + str(key): str(value) + for key, value in series_label_map.items() + } + + +class LinePlotEditorControls(QWidget): + """Editable controls for reusable line-plot settings.""" + + settings_changed = Signal() + label_settings_changed = Signal() + + def __init__( + self, + *, + settings: LinePlotSettings, + defaults: LinePlotDefaults, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._settings = settings + self._defaults = defaults + self._last_synced_defaults: LinePlotDefaults | None = None + self._syncing = False + self._build_ui() + self.sync_defaults(defaults) + + def needs_default_sync(self, defaults: LinePlotDefaults) -> bool: + return self._last_synced_defaults != defaults + + def sync_defaults(self, defaults: LinePlotDefaults) -> None: + self._defaults = defaults + self._settings.sync_series(defaults.series_defaults) + self._syncing = True + try: + self.title_edit.setText(self._settings.resolve_title(defaults)) + self.x_label_edit.setText(self._settings.resolve_x_label(defaults)) + self.primary_y_label_edit.setText( + self._settings.resolve_primary_y_label(defaults) + ) + self.secondary_y_label_edit.setText( + self._settings.resolve_secondary_y_label(defaults) + ) + self.residual_y_label_edit.setText( + self._settings.resolve_residual_y_label(defaults) + ) + self.title_position_x_spin.setValue( + self._settings.resolve_title_position_x(defaults) + ) + self.title_position_y_spin.setValue( + self._settings.resolve_title_position_y(defaults) + ) + if self._settings.font_family: + self.font_combo.setCurrentFont( + QFont(self._settings.font_family) + ) + self.title_font_spin.setValue(self._settings.title_font_size) + self.axis_label_font_spin.setValue( + self._settings.axis_label_font_size + ) + self.tick_label_font_spin.setValue( + self._settings.tick_label_font_size + ) + self.primary_axis_label_font_spin.setValue( + self._settings.resolve_primary_axis_label_font_size(defaults) + ) + self.primary_tick_label_font_spin.setValue( + self._settings.resolve_primary_tick_label_font_size(defaults) + ) + self.secondary_axis_label_font_spin.setValue( + self._settings.resolve_secondary_axis_label_font_size(defaults) + ) + self.secondary_tick_label_font_spin.setValue( + self._settings.resolve_secondary_tick_label_font_size(defaults) + ) + self.legend_font_spin.setValue(self._settings.legend_font_size) + self.annotation_font_spin.setValue( + self._settings.annotation_font_size + ) + self.show_legend_checkbox.setChecked( + self._settings.resolve_show_legend(defaults) + ) + self.legend_location_combo.setCurrentIndex( + max( + 0, + self.legend_location_combo.findData( + self._settings.resolve_legend_location(defaults) + ), + ) + ) + self.show_annotation_checkbox.setChecked( + self._settings.resolve_show_annotation(defaults) + ) + self._sync_label_table() + self._update_dynamic_field_visibility() + self._update_display_state() + finally: + self._last_synced_defaults = defaults + self._syncing = False + + def _build_ui(self) -> None: + root = QVBoxLayout(self) + root.setContentsMargins(0, 0, 0, 0) + root.setSpacing(10) + + note = QLabel( + "Edit line-plot titles, axis labels, legend layout, and series " + "display labels. Axis-specific fields appear only when the " + "current plot uses them." + ) + note.setWordWrap(True) + root.addWidget(note) + + text_group = QGroupBox("Text") + text_form = QFormLayout(text_group) + self.title_edit = QLineEdit() + self.title_edit.textChanged.connect(self._on_title_changed) + text_form.addRow("Title", self.title_edit) + self.x_label_edit = QLineEdit() + self.x_label_edit.textChanged.connect(self._on_x_label_changed) + text_form.addRow("X Label", self.x_label_edit) + self.primary_y_label_edit = QLineEdit() + self.primary_y_label_edit.textChanged.connect( + self._on_primary_y_label_changed + ) + text_form.addRow("Y Label", self.primary_y_label_edit) + self.secondary_y_label_edit = QLineEdit() + self.secondary_y_label_edit.textChanged.connect( + self._on_secondary_y_label_changed + ) + text_form.addRow("Secondary Y", self.secondary_y_label_edit) + self.residual_y_label_edit = QLineEdit() + self.residual_y_label_edit.textChanged.connect( + self._on_residual_y_label_changed + ) + text_form.addRow("Residual Y", self.residual_y_label_edit) + self.reset_text_button = QPushButton("Reset Text Defaults") + self.reset_text_button.clicked.connect(self._reset_text_defaults) + text_form.addRow(self.reset_text_button) + root.addWidget(text_group) + + style_group = QGroupBox("Style") + style_form = QFormLayout(style_group) + self.font_combo = QFontComboBox() + self.font_combo.currentFontChanged.connect(self._on_font_changed) + style_form.addRow("Font", self.font_combo) + self.title_font_spin = self._build_font_spin() + self.title_font_spin.valueChanged.connect( + self._on_title_font_size_changed + ) + style_form.addRow("Title Size", self.title_font_spin) + self.title_position_x_spin = self._build_position_spin() + self.title_position_x_spin.valueChanged.connect( + self._on_title_position_x_changed + ) + style_form.addRow("Title X", self.title_position_x_spin) + self.title_position_y_spin = self._build_position_spin() + self.title_position_y_spin.valueChanged.connect( + self._on_title_position_y_changed + ) + style_form.addRow("Title Y", self.title_position_y_spin) + self.axis_label_font_spin = self._build_font_spin() + self.axis_label_font_spin.valueChanged.connect( + self._on_axis_label_font_size_changed + ) + style_form.addRow("X Label Size", self.axis_label_font_spin) + self.tick_label_font_spin = self._build_font_spin() + self.tick_label_font_spin.valueChanged.connect( + self._on_tick_label_font_size_changed + ) + style_form.addRow("X Tick Size", self.tick_label_font_spin) + self.primary_axis_label_font_spin = self._build_font_spin() + self.primary_axis_label_font_spin.valueChanged.connect( + self._on_primary_axis_label_font_size_changed + ) + style_form.addRow( + "Primary Y Label Size", + self.primary_axis_label_font_spin, + ) + self.primary_tick_label_font_spin = self._build_font_spin() + self.primary_tick_label_font_spin.valueChanged.connect( + self._on_primary_tick_label_font_size_changed + ) + style_form.addRow( + "Primary Y Tick Size", + self.primary_tick_label_font_spin, + ) + self.secondary_axis_label_font_spin = self._build_font_spin() + self.secondary_axis_label_font_spin.valueChanged.connect( + self._on_secondary_axis_label_font_size_changed + ) + style_form.addRow( + "Secondary Y Label Size", + self.secondary_axis_label_font_spin, + ) + self.secondary_tick_label_font_spin = self._build_font_spin() + self.secondary_tick_label_font_spin.valueChanged.connect( + self._on_secondary_tick_label_font_size_changed + ) + style_form.addRow( + "Secondary Y Tick Size", + self.secondary_tick_label_font_spin, + ) + self.legend_font_spin = self._build_font_spin() + self.legend_font_spin.valueChanged.connect( + self._on_legend_font_size_changed + ) + style_form.addRow("Legend Size", self.legend_font_spin) + self.annotation_font_spin = self._build_font_spin() + self.annotation_font_spin.valueChanged.connect( + self._on_annotation_font_size_changed + ) + style_form.addRow("Annotation Size", self.annotation_font_spin) + root.addWidget(style_group) + + display_group = QGroupBox("Display") + display_form = QFormLayout(display_group) + self.show_legend_checkbox = QCheckBox("Show Legend") + self.show_legend_checkbox.toggled.connect(self._on_show_legend_changed) + display_form.addRow(self.show_legend_checkbox) + self.legend_location_combo = QComboBox() + for label, value in LINE_PLOT_LEGEND_LOCATIONS: + self.legend_location_combo.addItem(label, value) + self.legend_location_combo.currentIndexChanged.connect( + self._on_legend_location_changed + ) + display_form.addRow("Legend Position", self.legend_location_combo) + self.show_annotation_checkbox = QCheckBox("Show Annotation") + self.show_annotation_checkbox.toggled.connect( + self._on_show_annotation_changed + ) + display_form.addRow(self.show_annotation_checkbox) + self.reset_display_button = QPushButton("Reset Display Defaults") + self.reset_display_button.clicked.connect(self._reset_display_defaults) + display_form.addRow(self.reset_display_button) + root.addWidget(display_group) + + labels_group = QGroupBox("Series Labels") + labels_layout = QVBoxLayout(labels_group) + labels_layout.setContentsMargins(8, 8, 8, 8) + labels_layout.setSpacing(8) + labels_note = QLabel( + "Edit display labels for the traces that are currently plotted." + ) + labels_note.setWordWrap(True) + labels_layout.addWidget(labels_note) + self.label_table = QTableWidget(0, 3) + self.label_table.setHorizontalHeaderLabels( + ["Series", "Axis", "Display Label"] + ) + self.label_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + self.label_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.ResizeToContents + ) + self.label_table.horizontalHeader().setStretchLastSection(True) + self.label_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.label_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.label_table.itemChanged.connect(self._on_label_item_changed) + labels_layout.addWidget(self.label_table) + self.reset_labels_button = QPushButton("Reset Series Labels") + self.reset_labels_button.clicked.connect(self._reset_labels) + labels_layout.addWidget(self.reset_labels_button) + root.addWidget(labels_group, stretch=1) + root.addStretch(1) + + @staticmethod + def _build_font_spin() -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(1) + spin.setRange(6.0, 40.0) + spin.setSingleStep(0.5) + return spin + + @staticmethod + def _build_position_spin() -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(2) + spin.setRange(-1.0, 2.0) + spin.setSingleStep(0.05) + return spin + + def _set_form_row_visible(self, field: QWidget, visible: bool) -> None: + label = self._form_label_for_field(field) + if label is not None: + label.setVisible(visible) + field.setVisible(visible) + + def _form_label_for_field(self, field: QWidget) -> QWidget | None: + for group in self.findChildren(QGroupBox): + layout = group.layout() + if isinstance(layout, QFormLayout): + label = layout.labelForField(field) + if label is not None: + return label + return None + + def _update_dynamic_field_visibility(self) -> None: + self._set_form_row_visible( + self.secondary_y_label_edit, + self._defaults.has_secondary_y_axis, + ) + self._set_form_row_visible( + self.secondary_axis_label_font_spin, + self._defaults.has_secondary_y_axis, + ) + self._set_form_row_visible( + self.secondary_tick_label_font_spin, + self._defaults.has_secondary_y_axis, + ) + self._set_form_row_visible( + self.residual_y_label_edit, + self._defaults.has_residual_y_axis, + ) + self._set_form_row_visible( + self.annotation_font_spin, + self._defaults.has_annotation, + ) + self.show_annotation_checkbox.setVisible(self._defaults.has_annotation) + + def _update_display_state(self) -> None: + self.legend_location_combo.setEnabled( + self.show_legend_checkbox.isChecked() + ) + self.annotation_font_spin.setEnabled( + self._defaults.has_annotation + and self.show_annotation_checkbox.isChecked() + ) + + def _sync_label_table(self) -> None: + self.label_table.blockSignals(True) + try: + self.label_table.setRowCount(len(self._defaults.series_defaults)) + for row, series in enumerate(self._defaults.series_defaults): + raw_item = QTableWidgetItem(series.label) + raw_item.setFlags( + raw_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + raw_item.setData(Qt.ItemDataRole.UserRole, series.key) + axis_item = QTableWidgetItem(series.axis_label) + axis_item.setFlags( + axis_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + label_item = QTableWidgetItem( + self._settings.display_series_label( + series.key, + series.label, + ) + ) + self.label_table.setItem(row, 0, raw_item) + self.label_table.setItem(row, 1, axis_item) + self.label_table.setItem(row, 2, label_item) + if ( + self._defaults.series_defaults + and self.label_table.currentRow() < 0 + ): + self.label_table.selectRow(0) + finally: + self.label_table.blockSignals(False) + + def _emit_settings_changed(self) -> None: + if not self._syncing: + self.settings_changed.emit() + + def _emit_label_settings_changed(self) -> None: + if not self._syncing: + self.label_settings_changed.emit() + + def _on_title_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.title = text + self._emit_settings_changed() + + def _on_x_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.x_label = text + self._emit_settings_changed() + + def _on_primary_y_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.primary_y_label = text + self._emit_settings_changed() + + def _on_secondary_y_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.secondary_y_label = text + self._emit_settings_changed() + + def _on_residual_y_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.residual_y_label = text + self._emit_settings_changed() + + def _on_font_changed(self, font: QFont) -> None: + if self._syncing: + return + self._settings.font_family = font.family() + self._emit_settings_changed() + + def _on_title_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_font_size = float(value) + self._emit_settings_changed() + + def _on_title_position_x_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_x = float(value) + self._emit_settings_changed() + + def _on_title_position_y_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_y = float(value) + self._emit_settings_changed() + + def _on_axis_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.axis_label_font_size = float(value) + self._emit_settings_changed() + + def _on_tick_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.tick_label_font_size = float(value) + self._emit_settings_changed() + + def _on_primary_axis_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.primary_axis_label_font_size = float(value) + self._emit_settings_changed() + + def _on_primary_tick_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.primary_tick_label_font_size = float(value) + self._emit_settings_changed() + + def _on_secondary_axis_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.secondary_axis_label_font_size = float(value) + self._emit_settings_changed() + + def _on_secondary_tick_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.secondary_tick_label_font_size = float(value) + self._emit_settings_changed() + + def _on_legend_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.legend_font_size = float(value) + self._emit_settings_changed() + + def _on_annotation_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.annotation_font_size = float(value) + self._emit_settings_changed() + + def _on_show_legend_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_legend = bool(checked) + self._update_display_state() + self._emit_settings_changed() + + def _on_legend_location_changed(self) -> None: + if self._syncing: + return + current = self.legend_location_combo.currentData() + self._settings.legend_location = ( + self._defaults.default_legend_location + if current is None + else str(current) + ) + self._emit_settings_changed() + + def _on_show_annotation_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_annotation = bool(checked) + self._update_display_state() + self._emit_settings_changed() + + def _on_label_item_changed(self, item: QTableWidgetItem) -> None: + if self._syncing or item.column() != 2: + return + raw_item = self.label_table.item(item.row(), 0) + if raw_item is None: + return + series_key = str(raw_item.data(Qt.ItemDataRole.UserRole) or "") + if not series_key: + return + self._settings.series_label_map[series_key] = item.text() + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _reset_text_defaults(self) -> None: + self._settings.title = None + self._settings.x_label = None + self._settings.primary_y_label = None + self._settings.secondary_y_label = None + self._settings.residual_y_label = None + self._settings.title_position_x = None + self._settings.title_position_y = None + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _reset_display_defaults(self) -> None: + self._settings.show_legend = None + self._settings.legend_location = None + self._settings.show_annotation = None + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _reset_labels(self) -> None: + self._settings.series_label_map = {} + self._settings.sync_series(self._defaults.series_defaults) + self.sync_defaults(self._defaults) + self._emit_label_settings_changed() + self._emit_settings_changed() + + +__all__ = [ + "LINE_PLOT_LEGEND_LOCATIONS", + "LinePlotDefaults", + "LinePlotEditorControls", + "LinePlotSeriesDefaults", + "LinePlotSettings", +] diff --git a/src/saxshell/plotting/plot_editor.py b/src/saxshell/plotting/plot_editor.py new file mode 100644 index 0000000..d4dd78d --- /dev/null +++ b/src/saxshell/plotting/plot_editor.py @@ -0,0 +1,1767 @@ +from __future__ import annotations + +import pickle +import re +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field +from pathlib import Path + +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure +from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import QFont +from PySide6.QtWidgets import ( + QAbstractItemView, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QFontComboBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QScrollArea, + QSpinBox, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +from saxshell.plotting.stacked_histogram import ( + STACKED_HISTOGRAM_LEGEND_LOCATIONS, + StackedHistogramPlotDefaults, + StackedHistogramPlotSettings, +) +from saxshell.saxs.stoichiometry import ( + format_stoich_for_axis, + sort_stoich_labels, +) + +PICKLED_PLOT_FILE_FILTER = "Pickled Plot Files (*.pkl);;All Files (*)" + + +def _default_pickled_plot_name(window_title: str) -> str: + stem = re.sub(r"[^0-9a-zA-Z]+", "_", window_title.strip().lower()).strip( + "_" + ) + if not stem: + stem = "plot" + return f"{stem}.pkl" + + +def save_pickled_plot_figure( + figure: Figure, + destination: str | Path, + *, + window_title: str = "", + extra_payload: Mapping[str, object] | None = None, +) -> Path: + path = Path(destination) + payload = { + "kind": "saxshell_plot_figure", + "version": 1, + "window_title": window_title, + "figure": figure, + } + if extra_payload is not None: + payload.update(dict(extra_payload)) + with path.open("wb") as handle: + pickle.dump(payload, handle, protocol=pickle.HIGHEST_PROTOCOL) + return path + + +def load_pickled_plot_payload(source: str | Path) -> dict[str, object]: + path = Path(source) + with path.open("rb") as handle: + payload = pickle.load(handle) + if isinstance(payload, Figure): + return { + "kind": "legacy_matplotlib_figure", + "figure": payload, + } + if isinstance(payload, dict): + figure = payload.get("figure") + if isinstance(figure, Figure): + return dict(payload) + legacy_figure = payload.get("fig") + if isinstance(legacy_figure, Figure): + updated_payload = dict(payload) + updated_payload["figure"] = legacy_figure + return updated_payload + raise ValueError(f"{path} does not contain a pickled Matplotlib figure") + + +def load_pickled_plot_figure(source: str | Path) -> Figure: + payload = load_pickled_plot_payload(source) + figure = payload.get("figure") + if isinstance(figure, Figure): + return figure + raise ValueError(f"{source} does not contain a pickled Matplotlib figure") + + +@dataclass(slots=True) +class HeatmapPlotDefaults: + title: str + x_label: str + y_label: str + colorbar_label: str + title_position_x: float = 0.5 + title_position_y: float = 1.0 + default_x_axis_unit_name: str = "" + available_x_axis_unit_names: tuple[str, ...] = () + default_colormap_name: str = "" + available_colormap_names: tuple[str, ...] = () + auto_color_limit_min: float = 0.0 + auto_color_limit_max: float = 1.0 + raw_cluster_labels: tuple[str, ...] = () + default_label_entries: tuple[tuple[str, str], ...] = () + + +@dataclass(slots=True) +class HeatmapPlotSettings: + title: str | None = None + x_label: str | None = None + y_label: str | None = None + colorbar_label: str | None = None + title_position_x: float | None = None + title_position_y: float | None = None + color_limit_min: float | None = None + color_limit_max: float | None = None + font_family: str = "" + title_font_size: float = 12.0 + axis_label_font_size: float = 11.0 + tick_label_font_size: float = 9.0 + cluster_label_font_size: float = 9.0 + aspect_mode: str = "auto" + custom_aspect: float = 1.0 + max_x_ticks: int = 8 + max_y_ticks: int = 24 + x_tick_rotation: int = 0 + y_tick_rotation: int = 0 + show_minor_x_ticks: bool = False + show_minor_y_ticks: bool = False + label_order: list[str] = field(default_factory=list) + label_map: dict[str, str] = field(default_factory=dict) + + def resolve_title(self, defaults: HeatmapPlotDefaults) -> str: + return defaults.title if self.title is None else self.title + + def resolve_x_label(self, defaults: HeatmapPlotDefaults) -> str: + return defaults.x_label if self.x_label is None else self.x_label + + def resolve_y_label(self, defaults: HeatmapPlotDefaults) -> str: + return defaults.y_label if self.y_label is None else self.y_label + + def resolve_colorbar_label(self, defaults: HeatmapPlotDefaults) -> str: + return ( + defaults.colorbar_label + if self.colorbar_label is None + else self.colorbar_label + ) + + def resolve_title_position_x(self, defaults: HeatmapPlotDefaults) -> float: + return ( + defaults.title_position_x + if self.title_position_x is None + else self.title_position_x + ) + + def resolve_title_position_y(self, defaults: HeatmapPlotDefaults) -> float: + return ( + defaults.title_position_y + if self.title_position_y is None + else self.title_position_y + ) + + def has_manual_color_limits(self) -> bool: + return ( + self.color_limit_min is not None + or self.color_limit_max is not None + ) + + def resolve_color_limit_min(self, defaults: HeatmapPlotDefaults) -> float: + return ( + defaults.auto_color_limit_min + if self.color_limit_min is None + else self.color_limit_min + ) + + def resolve_color_limit_max(self, defaults: HeatmapPlotDefaults) -> float: + return ( + defaults.auto_color_limit_max + if self.color_limit_max is None + else self.color_limit_max + ) + + def sync_labels( + self, + raw_labels: Sequence[str], + *, + default_label_entries: Sequence[tuple[str, str]] | None = None, + ) -> None: + default_entries = ( + [ + (str(raw_label), format_stoich_for_axis(str(raw_label))) + for raw_label in raw_labels + ] + if default_label_entries is None + else [ + (str(raw_label), str(display_label)) + for raw_label, display_label in default_label_entries + ] + ) + default_map = { + raw_label: display_label + for raw_label, display_label in default_entries + } + existing = dict(self.label_map) + preserved_order = [ + raw_label + for raw_label in self.label_order + if raw_label in default_map + ] + remaining = [ + raw_label + for raw_label, _display_label in default_entries + if raw_label not in preserved_order + ] + self.label_order = preserved_order + remaining + self.label_map = { + raw_label: existing.get(raw_label, default_map[raw_label]) + for raw_label in self.label_order + } + + def display_label(self, raw_label: str) -> str: + return self.label_map.get(raw_label, raw_label) + + def ordered_raw_labels( + self, + defaults: HeatmapPlotDefaults, + ) -> list[str]: + if self.label_order: + available = set(defaults.raw_cluster_labels) + ordered = [raw for raw in self.label_order if raw in available] + remaining = [ + raw + for raw in defaults.raw_cluster_labels + if raw not in ordered + ] + return ordered + remaining + return list(defaults.raw_cluster_labels) + + def to_dict(self) -> dict[str, object]: + return { + "title": self.title, + "x_label": self.x_label, + "y_label": self.y_label, + "colorbar_label": self.colorbar_label, + "title_position_x": self.title_position_x, + "title_position_y": self.title_position_y, + "color_limit_min": self.color_limit_min, + "color_limit_max": self.color_limit_max, + "font_family": self.font_family, + "title_font_size": self.title_font_size, + "axis_label_font_size": self.axis_label_font_size, + "tick_label_font_size": self.tick_label_font_size, + "cluster_label_font_size": self.cluster_label_font_size, + "aspect_mode": self.aspect_mode, + "custom_aspect": self.custom_aspect, + "max_x_ticks": self.max_x_ticks, + "max_y_ticks": self.max_y_ticks, + "x_tick_rotation": self.x_tick_rotation, + "y_tick_rotation": self.y_tick_rotation, + "show_minor_x_ticks": self.show_minor_x_ticks, + "show_minor_y_ticks": self.show_minor_y_ticks, + "label_order": list(self.label_order), + "label_map": dict(self.label_map), + } + + def update_from_dict(self, payload: Mapping[str, object]) -> None: + for field_name in ( + "title", + "x_label", + "y_label", + "colorbar_label", + "title_position_x", + "title_position_y", + "color_limit_min", + "color_limit_max", + ): + if field_name in payload: + setattr(self, field_name, payload[field_name]) + if "font_family" in payload: + self.font_family = str(payload["font_family"] or "") + for field_name in ( + "title_font_size", + "axis_label_font_size", + "tick_label_font_size", + "cluster_label_font_size", + "custom_aspect", + ): + if field_name in payload: + setattr(self, field_name, float(payload[field_name])) + if "aspect_mode" in payload: + self.aspect_mode = str(payload["aspect_mode"]) + for field_name in ( + "max_x_ticks", + "max_y_ticks", + "x_tick_rotation", + "y_tick_rotation", + ): + if field_name in payload: + setattr(self, field_name, int(payload[field_name])) + for field_name in ("show_minor_x_ticks", "show_minor_y_ticks"): + if field_name in payload: + setattr(self, field_name, bool(payload[field_name])) + if "label_order" in payload: + self.label_order = [str(value) for value in payload["label_order"]] + if "label_map" in payload: + label_map = payload["label_map"] + if isinstance(label_map, Mapping): + self.label_map = { + str(key): str(value) for key, value in label_map.items() + } + + +class PlotEditorWindow(QWidget): + """Reusable popup shell for plot editors with a live Matplotlib + preview.""" + + closed = Signal() + + def __init__( + self, + *, + window_title: str, + controls_widget: QWidget | None, + render_preview: Callable[[Figure], None] | None, + pickle_default_name: str | None = None, + pickle_state_provider: ( + Callable[[], Mapping[str, object] | None] | None + ) = None, + apply_loaded_pickle_state: ( + Callable[[Mapping[str, object]], bool] | None + ) = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent, Qt.WindowType.Window) + self._render_preview = render_preview + self._pickle_state_provider = pickle_state_provider + self._apply_loaded_pickle_state = apply_loaded_pickle_state + self._showing_pickled_plot = False + self._last_pickle_path: Path | None = None + self._pickle_default_name = ( + _default_pickled_plot_name(window_title) + if pickle_default_name is None + else pickle_default_name + ) + self._preview_toolbar: NavigationToolbar | None = None + self.canvas: FigureCanvas | None = None + self.figure = Figure(figsize=(7.8, 6.2)) + self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, True) + self.setWindowTitle(window_title) + self.resize(1260, 760) + + root = QHBoxLayout(self) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(10) + + if controls_widget is not None: + controls_scroll = QScrollArea() + controls_scroll.setWidgetResizable(True) + controls_scroll.setMinimumWidth(410) + controls_scroll.setWidget(controls_widget) + root.addWidget(controls_scroll, stretch=0) + + preview_panel = QWidget() + preview_layout = QVBoxLayout(preview_panel) + preview_layout.setContentsMargins(0, 0, 0, 0) + preview_layout.setSpacing(6) + preview_layout.addWidget(QLabel("Preview")) + preview_button_row = QHBoxLayout() + preview_button_row.setContentsMargins(0, 0, 0, 0) + preview_button_row.setSpacing(6) + self.save_pickle_button = QPushButton("Save Pickled Plot") + self.save_pickle_button.clicked.connect(self.save_pickled_plot_as) + preview_button_row.addWidget(self.save_pickle_button) + self.load_pickle_button = QPushButton("Load Pickled Plot") + self.load_pickle_button.clicked.connect(self.load_pickled_plot_as) + preview_button_row.addWidget(self.load_pickle_button) + self.show_live_preview_button = QPushButton("Show Live Plot") + self.show_live_preview_button.clicked.connect(self.show_live_preview) + self.show_live_preview_button.setEnabled(False) + preview_button_row.addWidget(self.show_live_preview_button) + preview_button_row.addStretch(1) + preview_layout.addLayout(preview_button_row) + self._preview_canvas_layout = QVBoxLayout() + self._preview_canvas_layout.setContentsMargins(0, 0, 0, 0) + self._preview_canvas_layout.setSpacing(0) + preview_layout.addLayout(self._preview_canvas_layout, stretch=1) + self._set_preview_figure(self.figure) + root.addWidget(preview_panel, stretch=1) + + def _set_preview_figure(self, figure: Figure) -> None: + if self._preview_toolbar is not None: + self._preview_canvas_layout.removeWidget(self._preview_toolbar) + self._preview_toolbar.setParent(None) + self._preview_toolbar.deleteLater() + self._preview_toolbar = None + if self.canvas is not None: + self._preview_canvas_layout.removeWidget(self.canvas) + self.canvas.setParent(None) + self.canvas.deleteLater() + self.canvas = None + self.figure = figure + self.canvas = FigureCanvas(self.figure) + self._preview_toolbar = NavigationToolbar(self.canvas, self) + self._preview_canvas_layout.addWidget(self._preview_toolbar) + self._preview_canvas_layout.addWidget(self.canvas, stretch=1) + + def is_showing_pickled_plot(self) -> bool: + return self._showing_pickled_plot + + def refresh_preview(self, *, force: bool = False) -> None: + if self._showing_pickled_plot and not force: + return + if force: + self._showing_pickled_plot = False + self.show_live_preview_button.setEnabled(False) + if self._render_preview is None: + if self.canvas is not None: + self.canvas.draw_idle() + return + self._render_preview(self.figure) + if self.canvas is not None: + self.canvas.draw_idle() + + def _default_pickle_path(self) -> Path: + if self._last_pickle_path is not None: + return self._last_pickle_path + return Path.cwd() / self._pickle_default_name + + def save_pickled_plot(self, destination: str | Path) -> Path: + extra_payload = ( + None + if self._pickle_state_provider is None + else self._pickle_state_provider() + ) + saved_path = save_pickled_plot_figure( + self.figure, + destination, + window_title=self.windowTitle(), + extra_payload=extra_payload, + ) + self._last_pickle_path = saved_path + return saved_path + + def save_pickled_plot_as(self) -> Path | None: + selected_path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Save Pickled Plot", + str(self._default_pickle_path()), + PICKLED_PLOT_FILE_FILTER, + ) + if not selected_path: + return None + destination = Path(selected_path) + if destination.suffix == "": + destination = destination.with_suffix(".pkl") + try: + return self.save_pickled_plot(destination) + except Exception as exc: # pragma: no cover - defensive UI guard + QMessageBox.warning( + self, + "Save Pickled Plot", + f"Could not save the pickled plot:\n{exc}", + ) + return None + + def load_pickled_plot(self, source: str | Path) -> Figure: + payload = load_pickled_plot_payload(source) + self._last_pickle_path = Path(source) + if ( + self._apply_loaded_pickle_state is not None + and self._apply_loaded_pickle_state(payload) + ): + self.refresh_preview(force=True) + return self.figure + + figure = payload.get("figure") + if not isinstance(figure, Figure): + raise ValueError( + f"{source} does not contain a pickled Matplotlib figure" + ) + self._set_preview_figure(figure) + self._showing_pickled_plot = True + self.show_live_preview_button.setEnabled( + self._render_preview is not None + ) + if self.canvas is not None: + self.canvas.draw_idle() + return figure + + def load_pickled_plot_as(self) -> Path | None: + selected_path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Load Pickled Plot", + str(self._default_pickle_path()), + PICKLED_PLOT_FILE_FILTER, + ) + if not selected_path: + return None + source = Path(selected_path) + try: + self.load_pickled_plot(source) + except Exception as exc: # pragma: no cover - defensive UI guard + QMessageBox.warning( + self, + "Load Pickled Plot", + f"Could not load the pickled plot:\n{exc}", + ) + return None + return source + + def show_live_preview(self) -> None: + self.refresh_preview(force=True) + + def closeEvent(self, event) -> None: # noqa: N802 + self.closed.emit() + super().closeEvent(event) + + +class HeatmapPlotEditorControls(QWidget): + """Editable controls for reusable heatmap/colormap plot settings.""" + + settings_changed = Signal() + x_axis_unit_changed = Signal(str) + colormap_changed = Signal(str) + + def __init__( + self, + *, + settings: HeatmapPlotSettings, + defaults: HeatmapPlotDefaults, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._settings = settings + self._defaults = defaults + self._last_synced_defaults: HeatmapPlotDefaults | None = None + self._syncing = False + self._build_ui() + self.sync_defaults(defaults) + + def needs_default_sync(self, defaults: HeatmapPlotDefaults) -> bool: + return self._last_synced_defaults != defaults + + def sync_defaults(self, defaults: HeatmapPlotDefaults) -> None: + self._defaults = defaults + self._settings.sync_labels( + defaults.raw_cluster_labels, + default_label_entries=defaults.default_label_entries, + ) + self._syncing = True + try: + self.title_edit.setText(self._settings.resolve_title(defaults)) + self.x_label_edit.setText(self._settings.resolve_x_label(defaults)) + self.y_label_edit.setText(self._settings.resolve_y_label(defaults)) + self.colorbar_label_edit.setText( + self._settings.resolve_colorbar_label(defaults) + ) + self.title_position_x_spin.setValue( + self._settings.resolve_title_position_x(defaults) + ) + self.title_position_y_spin.setValue( + self._settings.resolve_title_position_y(defaults) + ) + self.x_axis_unit_combo.clear() + for unit_name in defaults.available_x_axis_unit_names: + self.x_axis_unit_combo.addItem(unit_name, unit_name) + if defaults.available_x_axis_unit_names: + self.x_axis_unit_combo.setCurrentIndex( + max( + 0, + self.x_axis_unit_combo.findData( + defaults.default_x_axis_unit_name + ), + ) + ) + self.x_axis_unit_combo.setEnabled( + bool(defaults.available_x_axis_unit_names) + ) + self.colormap_combo.clear() + for colormap_name in defaults.available_colormap_names: + self.colormap_combo.addItem(colormap_name, colormap_name) + if defaults.available_colormap_names: + self.colormap_combo.setCurrentIndex( + max( + 0, + self.colormap_combo.findData( + defaults.default_colormap_name + ), + ) + ) + self.colormap_combo.setEnabled( + bool(defaults.available_colormap_names) + ) + self.color_limit_min_spin.setValue( + self._settings.resolve_color_limit_min(defaults) + ) + self.color_limit_max_spin.setValue( + self._settings.resolve_color_limit_max(defaults) + ) + self._update_color_limit_reset_state() + if self._settings.font_family: + self.font_combo.setCurrentFont( + QFont(self._settings.font_family) + ) + self.title_font_spin.setValue(self._settings.title_font_size) + self.axis_label_font_spin.setValue( + self._settings.axis_label_font_size + ) + self.tick_label_font_spin.setValue( + self._settings.tick_label_font_size + ) + self.cluster_label_font_spin.setValue( + self._settings.cluster_label_font_size + ) + self.aspect_combo.setCurrentIndex( + max( + 0, + self.aspect_combo.findData(self._settings.aspect_mode), + ) + ) + self.aspect_value_spin.setValue(self._settings.custom_aspect) + self.max_x_ticks_spin.setValue(self._settings.max_x_ticks) + self.max_y_ticks_spin.setValue(self._settings.max_y_ticks) + self.x_tick_rotation_spin.setValue(self._settings.x_tick_rotation) + self.y_tick_rotation_spin.setValue(self._settings.y_tick_rotation) + self.minor_x_ticks_checkbox.setChecked( + self._settings.show_minor_x_ticks + ) + self.minor_y_ticks_checkbox.setChecked( + self._settings.show_minor_y_ticks + ) + self._sync_label_table() + self._update_aspect_state() + finally: + self._last_synced_defaults = defaults + self._syncing = False + + def _build_ui(self) -> None: + root = QVBoxLayout(self) + root.setContentsMargins(0, 0, 0, 0) + root.setSpacing(10) + + note = QLabel( + "Edit the heatmap title, labels, font, aspect ratio, and tick " + "density. Use $_{n}$ for subscript and $^{n}$ for superscript " + "(matplotlib mathtext). Igor-style inline text is also " + "supported: \\f01 bold, \\f02 italics, \\f00 reset, and " + "\\Z inline font size." + ) + note.setWordWrap(True) + root.addWidget(note) + + text_group = QGroupBox("Text") + text_form = QFormLayout(text_group) + self.title_edit = QLineEdit() + self.title_edit.textChanged.connect(self._on_title_changed) + text_form.addRow("Title", self.title_edit) + self.x_axis_unit_combo = QComboBox() + self.x_axis_unit_combo.currentIndexChanged.connect( + self._on_x_axis_unit_changed + ) + text_form.addRow("X Unit", self.x_axis_unit_combo) + self.x_label_edit = QLineEdit() + self.x_label_edit.textChanged.connect(self._on_x_label_changed) + text_form.addRow("X Label", self.x_label_edit) + self.y_label_edit = QLineEdit() + self.y_label_edit.textChanged.connect(self._on_y_label_changed) + text_form.addRow("Y Label", self.y_label_edit) + self.colorbar_label_edit = QLineEdit() + self.colorbar_label_edit.textChanged.connect( + self._on_colorbar_label_changed + ) + text_form.addRow("Colorbar", self.colorbar_label_edit) + self.colormap_combo = QComboBox() + self.colormap_combo.currentIndexChanged.connect( + self._on_colormap_changed + ) + text_form.addRow("Colormap", self.colormap_combo) + self.reset_text_button = QPushButton("Reset Text Defaults") + self.reset_text_button.clicked.connect(self._reset_text_defaults) + text_form.addRow(self.reset_text_button) + root.addWidget(text_group) + + color_group = QGroupBox("Color Scale") + color_form = QFormLayout(color_group) + self.color_limit_min_spin = self._build_color_limit_spin() + self.color_limit_min_spin.valueChanged.connect( + self._on_color_limit_min_changed + ) + color_form.addRow("Min", self.color_limit_min_spin) + self.color_limit_max_spin = self._build_color_limit_spin() + self.color_limit_max_spin.valueChanged.connect( + self._on_color_limit_max_changed + ) + color_form.addRow("Max", self.color_limit_max_spin) + self.reset_color_limits_button = QPushButton("Reset to Auto Limits") + self.reset_color_limits_button.clicked.connect( + self._reset_color_limits + ) + color_form.addRow(self.reset_color_limits_button) + root.addWidget(color_group) + + style_group = QGroupBox("Style") + style_form = QFormLayout(style_group) + self.font_combo = QFontComboBox() + self.font_combo.currentFontChanged.connect(self._on_font_changed) + style_form.addRow("Font", self.font_combo) + self.title_font_spin = self._build_font_spin() + self.title_font_spin.valueChanged.connect( + self._on_title_font_size_changed + ) + style_form.addRow("Title Size", self.title_font_spin) + self.title_position_x_spin = self._build_position_spin() + self.title_position_x_spin.valueChanged.connect( + self._on_title_position_x_changed + ) + style_form.addRow("Title X", self.title_position_x_spin) + self.title_position_y_spin = self._build_position_spin() + self.title_position_y_spin.valueChanged.connect( + self._on_title_position_y_changed + ) + style_form.addRow("Title Y", self.title_position_y_spin) + self.axis_label_font_spin = self._build_font_spin() + self.axis_label_font_spin.valueChanged.connect( + self._on_axis_label_font_size_changed + ) + style_form.addRow("Axis Label Size", self.axis_label_font_spin) + self.tick_label_font_spin = self._build_font_spin() + self.tick_label_font_spin.valueChanged.connect( + self._on_tick_label_font_size_changed + ) + style_form.addRow("Tick Label Size", self.tick_label_font_spin) + self.cluster_label_font_spin = self._build_font_spin() + self.cluster_label_font_spin.valueChanged.connect( + self._on_cluster_label_font_size_changed + ) + style_form.addRow("Cluster Label Size", self.cluster_label_font_spin) + + self.aspect_combo = QComboBox() + self.aspect_combo.addItem("Auto", "auto") + self.aspect_combo.addItem("Equal", "equal") + self.aspect_combo.addItem("Custom Ratio", "custom") + self.aspect_combo.currentIndexChanged.connect( + self._on_aspect_mode_changed + ) + style_form.addRow("Aspect", self.aspect_combo) + self.aspect_value_spin = QDoubleSpinBox() + self.aspect_value_spin.setDecimals(2) + self.aspect_value_spin.setRange(0.1, 10.0) + self.aspect_value_spin.setSingleStep(0.1) + self.aspect_value_spin.valueChanged.connect( + self._on_aspect_value_changed + ) + style_form.addRow("Aspect Ratio", self.aspect_value_spin) + root.addWidget(style_group) + + ticks_group = QGroupBox("Ticks") + ticks_form = QFormLayout(ticks_group) + self.max_x_ticks_spin = QSpinBox() + self.max_x_ticks_spin.setRange(2, 20) + self.max_x_ticks_spin.valueChanged.connect( + self._on_max_x_ticks_changed + ) + ticks_form.addRow("Max X Ticks", self.max_x_ticks_spin) + self.max_y_ticks_spin = QSpinBox() + self.max_y_ticks_spin.setRange(1, 60) + self.max_y_ticks_spin.valueChanged.connect( + self._on_max_y_ticks_changed + ) + ticks_form.addRow("Max Y Ticks", self.max_y_ticks_spin) + self.x_tick_rotation_spin = QSpinBox() + self.x_tick_rotation_spin.setRange(-180, 180) + self.x_tick_rotation_spin.valueChanged.connect( + self._on_x_tick_rotation_changed + ) + ticks_form.addRow("X Tick Rotation", self.x_tick_rotation_spin) + self.y_tick_rotation_spin = QSpinBox() + self.y_tick_rotation_spin.setRange(-180, 180) + self.y_tick_rotation_spin.valueChanged.connect( + self._on_y_tick_rotation_changed + ) + ticks_form.addRow("Y Tick Rotation", self.y_tick_rotation_spin) + self.minor_x_ticks_checkbox = QCheckBox("Show X Minor Ticks") + self.minor_x_ticks_checkbox.toggled.connect( + self._on_minor_x_ticks_changed + ) + ticks_form.addRow(self.minor_x_ticks_checkbox) + self.minor_y_ticks_checkbox = QCheckBox("Show Y Minor Ticks") + self.minor_y_ticks_checkbox.toggled.connect( + self._on_minor_y_ticks_changed + ) + ticks_form.addRow(self.minor_y_ticks_checkbox) + root.addWidget(ticks_group) + + labels_group = QGroupBox("Axis Labels") + labels_layout = QVBoxLayout(labels_group) + labels_layout.setContentsMargins(8, 8, 8, 8) + labels_layout.setSpacing(8) + labels_note = QLabel( + "Rearrange rows to control axis-bin order and edit Display Label " + "to customise tick text. The raw label column stays fixed so you " + "can round-trip custom formatting." + ) + labels_note.setWordWrap(True) + labels_layout.addWidget(labels_note) + self.label_table = QTableWidget(0, 2) + self.label_table.setHorizontalHeaderLabels( + ["Raw Label", "Display Label"] + ) + self.label_table.horizontalHeader().setStretchLastSection(True) + self.label_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.label_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.label_table.itemChanged.connect(self._on_label_item_changed) + content_row = QHBoxLayout() + content_row.addWidget(self.label_table, stretch=1) + move_column = QVBoxLayout() + self.move_label_up_button = QPushButton("Up") + self.move_label_up_button.clicked.connect(self._move_label_up) + move_column.addWidget(self.move_label_up_button) + self.move_label_down_button = QPushButton("Down") + self.move_label_down_button.clicked.connect(self._move_label_down) + move_column.addWidget(self.move_label_down_button) + move_column.addStretch(1) + content_row.addLayout(move_column) + labels_layout.addLayout(content_row, stretch=1) + button_row = QHBoxLayout() + self.histogram_order_button = QPushButton("Sort Like Histogram") + self.histogram_order_button.clicked.connect( + self._apply_histogram_order + ) + button_row.addWidget(self.histogram_order_button) + self.auto_subscript_button = QPushButton("Auto Stoich Subscripts") + self.auto_subscript_button.clicked.connect( + self._apply_stoich_subscripts + ) + button_row.addWidget(self.auto_subscript_button) + self.reset_labels_button = QPushButton("Reset Labels") + self.reset_labels_button.clicked.connect(self._reset_labels) + button_row.addWidget(self.reset_labels_button) + button_row.addStretch(1) + labels_layout.addLayout(button_row) + root.addWidget(labels_group, stretch=1) + root.addStretch(1) + + @staticmethod + def _build_font_spin() -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(1) + spin.setRange(6.0, 40.0) + spin.setSingleStep(0.5) + return spin + + @staticmethod + def _build_position_spin() -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(2) + spin.setRange(-1.0, 2.0) + spin.setSingleStep(0.05) + return spin + + @staticmethod + def _build_color_limit_spin() -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setDecimals(6) + spin.setRange(-1_000_000_000.0, 1_000_000_000.0) + spin.setSingleStep(0.1) + spin.setAccelerated(True) + return spin + + def _sync_label_table(self) -> None: + self.label_table.blockSignals(True) + try: + raw_labels = self._settings.ordered_raw_labels(self._defaults) + self.label_table.setRowCount(len(raw_labels)) + for row, raw_label in enumerate(raw_labels): + raw_item = QTableWidgetItem(raw_label) + raw_item.setFlags( + raw_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + self.label_table.setItem(row, 0, raw_item) + self.label_table.setItem( + row, + 1, + QTableWidgetItem(self._settings.display_label(raw_label)), + ) + if raw_labels: + self.label_table.resizeColumnToContents(0) + if self.label_table.currentRow() < 0: + self.label_table.selectRow(0) + finally: + self.label_table.blockSignals(False) + + def _update_aspect_state(self) -> None: + aspect_mode = self.aspect_combo.currentData() + self.aspect_value_spin.setEnabled(aspect_mode == "custom") + + def _update_color_limit_reset_state(self) -> None: + self.reset_color_limits_button.setEnabled( + self._settings.has_manual_color_limits() + ) + + def _emit_settings_changed(self) -> None: + if not self._syncing: + self.settings_changed.emit() + + def _on_title_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.title = text + self._emit_settings_changed() + + def _on_x_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.x_label = text + self._emit_settings_changed() + + def _on_x_axis_unit_changed(self) -> None: + if self._syncing: + return + current = self.x_axis_unit_combo.currentData() + if current is None: + return + self.x_axis_unit_changed.emit(str(current)) + + def _on_y_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.y_label = text + self._emit_settings_changed() + + def _on_colorbar_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.colorbar_label = text + self._emit_settings_changed() + + def _on_colormap_changed(self) -> None: + if self._syncing: + return + current = self.colormap_combo.currentData() + if current is None: + return + self.colormap_changed.emit(str(current)) + + @staticmethod + def _color_limit_delta(anchor: float) -> float: + return max(abs(float(anchor)) * 0.01, 1.0e-6) + + def _set_manual_color_limits( + self, + *, + minimum: float | None = None, + maximum: float | None = None, + changed: str, + ) -> None: + current_min = self._settings.resolve_color_limit_min(self._defaults) + current_max = self._settings.resolve_color_limit_max(self._defaults) + new_min = current_min if minimum is None else float(minimum) + new_max = current_max if maximum is None else float(maximum) + if new_max <= new_min: + if changed == "min": + new_max = new_min + self._color_limit_delta(new_min) + else: + new_min = new_max - self._color_limit_delta(new_max) + + self._settings.color_limit_min = float(new_min) + self._settings.color_limit_max = float(new_max) + + self._syncing = True + try: + self.color_limit_min_spin.setValue(float(new_min)) + self.color_limit_max_spin.setValue(float(new_max)) + self._update_color_limit_reset_state() + finally: + self._syncing = False + + self._emit_settings_changed() + + def _on_color_limit_min_changed(self, value: float) -> None: + if self._syncing: + return + self._set_manual_color_limits(minimum=value, changed="min") + + def _on_color_limit_max_changed(self, value: float) -> None: + if self._syncing: + return + self._set_manual_color_limits(maximum=value, changed="max") + + def _on_font_changed(self, font: QFont) -> None: + if self._syncing: + return + self._settings.font_family = font.family() + self._emit_settings_changed() + + def _on_title_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_font_size = float(value) + self._emit_settings_changed() + + def _on_title_position_x_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_x = float(value) + self._emit_settings_changed() + + def _on_title_position_y_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_y = float(value) + self._emit_settings_changed() + + def _on_axis_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.axis_label_font_size = float(value) + self._emit_settings_changed() + + def _on_tick_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.tick_label_font_size = float(value) + self._emit_settings_changed() + + def _on_cluster_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.cluster_label_font_size = float(value) + self._emit_settings_changed() + + def _on_aspect_mode_changed(self) -> None: + if self._syncing: + return + aspect_mode = self.aspect_combo.currentData() + self._settings.aspect_mode = ( + "auto" if aspect_mode is None else str(aspect_mode) + ) + self._update_aspect_state() + self._emit_settings_changed() + + def _on_aspect_value_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.custom_aspect = float(value) + self._emit_settings_changed() + + def _on_max_x_ticks_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.max_x_ticks = int(value) + self._emit_settings_changed() + + def _on_max_y_ticks_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.max_y_ticks = int(value) + self._emit_settings_changed() + + def _on_x_tick_rotation_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.x_tick_rotation = int(value) + self._emit_settings_changed() + + def _on_y_tick_rotation_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.y_tick_rotation = int(value) + self._emit_settings_changed() + + def _on_minor_x_ticks_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_minor_x_ticks = bool(checked) + self._emit_settings_changed() + + def _on_minor_y_ticks_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_minor_y_ticks = bool(checked) + self._emit_settings_changed() + + def _on_label_item_changed(self, item: QTableWidgetItem) -> None: + if self._syncing or item.column() != 1: + return + raw_item = self.label_table.item(item.row(), 0) + if raw_item is None: + return + self._settings.label_map[raw_item.text()] = item.text() + self._emit_settings_changed() + + def _move_label_up(self) -> None: + row = self.label_table.currentRow() + if row <= 0: + return + self._swap_label_rows(row, row - 1) + self.label_table.selectRow(row - 1) + self._store_label_order_from_table() + self._emit_settings_changed() + + def _move_label_down(self) -> None: + row = self.label_table.currentRow() + if row < 0 or row >= self.label_table.rowCount() - 1: + return + self._swap_label_rows(row, row + 1) + self.label_table.selectRow(row + 1) + self._store_label_order_from_table() + self._emit_settings_changed() + + def _swap_label_rows(self, row_a: int, row_b: int) -> None: + self.label_table.blockSignals(True) + try: + for column in range(self.label_table.columnCount()): + item_a = self.label_table.takeItem(row_a, column) + item_b = self.label_table.takeItem(row_b, column) + if item_a is not None: + self.label_table.setItem(row_b, column, item_a) + if item_b is not None: + self.label_table.setItem(row_a, column, item_b) + finally: + self.label_table.blockSignals(False) + + def _store_label_order_from_table(self) -> None: + self._settings.label_order = [] + for row in range(self.label_table.rowCount()): + raw_item = self.label_table.item(row, 0) + display_item = self.label_table.item(row, 1) + if raw_item is None: + continue + raw_label = raw_item.text() + self._settings.label_order.append(raw_label) + self._settings.label_map[raw_label] = ( + display_item.text() if display_item is not None else raw_label + ) + + def _reset_text_defaults(self) -> None: + self._settings.title = None + self._settings.x_label = None + self._settings.y_label = None + self._settings.colorbar_label = None + self._settings.title_position_x = None + self._settings.title_position_y = None + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _reset_color_limits(self) -> None: + self._settings.color_limit_min = None + self._settings.color_limit_max = None + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _reset_labels(self) -> None: + self._settings.label_order = list(self._defaults.raw_cluster_labels) + default_map = dict(self._defaults.default_label_entries) + self._settings.label_map = { + raw_label: default_map.get(raw_label, raw_label) + for raw_label in self._defaults.raw_cluster_labels + } + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _apply_stoich_subscripts(self) -> None: + self._settings.label_map = { + raw_label: format_stoich_for_axis(raw_label) + for raw_label in self._settings.ordered_raw_labels(self._defaults) + } + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _apply_histogram_order(self) -> None: + ordered = sort_stoich_labels(self._defaults.raw_cluster_labels) + self._settings.label_order = list(ordered) + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + +class StackedHistogramPlotEditorControls(QWidget): + """Editable controls for reusable stacked-histogram plot + settings.""" + + settings_changed = Signal() + colormap_changed = Signal(str) + label_settings_changed = Signal() + + def __init__( + self, + *, + settings: StackedHistogramPlotSettings, + defaults: StackedHistogramPlotDefaults, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self._settings = settings + self._defaults = defaults + self._last_synced_defaults: StackedHistogramPlotDefaults | None = None + self._syncing = False + self._build_ui() + self.sync_defaults(defaults) + + def needs_default_sync( + self, defaults: StackedHistogramPlotDefaults + ) -> bool: + return self._last_synced_defaults != defaults + + def sync_defaults(self, defaults: StackedHistogramPlotDefaults) -> None: + self._defaults = defaults + self._settings.sync_labels( + defaults.raw_category_labels, + default_label_entries=defaults.default_label_entries, + ) + self._syncing = True + try: + self.title_edit.setText(self._settings.resolve_title(defaults)) + self.x_label_edit.setText(self._settings.resolve_x_label(defaults)) + self.y_label_edit.setText(self._settings.resolve_y_label(defaults)) + self.legend_title_edit.setText( + self._settings.resolve_legend_title(defaults) + ) + self.title_position_x_spin.setValue( + self._settings.resolve_title_position_x(defaults) + ) + self.title_position_y_spin.setValue( + self._settings.resolve_title_position_y(defaults) + ) + self.colormap_combo.clear() + for colormap_name in defaults.available_colormap_names: + self.colormap_combo.addItem(colormap_name, colormap_name) + if defaults.available_colormap_names: + self.colormap_combo.setCurrentIndex( + max( + 0, + self.colormap_combo.findData( + defaults.default_colormap_name + ), + ) + ) + self.colormap_combo.setEnabled( + bool(defaults.available_colormap_names) + ) + if self._settings.font_family: + self.font_combo.setCurrentFont( + QFont(self._settings.font_family) + ) + self.title_font_spin.setValue(self._settings.title_font_size) + self.axis_label_font_spin.setValue( + self._settings.axis_label_font_size + ) + self.tick_label_font_spin.setValue( + self._settings.tick_label_font_size + ) + self.legend_font_spin.setValue(self._settings.legend_font_size) + self.annotation_font_spin.setValue( + self._settings.annotation_font_size + ) + self.show_total_annotations_checkbox.setChecked( + self._settings.show_total_annotations + ) + self.show_legend_checkbox.setChecked(self._settings.show_legend) + self.legend_location_combo.setCurrentIndex( + max( + 0, + self.legend_location_combo.findData( + self._settings.legend_location + ), + ) + ) + self.max_y_ticks_spin.setValue(self._settings.max_y_ticks) + self.x_tick_rotation_spin.setValue(self._settings.x_tick_rotation) + self.y_tick_rotation_spin.setValue(self._settings.y_tick_rotation) + self.minor_y_ticks_checkbox.setChecked( + self._settings.show_minor_y_ticks + ) + self._sync_label_table() + self._update_legend_state() + finally: + self._last_synced_defaults = defaults + self._syncing = False + + def _build_ui(self) -> None: + root = QVBoxLayout(self) + root.setContentsMargins(0, 0, 0, 0) + root.setSpacing(10) + + note = QLabel( + "Edit stacked-histogram text, legend layout, tick styling, and " + "category order. Use $_{n}$ for subscript and $^{n}$ for " + "superscript (matplotlib mathtext). Igor-style inline text is " + "also supported: \\f01 bold, \\f02 italics, \\f00 reset, and " + "\\Z inline font size." + ) + note.setWordWrap(True) + root.addWidget(note) + + text_group = QGroupBox("Text") + text_form = QFormLayout(text_group) + self.title_edit = QLineEdit() + self.title_edit.textChanged.connect(self._on_title_changed) + text_form.addRow("Title", self.title_edit) + self.x_label_edit = QLineEdit() + self.x_label_edit.textChanged.connect(self._on_x_label_changed) + text_form.addRow("X Label", self.x_label_edit) + self.y_label_edit = QLineEdit() + self.y_label_edit.textChanged.connect(self._on_y_label_changed) + text_form.addRow("Y Label", self.y_label_edit) + self.legend_title_edit = QLineEdit() + self.legend_title_edit.textChanged.connect( + self._on_legend_title_changed + ) + text_form.addRow("Legend Title", self.legend_title_edit) + self.colormap_combo = QComboBox() + self.colormap_combo.currentIndexChanged.connect( + self._on_colormap_changed + ) + text_form.addRow("Colormap", self.colormap_combo) + self.reset_text_button = QPushButton("Reset Text Defaults") + self.reset_text_button.clicked.connect(self._reset_text_defaults) + text_form.addRow(self.reset_text_button) + root.addWidget(text_group) + + style_group = QGroupBox("Style") + style_form = QFormLayout(style_group) + self.font_combo = QFontComboBox() + self.font_combo.currentFontChanged.connect(self._on_font_changed) + style_form.addRow("Font", self.font_combo) + self.title_font_spin = HeatmapPlotEditorControls._build_font_spin() + self.title_font_spin.valueChanged.connect( + self._on_title_font_size_changed + ) + style_form.addRow("Title Size", self.title_font_spin) + self.title_position_x_spin = ( + HeatmapPlotEditorControls._build_position_spin() + ) + self.title_position_x_spin.valueChanged.connect( + self._on_title_position_x_changed + ) + style_form.addRow("Title X", self.title_position_x_spin) + self.title_position_y_spin = ( + HeatmapPlotEditorControls._build_position_spin() + ) + self.title_position_y_spin.valueChanged.connect( + self._on_title_position_y_changed + ) + style_form.addRow("Title Y", self.title_position_y_spin) + self.axis_label_font_spin = ( + HeatmapPlotEditorControls._build_font_spin() + ) + self.axis_label_font_spin.valueChanged.connect( + self._on_axis_label_font_size_changed + ) + style_form.addRow("Axis Label Size", self.axis_label_font_spin) + self.tick_label_font_spin = ( + HeatmapPlotEditorControls._build_font_spin() + ) + self.tick_label_font_spin.valueChanged.connect( + self._on_tick_label_font_size_changed + ) + style_form.addRow("Tick Label Size", self.tick_label_font_spin) + self.legend_font_spin = HeatmapPlotEditorControls._build_font_spin() + self.legend_font_spin.valueChanged.connect( + self._on_legend_font_size_changed + ) + style_form.addRow("Legend Size", self.legend_font_spin) + self.annotation_font_spin = ( + HeatmapPlotEditorControls._build_font_spin() + ) + self.annotation_font_spin.valueChanged.connect( + self._on_annotation_font_size_changed + ) + style_form.addRow("Totals Size", self.annotation_font_spin) + root.addWidget(style_group) + + display_group = QGroupBox("Display") + display_form = QFormLayout(display_group) + self.show_total_annotations_checkbox = QCheckBox( + "Show Total Percent Labels" + ) + self.show_total_annotations_checkbox.toggled.connect( + self._on_show_total_annotations_changed + ) + display_form.addRow(self.show_total_annotations_checkbox) + self.show_legend_checkbox = QCheckBox("Show Legend") + self.show_legend_checkbox.toggled.connect(self._on_show_legend_changed) + display_form.addRow(self.show_legend_checkbox) + self.legend_location_combo = QComboBox() + for label, value in STACKED_HISTOGRAM_LEGEND_LOCATIONS: + self.legend_location_combo.addItem(label, value) + self.legend_location_combo.currentIndexChanged.connect( + self._on_legend_location_changed + ) + display_form.addRow("Legend Position", self.legend_location_combo) + root.addWidget(display_group) + + ticks_group = QGroupBox("Ticks") + ticks_form = QFormLayout(ticks_group) + self.max_y_ticks_spin = QSpinBox() + self.max_y_ticks_spin.setRange(2, 20) + self.max_y_ticks_spin.valueChanged.connect( + self._on_max_y_ticks_changed + ) + ticks_form.addRow("Max Y Ticks", self.max_y_ticks_spin) + self.x_tick_rotation_spin = QSpinBox() + self.x_tick_rotation_spin.setRange(-180, 180) + self.x_tick_rotation_spin.valueChanged.connect( + self._on_x_tick_rotation_changed + ) + ticks_form.addRow("X Tick Rotation", self.x_tick_rotation_spin) + self.y_tick_rotation_spin = QSpinBox() + self.y_tick_rotation_spin.setRange(-180, 180) + self.y_tick_rotation_spin.valueChanged.connect( + self._on_y_tick_rotation_changed + ) + ticks_form.addRow("Y Tick Rotation", self.y_tick_rotation_spin) + self.minor_y_ticks_checkbox = QCheckBox("Show Y Minor Ticks") + self.minor_y_ticks_checkbox.toggled.connect( + self._on_minor_y_ticks_changed + ) + ticks_form.addRow(self.minor_y_ticks_checkbox) + root.addWidget(ticks_group) + + labels_group = QGroupBox("Axis Labels") + labels_layout = QVBoxLayout(labels_group) + labels_layout.setContentsMargins(8, 8, 8, 8) + labels_layout.setSpacing(8) + labels_note = QLabel( + "Rearrange rows to control x-axis order and edit Display Label " + "to customise tick text. The raw label column stays fixed so you " + "can round-trip custom formatting." + ) + labels_note.setWordWrap(True) + labels_layout.addWidget(labels_note) + self.label_table = QTableWidget(0, 2) + self.label_table.setHorizontalHeaderLabels( + ["Raw Label", "Display Label"] + ) + self.label_table.horizontalHeader().setStretchLastSection(True) + self.label_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.label_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.label_table.itemChanged.connect(self._on_label_item_changed) + content_row = QHBoxLayout() + content_row.addWidget(self.label_table, stretch=1) + move_column = QVBoxLayout() + self.move_label_up_button = QPushButton("Up") + self.move_label_up_button.clicked.connect(self._move_label_up) + move_column.addWidget(self.move_label_up_button) + self.move_label_down_button = QPushButton("Down") + self.move_label_down_button.clicked.connect(self._move_label_down) + move_column.addWidget(self.move_label_down_button) + move_column.addStretch(1) + content_row.addLayout(move_column) + labels_layout.addLayout(content_row, stretch=1) + button_row = QHBoxLayout() + self.histogram_order_button = QPushButton("Sort Like Histogram") + self.histogram_order_button.clicked.connect( + self._apply_histogram_order + ) + button_row.addWidget(self.histogram_order_button) + self.auto_subscript_button = QPushButton("Auto Stoich Subscripts") + self.auto_subscript_button.clicked.connect( + self._apply_stoich_subscripts + ) + button_row.addWidget(self.auto_subscript_button) + self.reset_labels_button = QPushButton("Reset Labels") + self.reset_labels_button.clicked.connect(self._reset_labels) + button_row.addWidget(self.reset_labels_button) + button_row.addStretch(1) + labels_layout.addLayout(button_row) + root.addWidget(labels_group, stretch=1) + root.addStretch(1) + + def _sync_label_table(self) -> None: + self.label_table.blockSignals(True) + try: + raw_labels = self._settings.ordered_raw_labels(self._defaults) + self.label_table.setRowCount(len(raw_labels)) + for row, raw_label in enumerate(raw_labels): + raw_item = QTableWidgetItem(raw_label) + raw_item.setFlags( + raw_item.flags() & ~Qt.ItemFlag.ItemIsEditable + ) + self.label_table.setItem(row, 0, raw_item) + self.label_table.setItem( + row, + 1, + QTableWidgetItem(self._settings.display_label(raw_label)), + ) + if raw_labels: + self.label_table.resizeColumnToContents(0) + if self.label_table.currentRow() < 0: + self.label_table.selectRow(0) + finally: + self.label_table.blockSignals(False) + + def _update_legend_state(self) -> None: + self.legend_location_combo.setEnabled( + self.show_legend_checkbox.isChecked() + ) + + def _emit_settings_changed(self) -> None: + if not self._syncing: + self.settings_changed.emit() + + def _emit_label_settings_changed(self) -> None: + if not self._syncing: + self.label_settings_changed.emit() + + def _on_title_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.title = text + self._emit_settings_changed() + + def _on_x_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.x_label = text + self._emit_settings_changed() + + def _on_y_label_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.y_label = text + self._emit_settings_changed() + + def _on_legend_title_changed(self, text: str) -> None: + if self._syncing: + return + self._settings.legend_title = text + self._emit_settings_changed() + + def _on_colormap_changed(self) -> None: + if self._syncing: + return + current = self.colormap_combo.currentData() + if current is None: + return + self.colormap_changed.emit(str(current)) + + def _on_font_changed(self, font: QFont) -> None: + if self._syncing: + return + self._settings.font_family = font.family() + self._emit_settings_changed() + + def _on_title_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_font_size = float(value) + self._emit_settings_changed() + + def _on_title_position_x_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_x = float(value) + self._emit_settings_changed() + + def _on_title_position_y_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.title_position_y = float(value) + self._emit_settings_changed() + + def _on_axis_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.axis_label_font_size = float(value) + self._emit_settings_changed() + + def _on_tick_label_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.tick_label_font_size = float(value) + self._emit_settings_changed() + + def _on_legend_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.legend_font_size = float(value) + self._emit_settings_changed() + + def _on_annotation_font_size_changed(self, value: float) -> None: + if self._syncing: + return + self._settings.annotation_font_size = float(value) + self._emit_settings_changed() + + def _on_show_total_annotations_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_total_annotations = bool(checked) + self._emit_settings_changed() + + def _on_show_legend_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_legend = bool(checked) + self._update_legend_state() + self._emit_settings_changed() + + def _on_legend_location_changed(self) -> None: + if self._syncing: + return + current = self.legend_location_combo.currentData() + self._settings.legend_location = ( + "outside_upper_right" if current is None else str(current) + ) + self._emit_settings_changed() + + def _on_max_y_ticks_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.max_y_ticks = int(value) + self._emit_settings_changed() + + def _on_x_tick_rotation_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.x_tick_rotation = int(value) + self._emit_settings_changed() + + def _on_y_tick_rotation_changed(self, value: int) -> None: + if self._syncing: + return + self._settings.y_tick_rotation = int(value) + self._emit_settings_changed() + + def _on_minor_y_ticks_changed(self, checked: bool) -> None: + if self._syncing: + return + self._settings.show_minor_y_ticks = bool(checked) + self._emit_settings_changed() + + def _on_label_item_changed(self, item: QTableWidgetItem) -> None: + if self._syncing or item.column() != 1: + return + raw_item = self.label_table.item(item.row(), 0) + if raw_item is None: + return + self._settings.label_map[raw_item.text()] = item.text() + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _move_label_up(self) -> None: + row = self.label_table.currentRow() + if row <= 0: + return + self._swap_label_rows(row, row - 1) + self.label_table.selectRow(row - 1) + self._store_label_order_from_table() + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _move_label_down(self) -> None: + row = self.label_table.currentRow() + if row < 0 or row >= self.label_table.rowCount() - 1: + return + self._swap_label_rows(row, row + 1) + self.label_table.selectRow(row + 1) + self._store_label_order_from_table() + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _swap_label_rows(self, row_a: int, row_b: int) -> None: + self.label_table.blockSignals(True) + try: + for column in range(self.label_table.columnCount()): + item_a = self.label_table.takeItem(row_a, column) + item_b = self.label_table.takeItem(row_b, column) + if item_a is not None: + self.label_table.setItem(row_b, column, item_a) + if item_b is not None: + self.label_table.setItem(row_a, column, item_b) + finally: + self.label_table.blockSignals(False) + + def _store_label_order_from_table(self) -> None: + self._settings.label_order = [] + for row in range(self.label_table.rowCount()): + raw_item = self.label_table.item(row, 0) + display_item = self.label_table.item(row, 1) + if raw_item is None: + continue + raw_label = raw_item.text() + self._settings.label_order.append(raw_label) + self._settings.label_map[raw_label] = ( + display_item.text() if display_item is not None else raw_label + ) + + def _reset_text_defaults(self) -> None: + self._settings.title = None + self._settings.x_label = None + self._settings.y_label = None + self._settings.legend_title = None + self._settings.title_position_x = None + self._settings.title_position_y = None + self.sync_defaults(self._defaults) + self._emit_settings_changed() + + def _reset_labels(self) -> None: + self._settings.label_order = list(self._defaults.raw_category_labels) + default_map = dict(self._defaults.default_label_entries) + self._settings.label_map = { + raw_label: default_map.get(raw_label, raw_label) + for raw_label in self._defaults.raw_category_labels + } + self.sync_defaults(self._defaults) + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _apply_stoich_subscripts(self) -> None: + self._settings.label_map = { + raw_label: format_stoich_for_axis(raw_label) + for raw_label in self._settings.ordered_raw_labels(self._defaults) + } + self.sync_defaults(self._defaults) + self._emit_label_settings_changed() + self._emit_settings_changed() + + def _apply_histogram_order(self) -> None: + ordered = sort_stoich_labels(self._defaults.raw_category_labels) + self._settings.label_order = list(ordered) + self.sync_defaults(self._defaults) + self._emit_label_settings_changed() + self._emit_settings_changed() + + +__all__ = [ + "HeatmapPlotDefaults", + "HeatmapPlotEditorControls", + "HeatmapPlotSettings", + "PlotEditorWindow", + "StackedHistogramPlotDefaults", + "StackedHistogramPlotEditorControls", + "StackedHistogramPlotSettings", + "load_pickled_plot_payload", + "load_pickled_plot_figure", + "save_pickled_plot_figure", +] diff --git a/src/saxshell/plotting/stacked_histogram.py b/src/saxshell/plotting/stacked_histogram.py new file mode 100644 index 0000000..7d89fcb --- /dev/null +++ b/src/saxshell/plotting/stacked_histogram.py @@ -0,0 +1,546 @@ +from __future__ import annotations + +import math +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.lines import Line2D +from matplotlib.ticker import AutoMinorLocator, MaxNLocator + +from saxshell.plotting.igor_inline import ( + apply_igor_inline_text_artist, + igor_inline_to_mathtext, + prepare_igor_inline_segments, +) +from saxshell.saxs.stoichiometry import ( + format_stoich_for_axis, + sort_stoich_labels, +) + +STACKED_HISTOGRAM_LEGEND_LOCATIONS = ( + ("Outside Upper Right", "outside_upper_right"), + ("Upper Right", "upper_right"), + ("Upper Left", "upper_left"), + ("Lower Right", "lower_right"), + ("Lower Left", "lower_left"), + ("Best", "best"), +) + + +@dataclass(slots=True) +class StackedHistogramPlotDefaults: + title: str + x_label: str + y_label: str + legend_title: str + title_position_x: float = 0.5 + title_position_y: float = 1.0 + default_colormap_name: str = "" + available_colormap_names: tuple[str, ...] = () + raw_category_labels: tuple[str, ...] = () + default_label_entries: tuple[tuple[str, str], ...] = () + + +@dataclass(slots=True) +class StackedHistogramPlotSettings: + title: str | None = None + x_label: str | None = None + y_label: str | None = None + legend_title: str | None = None + title_position_x: float | None = None + title_position_y: float | None = None + font_family: str = "" + title_font_size: float = 12.0 + axis_label_font_size: float = 11.0 + tick_label_font_size: float = 9.0 + legend_font_size: float = 8.5 + annotation_font_size: float = 9.0 + max_y_ticks: int = 8 + x_tick_rotation: int = 45 + y_tick_rotation: int = 0 + show_minor_y_ticks: bool = False + show_total_annotations: bool = True + show_legend: bool = True + legend_location: str = "outside_upper_right" + label_order: list[str] = field(default_factory=list) + label_map: dict[str, str] = field(default_factory=dict) + + def resolve_title(self, defaults: StackedHistogramPlotDefaults) -> str: + return defaults.title if self.title is None else self.title + + def resolve_x_label(self, defaults: StackedHistogramPlotDefaults) -> str: + return defaults.x_label if self.x_label is None else self.x_label + + def resolve_y_label(self, defaults: StackedHistogramPlotDefaults) -> str: + return defaults.y_label if self.y_label is None else self.y_label + + def resolve_legend_title( + self, + defaults: StackedHistogramPlotDefaults, + ) -> str: + return ( + defaults.legend_title + if self.legend_title is None + else self.legend_title + ) + + def resolve_title_position_x( + self, + defaults: StackedHistogramPlotDefaults, + ) -> float: + return ( + defaults.title_position_x + if self.title_position_x is None + else self.title_position_x + ) + + def resolve_title_position_y( + self, + defaults: StackedHistogramPlotDefaults, + ) -> float: + return ( + defaults.title_position_y + if self.title_position_y is None + else self.title_position_y + ) + + def sync_labels( + self, + raw_labels: Sequence[str], + *, + default_label_entries: Sequence[tuple[str, str]] | None = None, + ) -> None: + default_entries = ( + [ + (str(raw_label), format_stoich_for_axis(str(raw_label))) + for raw_label in raw_labels + ] + if default_label_entries is None + else [ + (str(raw_label), str(display_label)) + for raw_label, display_label in default_label_entries + ] + ) + default_map = { + raw_label: display_label + for raw_label, display_label in default_entries + } + existing = dict(self.label_map) + preserved_order = [ + raw_label + for raw_label in self.label_order + if raw_label in default_map + ] + remaining = [ + raw_label + for raw_label, _display_label in default_entries + if raw_label not in preserved_order + ] + self.label_order = preserved_order + remaining + self.label_map = { + raw_label: existing.get(raw_label, default_map[raw_label]) + for raw_label in self.label_order + } + + def display_label(self, raw_label: str) -> str: + return self.label_map.get(raw_label, raw_label) + + def ordered_raw_labels( + self, + defaults: StackedHistogramPlotDefaults, + ) -> list[str]: + if self.label_order: + available = set(defaults.raw_category_labels) + ordered = [raw for raw in self.label_order if raw in available] + remaining = [ + raw + for raw in defaults.raw_category_labels + if raw not in ordered + ] + return ordered + remaining + return list(defaults.raw_category_labels) + + def ordered_label_entries( + self, + defaults: StackedHistogramPlotDefaults, + ) -> list[tuple[str, str]]: + return [ + (raw_label, self.display_label(raw_label)) + for raw_label in self.ordered_raw_labels(defaults) + ] + + def to_dict(self) -> dict[str, object]: + return { + "title": self.title, + "x_label": self.x_label, + "y_label": self.y_label, + "legend_title": self.legend_title, + "title_position_x": self.title_position_x, + "title_position_y": self.title_position_y, + "font_family": self.font_family, + "title_font_size": self.title_font_size, + "axis_label_font_size": self.axis_label_font_size, + "tick_label_font_size": self.tick_label_font_size, + "legend_font_size": self.legend_font_size, + "annotation_font_size": self.annotation_font_size, + "max_y_ticks": self.max_y_ticks, + "x_tick_rotation": self.x_tick_rotation, + "y_tick_rotation": self.y_tick_rotation, + "show_minor_y_ticks": self.show_minor_y_ticks, + "show_total_annotations": self.show_total_annotations, + "show_legend": self.show_legend, + "legend_location": self.legend_location, + "label_order": list(self.label_order), + "label_map": dict(self.label_map), + } + + def update_from_dict(self, payload: Mapping[str, object]) -> None: + for field_name in ( + "title", + "x_label", + "y_label", + "legend_title", + "title_position_x", + "title_position_y", + ): + if field_name in payload: + setattr(self, field_name, payload[field_name]) + if "font_family" in payload: + self.font_family = str(payload["font_family"] or "") + for field_name in ( + "title_font_size", + "axis_label_font_size", + "tick_label_font_size", + "legend_font_size", + "annotation_font_size", + ): + if field_name in payload: + setattr(self, field_name, float(payload[field_name])) + for field_name in ( + "max_y_ticks", + "x_tick_rotation", + "y_tick_rotation", + ): + if field_name in payload: + setattr(self, field_name, int(payload[field_name])) + for field_name in ( + "show_minor_y_ticks", + "show_total_annotations", + "show_legend", + ): + if field_name in payload: + setattr(self, field_name, bool(payload[field_name])) + if "legend_location" in payload: + self.legend_location = str(payload["legend_location"] or "best") + if "label_order" in payload: + self.label_order = [str(value) for value in payload["label_order"]] + if "label_map" in payload: + label_map = payload["label_map"] + if isinstance(label_map, Mapping): + self.label_map = { + str(key): str(value) for key, value in label_map.items() + } + + +def render_stacked_histogram_export_payload( + export_payload: Mapping[str, object], + *, + ax, + defaults: StackedHistogramPlotDefaults, + settings: StackedHistogramPlotSettings | None = None, + cmap: str | None = None, + structure_segment_colors: Mapping[str, str] | None = None, + show_percent: bool = True, +): + resolved_settings = ( + StackedHistogramPlotSettings() if settings is None else settings + ) + fig = ax.figure + ax.clear() + + labels = [str(label) for label in export_payload.get("labels", ())] + axis_labels = [ + str(label) for label in export_payload.get("axis_labels", ()) + ] + segments = [str(segment) for segment in export_payload.get("segments", ())] + segment_labels = [ + str(label) for label in export_payload.get("segment_labels", ()) + ] + plot_mode = str(export_payload.get("plot_mode", "structure_fraction")) + matrix = np.asarray(export_payload.get("matrix", []), dtype=float) + color_keys = [ + [None if key is None else str(key) for key in row] + for row in export_payload.get("color_keys", []) + ] + + if not labels: + ax.set_title("No prior-weight data available") + ax.set_xlabel(resolved_settings.resolve_x_label(defaults)) + ax.set_ylabel(resolved_settings.resolve_y_label(defaults)) + return fig, ax + + cmap_name = str(cmap or defaults.default_colormap_name or "summer") + colors = plt.get_cmap(cmap_name)( + np.linspace(0.1, 0.9, max(len(segment_labels), 1), endpoint=True) + ) + + x_positions = np.arange(len(labels), dtype=float) + bottoms = np.zeros(len(labels), dtype=float) + for index, segment_label in enumerate(segment_labels): + heights_array = matrix[:, index] + bar_colors = colors[index] + if structure_segment_colors and not plot_mode.startswith( + "solvent_sort" + ): + bar_colors = [ + structure_segment_colors.get( + ( + color_keys[row_index][index] + if row_index < len(color_keys) + and index < len(color_keys[row_index]) + else f"{label}_{segments[index]}" + ), + fallback_color, + ) + for row_index, (label, fallback_color) in enumerate( + zip(labels, [colors[index]] * len(labels), strict=False) + ) + ] + ax.bar( + x_positions, + heights_array, + bottom=bottoms, + label=segment_label, + color=bar_colors, + edgecolor="white", + width=0.8, + ) + bottoms += heights_array + + showed_small_total_marker = False + if show_percent and resolved_settings.show_total_annotations: + for index, total in enumerate(bottoms): + if total >= 1.0: + ax.text( + x_positions[index], + total + 1.0, + f"{total:.1f}%", + ha="center", + va="bottom", + fontsize=resolved_settings.annotation_font_size, + **_font_kwargs(resolved_settings.font_family), + ) + else: + ax.scatter( + x_positions[index], + total + 1.0, + color="red", + s=16, + zorder=4, + ) + showed_small_total_marker = True + + max_total = float(np.max(bottoms)) if bottoms.size else 0.0 + ax.set_ylim(0.0, max(max_total + 4.0, 10.0)) + ax.set_xlim(-0.5, len(labels) - 0.5) + ax.set_xlabel( + resolved_settings.resolve_x_label(defaults), + fontsize=resolved_settings.axis_label_font_size, + **_font_kwargs(resolved_settings.font_family), + ) + apply_igor_inline_text_artist( + ax.xaxis.label, + resolved_settings.resolve_x_label(defaults), + default_font_size=resolved_settings.axis_label_font_size, + gid_prefix="stacked-histogram-x-label", + target_axes=ax, + ) + ax.set_ylabel( + resolved_settings.resolve_y_label(defaults), + fontsize=resolved_settings.axis_label_font_size, + **_font_kwargs(resolved_settings.font_family), + ) + apply_igor_inline_text_artist( + ax.yaxis.label, + resolved_settings.resolve_y_label(defaults), + default_font_size=resolved_settings.axis_label_font_size, + gid_prefix="stacked-histogram-y-label", + target_axes=ax, + ) + ax.set_title( + resolved_settings.resolve_title(defaults), + y=resolved_settings.resolve_title_position_y(defaults), + fontsize=resolved_settings.title_font_size, + **_font_kwargs(resolved_settings.font_family), + ) + ax.title.set_x(resolved_settings.resolve_title_position_x(defaults)) + apply_igor_inline_text_artist( + ax.title, + resolved_settings.resolve_title(defaults), + default_font_size=resolved_settings.title_font_size, + gid_prefix="stacked-histogram-title", + target_axes=ax, + ) + + rendered_x_tick_labels: list[str] = [] + composite_x_tick_labels: dict[int, str] = {} + for tick_index, axis_label in enumerate(axis_labels): + segments_for_label, has_markup = prepare_igor_inline_segments( + axis_label, + default_font_size=resolved_settings.tick_label_font_size, + ) + if not has_markup: + rendered_x_tick_labels.append(axis_label) + continue + if any( + not math.isclose( + segment.font_size, + resolved_settings.tick_label_font_size, + ) + for segment in segments_for_label + ): + rendered_x_tick_labels.append(" ") + composite_x_tick_labels[tick_index] = axis_label + continue + rendered_x_tick_labels.append( + igor_inline_to_mathtext( + axis_label, + default_font_size=resolved_settings.tick_label_font_size, + ) + ) + ax.set_xticks(x_positions) + ax.set_xticklabels( + rendered_x_tick_labels, + rotation=resolved_settings.x_tick_rotation, + ha="right", + ) + ax.tick_params(axis="x", labelsize=resolved_settings.tick_label_font_size) + ax.tick_params(axis="y", labelsize=resolved_settings.tick_label_font_size) + ax.yaxis.set_major_locator( + MaxNLocator(nbins=max(resolved_settings.max_y_ticks, 2)) + ) + if resolved_settings.show_minor_y_ticks: + ax.yaxis.set_minor_locator(AutoMinorLocator(2)) + else: + ax.minorticks_off() + + for tick_label in ax.get_xticklabels(): + _apply_font_to_text( + tick_label, + font_family=resolved_settings.font_family, + rotation=resolved_settings.x_tick_rotation, + ) + for tick_label in ax.get_yticklabels(): + _apply_font_to_text( + tick_label, + font_family=resolved_settings.font_family, + rotation=resolved_settings.y_tick_rotation, + ) + for tick_index, tick_label in enumerate(ax.get_xticklabels()): + if tick_index not in composite_x_tick_labels: + continue + apply_igor_inline_text_artist( + tick_label, + composite_x_tick_labels[tick_index], + default_font_size=resolved_settings.tick_label_font_size, + gid_prefix=f"stacked-histogram-x-tick-{tick_index}", + target_axes=ax, + ) + + legend_handles, legend_labels = ax.get_legend_handles_labels() + if ( + show_percent + and resolved_settings.show_total_annotations + and showed_small_total_marker + ): + legend_handles.append( + Line2D( + [], + [], + marker="o", + color="red", + linestyle="None", + markersize=5, + ) + ) + legend_labels.append("< 1% total") + + if resolved_settings.show_legend and legend_handles: + legend = ax.legend( + legend_handles, + legend_labels, + title=resolved_settings.resolve_legend_title(defaults), + fontsize=resolved_settings.legend_font_size, + **_legend_kwargs(resolved_settings.legend_location), + ) + for legend_text in legend.get_texts(): + _apply_font_to_text( + legend_text, + font_family=resolved_settings.font_family, + ) + _apply_font_to_text( + legend.get_title(), + font_family=resolved_settings.font_family, + ) + apply_igor_inline_text_artist( + legend.get_title(), + resolved_settings.resolve_legend_title(defaults), + default_font_size=resolved_settings.legend_font_size, + gid_prefix="stacked-histogram-legend-title", + target_axes=ax, + ) + + fig.tight_layout() + return fig, ax + + +def default_histogram_label_entries( + raw_labels: Sequence[str], +) -> tuple[tuple[str, str], ...]: + return tuple( + (str(raw_label), format_stoich_for_axis(str(raw_label))) + for raw_label in raw_labels + ) + + +def sorted_histogram_label_entries( + raw_labels: Sequence[str], +) -> list[tuple[str, str]]: + return [ + (raw_label, format_stoich_for_axis(raw_label)) + for raw_label in sort_stoich_labels(raw_labels) + ] + + +def _font_kwargs(font_family: str) -> dict[str, str]: + return {} if not font_family else {"fontfamily": font_family} + + +def _apply_font_to_text( + text_artist, + *, + font_family: str, + rotation: float | None = None, +) -> None: + if font_family: + text_artist.set_fontfamily(font_family) + if rotation is not None: + text_artist.set_rotation(rotation) + + +def _legend_kwargs(location: str) -> dict[str, object]: + if location == "outside_upper_right": + return {"bbox_to_anchor": (1.02, 1.0), "loc": "upper left"} + return {"loc": location.replace("_", " ")} + + +__all__ = [ + "STACKED_HISTOGRAM_LEGEND_LOCATIONS", + "StackedHistogramPlotDefaults", + "StackedHistogramPlotSettings", + "default_histogram_label_entries", + "render_stacked_histogram_export_payload", + "sorted_histogram_label_entries", +] diff --git a/src/saxshell/saxs/ui/dream_tab.py b/src/saxshell/saxs/ui/dream_tab.py index 8b33398..8be74e8 100644 --- a/src/saxshell/saxs/ui/dream_tab.py +++ b/src/saxshell/saxs/ui/dream_tab.py @@ -39,6 +39,7 @@ QWidget, ) +from saxshell.plotting import Q_A_INVERSE_LABEL from saxshell.saxs.dream import ( DreamModelPlotData, DreamRunSettings, @@ -2062,7 +2063,7 @@ def plot_model_fit(self, plot_data: DreamModelPlotData | None) -> None: bottom_axis.set_xscale( "log" if self.model_log_x_checkbox.isChecked() else "linear" ) - bottom_axis.set_xlabel("q (Å⁻¹)") + bottom_axis.set_xlabel(Q_A_INVERSE_LABEL) bottom_axis.set_ylabel("Residual") if plotted_lines: self._build_interactive_model_legend(top_axis, plotted_lines) diff --git a/tests/test_clusterdynamics.py b/tests/test_clusterdynamics.py index d63a9f9..31f4467 100644 --- a/tests/test_clusterdynamics.py +++ b/tests/test_clusterdynamics.py @@ -17,10 +17,17 @@ save_cluster_dynamics_dataset, ) from saxshell.clusterdynamics.ui.main_window import ClusterDynamicsMainWindow +from saxshell.clusterdynamics.ui.plot_panel import ClusterDynamicsPlotPanel +from saxshell.plotting import ( + igor_inline_to_mathtext, + load_pickled_plot_figure, + prepare_igor_inline_segments, +) from saxshell.saxs.project_manager import ( SAXSProjectManager, build_project_paths, ) +from saxshell.saxs.stoichiometry import format_stoich_for_axis ATOM_TYPE_DEFINITIONS = { "node": [("Pb", None)], @@ -501,6 +508,369 @@ def test_cluster_dynamics_main_window_exports_colormap_and_lifetime_csv( window.close() +def test_cluster_dynamics_plot_panel_applies_heatmap_editor_settings( + qapp, + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + panel = ClusterDynamicsPlotPanel(enable_plot_editor=True) + panel.set_result(result) + panel.show() + qapp.processEvents() + + assert panel.plot_editor_button is not None + assert len(result.cluster_labels) > 1 + + axis = panel.figure.axes[0] + default_tick_texts = [ + label.get_text() + for label in axis.get_yticklabels() + if label.get_text() + ] + expected_default_texts = [ + format_stoich_for_axis(label) for label in result.cluster_labels + ] + assert default_tick_texts == expected_default_texts + + panel.plot_editor_button.click() + qapp.processEvents() + + assert panel._plot_editor_window is not None + assert panel._plot_editor_controls is not None + + controls = panel._plot_editor_controls + assert controls.x_axis_unit_combo.currentData() == "fs" + auto_vmin = float(axis.images[0].norm.vmin) + auto_vmax = float(axis.images[0].norm.vmax) + auto_range = auto_vmax - auto_vmin + if auto_range > 0.4: + manual_vmin = auto_vmin + (0.2 * auto_range) + manual_vmax = auto_vmax - (0.2 * auto_range) + else: + manual_vmin = auto_vmin + manual_vmax = auto_vmax + (0.5 * max(auto_range, 1.0)) + + assert controls.color_limit_min_spin.value() == pytest.approx(auto_vmin) + assert controls.color_limit_max_spin.value() == pytest.approx(auto_vmax) + assert not controls.reset_color_limits_button.isEnabled() + + controls.x_axis_unit_combo.setCurrentIndex( + controls.x_axis_unit_combo.findData("ps") + ) + qapp.processEvents() + + axis = panel.figure.axes[0] + assert controls.x_axis_unit_combo.currentData() == "ps" + assert panel.time_unit_combo.currentData() == "ps" + assert axis.get_xlabel() == "Time (ps)" + + controls.label_table.setCurrentCell(1, 1) + controls._move_label_up() + qapp.processEvents() + + controls.title_edit.setText("Edited Heatmap") + controls.x_label_edit.setText("Time $^{2}$") + controls.y_label_edit.setText("Edited Clusters") + controls.colorbar_label_edit.setText("Edited Scale") + controls.title_position_x_spin.setValue(0.28) + controls.title_position_y_spin.setValue(1.11) + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("magma") + ) + controls.aspect_combo.setCurrentIndex( + controls.aspect_combo.findData("custom") + ) + controls.aspect_value_spin.setValue(1.5) + controls.max_x_ticks_spin.setValue(3) + controls.max_y_ticks_spin.setValue(3) + controls.x_tick_rotation_spin.setValue(35) + controls.y_tick_rotation_spin.setValue(15) + controls.tick_label_font_spin.setValue(13.0) + controls.cluster_label_font_spin.setValue(15.0) + controls.color_limit_min_spin.setValue(manual_vmin) + controls.color_limit_max_spin.setValue(manual_vmax) + controls.label_table.item(0, 1).setText("Pb$_{2}$I$^{+}$") + qapp.processEvents() + panel.canvas.draw() + + axis = panel.figure.axes[0] + colorbar_axis = panel.figure.axes[1] + tick_texts = [ + label.get_text() + for label in axis.get_yticklabels() + if label.get_text() + ] + + assert axis.get_title() == "Edited Heatmap" + assert axis.title.get_position()[0] == pytest.approx(0.28) + assert axis.title.get_position()[1] == pytest.approx(1.11) + assert axis.get_xlabel() == "Time $^{2}$" + assert axis.get_ylabel() == "Edited Clusters" + assert float(axis.get_aspect()) == pytest.approx(1.5) + assert colorbar_axis.get_ylabel() == "Edited Scale" + assert panel.colormap_combo.currentData() == "magma" + assert axis.images[0].get_cmap().name == "magma" + assert axis.images[0].norm.vmin == pytest.approx(manual_vmin) + assert axis.images[0].norm.vmax == pytest.approx(manual_vmax) + assert controls.reset_color_limits_button.isEnabled() + assert tick_texts[0] == "Pb$_{2}$I$^{+}$" + assert tick_texts[1] == expected_default_texts[0] + assert axis.get_xticklabels()[0].get_rotation() == pytest.approx(35.0) + assert axis.get_yticklabels()[0].get_rotation() == pytest.approx(15.0) + assert axis.get_xticklabels()[0].get_fontsize() == pytest.approx(13.0) + assert axis.get_yticklabels()[0].get_fontsize() == pytest.approx(15.0) + + controls.reset_color_limits_button.click() + qapp.processEvents() + panel.canvas.draw() + + axis = panel.figure.axes[0] + assert controls.color_limit_min_spin.value() == pytest.approx(auto_vmin) + assert controls.color_limit_max_spin.value() == pytest.approx(auto_vmax) + assert axis.images[0].norm.vmin == pytest.approx(auto_vmin) + assert axis.images[0].norm.vmax == pytest.approx(auto_vmax) + assert not controls.reset_color_limits_button.isEnabled() + + panel._plot_editor_window.close() + panel.close() + + +def test_igor_inline_text_supports_bold_italic_and_size_segments(): + segments, has_markup = prepare_igor_inline_segments( + r"\f01\f02Hello\f00 World \Z<16>Big", + default_font_size=11.0, + ) + + assert has_markup + assert segments[0].text == "Hello" + assert segments[0].bold is True + assert segments[0].italic is True + assert segments[0].font_size == pytest.approx(11.0) + assert segments[1].text == " World " + assert segments[1].bold is False + assert segments[1].italic is False + assert segments[1].font_size == pytest.approx(11.0) + assert segments[2].text == "Big" + assert segments[2].font_size == pytest.approx(16.0) + + assert ( + igor_inline_to_mathtext( + r"\f01 Hello \f00World", + default_font_size=11.0, + ) + == r"$\mathbf{\ Hello\ }\mathregular{World}$" + ) + + +def test_cluster_dynamics_plot_editor_does_not_resync_defaults_while_editing_labels( + qapp, + tmp_path, + monkeypatch, +): + frames_dir = _build_frames_dir(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + panel = ClusterDynamicsPlotPanel(enable_plot_editor=True) + panel.set_result(result) + panel.show() + qapp.processEvents() + panel.plot_editor_button.click() + qapp.processEvents() + + assert panel._plot_editor_controls is not None + controls = panel._plot_editor_controls + + sync_calls: list[object] = [] + original_sync_defaults = controls.sync_defaults + + def _tracking_sync_defaults(defaults): + sync_calls.append(defaults) + return original_sync_defaults(defaults) + + monkeypatch.setattr(controls, "sync_defaults", _tracking_sync_defaults) + + controls.label_table.item(0, 1).setText("I$^{+}$") + qapp.processEvents() + + assert sync_calls == [] + + panel._plot_editor_window.close() + panel.close() + + +def test_cluster_dynamics_plot_editor_can_save_and_load_pickled_plots( + qapp, + tmp_path, + monkeypatch, +): + frames_dir = _build_frames_dir(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + panel = ClusterDynamicsPlotPanel(enable_plot_editor=True) + panel.set_result(result) + panel.show() + qapp.processEvents() + panel.plot_editor_button.click() + qapp.processEvents() + + assert panel._plot_editor_window is not None + assert panel._plot_editor_controls is not None + + editor = panel._plot_editor_window + controls = panel._plot_editor_controls + pickle_path = tmp_path / "saved_colormap.pkl" + + monkeypatch.setattr( + "saxshell.plotting.plot_editor.QFileDialog.getSaveFileName", + lambda *args, **kwargs: ( + str(pickle_path), + "Pickled Plot Files (*.pkl)", + ), + ) + monkeypatch.setattr( + "saxshell.plotting.plot_editor.QFileDialog.getOpenFileName", + lambda *args, **kwargs: ( + str(pickle_path), + "Pickled Plot Files (*.pkl)", + ), + ) + + controls.title_edit.setText("Pickled Heatmap") + controls.title_position_x_spin.setValue(0.23) + controls.title_position_y_spin.setValue(1.14) + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("magma") + ) + controls.x_axis_unit_combo.setCurrentIndex( + controls.x_axis_unit_combo.findData("ps") + ) + qapp.processEvents() + + assert editor.save_pickled_plot_as() == pickle_path + assert pickle_path.is_file() + + pickled_figure = load_pickled_plot_figure(pickle_path) + assert pickled_figure.axes[0].get_title() == "Pickled Heatmap" + assert pickled_figure.axes[0].images[0].get_cmap().name == "magma" + assert pickled_figure.axes[0].title.get_position()[0] == pytest.approx( + 0.23 + ) + assert pickled_figure.axes[0].title.get_position()[1] == pytest.approx( + 1.14 + ) + + controls.title_edit.setText("Live Heatmap") + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("viridis") + ) + controls.x_axis_unit_combo.setCurrentIndex( + controls.x_axis_unit_combo.findData("fs") + ) + qapp.processEvents() + assert panel.figure.axes[0].get_title() == "Live Heatmap" + assert editor.figure.axes[0].get_title() == "Live Heatmap" + assert panel.colormap_combo.currentData() == "viridis" + assert panel.time_unit_combo.currentData() == "fs" + + assert editor.load_pickled_plot_as() == pickle_path + qapp.processEvents() + assert not editor.is_showing_pickled_plot() + assert editor.figure.axes[0].get_title() == "Pickled Heatmap" + assert panel.figure.axes[0].get_title() == "Pickled Heatmap" + assert controls.title_edit.text() == "Pickled Heatmap" + assert controls.title_position_x_spin.value() == pytest.approx(0.23) + assert controls.title_position_y_spin.value() == pytest.approx(1.14) + assert panel.figure.axes[0].title.get_position()[0] == pytest.approx(0.23) + assert panel.figure.axes[0].title.get_position()[1] == pytest.approx(1.14) + assert controls.colormap_combo.currentData() == "magma" + assert controls.x_axis_unit_combo.currentData() == "ps" + assert panel.colormap_combo.currentData() == "magma" + assert panel.time_unit_combo.currentData() == "ps" + assert not editor.show_live_preview_button.isEnabled() + + controls.title_edit.setText("Editable After Load") + qapp.processEvents() + assert panel.figure.axes[0].get_title() == "Editable After Load" + assert editor.figure.axes[0].get_title() == "Editable After Load" + + panel._plot_editor_window.close() + panel.close() + + +def test_cluster_dynamics_plot_editor_supports_igor_inline_label_markup( + qapp, + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + result = ClusterDynamicsWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + shell_levels=(1,), + frame_timestep_fs=10.0, + frames_per_colormap_timestep=2, + ).analyze() + + panel = ClusterDynamicsPlotPanel(enable_plot_editor=True) + panel.set_result(result) + panel.show() + qapp.processEvents() + panel.plot_editor_button.click() + qapp.processEvents() + + assert panel._plot_editor_controls is not None + controls = panel._plot_editor_controls + + controls.title_edit.setText(r"\f01\f02Rich\f00 Title") + controls.x_label_edit.setText(r"Time \Z<16>\f01fs") + controls.label_table.item(0, 1).setText(r"\f01Pb$_{2}$\f00I$^{+}$") + qapp.processEvents() + panel.canvas.draw() + + axis = panel.figure.axes[0] + assert axis.get_title() == r"$\mathbf{\mathit{Rich}}\mathregular{\ Title}$" + assert ( + axis.get_yticklabels()[0].get_text() + == r"$\mathbf{Pb}_{2}\mathregular{I}^{+}$" + ) + + x_label_segments = [ + text + for text in axis.texts + if str(text.get_gid()).startswith("heatmap-x-label-") + ] + assert len(x_label_segments) == 2 + assert x_label_segments[0].get_text() == r"$\mathregular{Time\ }$" + assert x_label_segments[1].get_text() == r"$\mathbf{fs}$" + assert x_label_segments[1].get_fontsize() == pytest.approx(16.0) + + panel._plot_editor_window.close() + panel.close() + + def test_cluster_dynamics_main_window_registers_frames_dir_with_project( qapp, tmp_path, diff --git a/tests/test_clusterdynamicsml.py b/tests/test_clusterdynamicsml.py index ddf2cdf..1c1b067 100644 --- a/tests/test_clusterdynamicsml.py +++ b/tests/test_clusterdynamicsml.py @@ -47,6 +47,7 @@ build_prior_histogram_export_payload, list_secondary_filter_elements, ) +from saxshell.saxs.stoichiometry import format_stoich_for_axis from saxshell.structure import PDBAtom, PDBStructure from saxshell.xyz2pdb import resolve_reference_path from saxshell.xyz2pdb.workflow import rotation_matrix_from_to @@ -2649,6 +2650,146 @@ def test_clusterdynamicsml_window_exports_colormap_and_lifetime_csv( window.close() +def test_clusterdynamicsml_window_opens_heatmap_plot_editor_popup( + qapp, + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window.show() + window.dynamics_plot_panel.set_result(result.dynamics_result) + window.dynamics_plot_panel.display_mode_combo.setCurrentIndex( + window.dynamics_plot_panel.display_mode_combo.findData("count") + ) + window.dynamics_plot_panel.time_unit_combo.setCurrentIndex( + window.dynamics_plot_panel.time_unit_combo.findData("ps") + ) + qapp.processEvents() + + assert window.dynamics_plot_panel.plot_editor_button is not None + + window.dynamics_plot_panel.plot_editor_button.click() + qapp.processEvents() + + assert window.dynamics_plot_panel._plot_editor_window is not None + assert window.dynamics_plot_panel._plot_editor_window.isVisible() + assert ( + window.dynamics_plot_panel._plot_editor_window.windowTitle() + == "Cluster Dynamics Colormap Editor" + ) + assert ( + window.dynamics_plot_panel._plot_editor_window.save_pickle_button.text() + == "Save Pickled Plot" + ) + assert ( + window.dynamics_plot_panel._plot_editor_window.load_pickle_button.text() + == "Load Pickled Plot" + ) + assert window.dynamics_plot_panel._plot_editor_controls is not None + assert ( + window.dynamics_plot_panel._plot_editor_controls.x_axis_unit_combo.currentData() + == "ps" + ) + assert ( + window.dynamics_plot_panel._plot_editor_controls.title_edit.text() + == "Time-Binned Cluster Distribution (Counts / bin)" + ) + assert ( + window.dynamics_plot_panel._plot_editor_controls.x_label_edit.text() + == "Time (ps)" + ) + assert ( + window.dynamics_plot_panel._plot_editor_controls.y_label_edit.text() + == "Cluster label" + ) + window.dynamics_plot_panel._plot_editor_controls.x_axis_unit_combo.setCurrentIndex( + window.dynamics_plot_panel._plot_editor_controls.x_axis_unit_combo.findData( + "fs" + ) + ) + qapp.processEvents() + assert window.dynamics_plot_panel.time_unit_combo.currentData() == "fs" + assert ( + window.dynamics_plot_panel._plot_editor_controls.x_label_edit.text() + == "Time (fs)" + ) + first_raw_label = ( + window.dynamics_plot_panel._plot_editor_controls.label_table.item( + 0, 0 + ).text() + ) + assert window.dynamics_plot_panel._plot_editor_controls.label_table.item( + 0, 1 + ).text() == format_stoich_for_axis(first_raw_label) + + window.dynamics_plot_panel._plot_editor_window.close() + window.close() + + +def test_clusterdynamicsml_window_keeps_live_heatmap_readable_after_plot_edits( + qapp, + tmp_path, +): + frames_dir = _build_frames_dir(tmp_path) + clusters_dir = _build_clusters_dir(tmp_path) + experimental_data_file = _write_experimental_data_file(tmp_path) + + result = ClusterDynamicsMLWorkflow( + frames_dir, + atom_type_definitions=ATOM_TYPE_DEFINITIONS, + pair_cutoff_definitions=PAIR_CUTOFFS, + clusters_dir=clusters_dir, + experimental_data_file=experimental_data_file, + frame_timestep_fs=10.0, + frames_per_colormap_timestep=1, + target_node_counts=(4, 5), + ).analyze() + + window = ClusterDynamicsMLMainWindow(initial_frames_dir=frames_dir) + window.show() + window.dynamics_plot_panel.set_result(result.dynamics_result) + qapp.processEvents() + + assert window.dynamics_plot_panel.plot_editor_button is not None + window.dynamics_plot_panel.plot_editor_button.click() + qapp.processEvents() + + assert window.dynamics_plot_panel._plot_editor_window is not None + assert window.dynamics_plot_panel._plot_editor_controls is not None + + controls = window.dynamics_plot_panel._plot_editor_controls + controls.title_edit.setText("Edited Heatmap") + controls.title_font_spin.setValue(16.0) + controls.axis_label_font_spin.setValue(14.0) + controls.tick_label_font_spin.setValue(12.0) + controls.cluster_label_font_spin.setValue(12.0) + controls.title_position_y_spin.setValue(1.12) + qapp.processEvents() + window.dynamics_plot_panel.canvas.draw() + + axis = window.dynamics_plot_panel.figure.axes[0] + assert window.dynamics_plot_panel.canvas.height() >= 250 + assert axis.get_title() == "Edited Heatmap" + assert axis.get_position().height > 0.3 + + window.dynamics_plot_panel._plot_editor_window.close() + window.close() + + def test_clusterdynamicsml_window_shows_observed_lifetime_tab( qapp, tmp_path, From 2273bcb43178b74286668f0af144a423b8db078e Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:45:33 -0600 Subject: [PATCH 2/7] feat(fullrmc): expand Packmol solvent setup workflow Add Packmol Docker linking, solvent-shell analysis/building, free-solvent allocation, structure-mode aware Packmol setup, and constraints preview/open helpers. Update the fullrmc UI and tests around the integrated RMC setup flow. --- src/saxshell/fullrmc/__init__.py | 66 + src/saxshell/fullrmc/packmol_docker.py | 626 ++++ src/saxshell/fullrmc/packmol_planning.py | 536 ++- src/saxshell/fullrmc/packmol_setup.py | 247 +- src/saxshell/fullrmc/project_loader.py | 52 +- src/saxshell/fullrmc/project_model.py | 54 +- src/saxshell/fullrmc/representatives.py | 320 +- src/saxshell/fullrmc/solvent_handling.py | 1038 +++++- src/saxshell/fullrmc/solvent_shell_builder.py | 2931 +++++++++++++++++ src/saxshell/fullrmc/ui/__init__.py | 10 + .../fullrmc/ui/constraints_preview_window.py | 72 + src/saxshell/fullrmc/ui/main_window.py | 2356 +++++++++++-- .../fullrmc/ui/packmol_docker_dialog.py | 487 +++ .../ui/solvent_shell_builder_window.py | 1316 ++++++++ tests/test_fullrmc_cli.py | 2708 +++++++++++++-- 15 files changed, 12265 insertions(+), 554 deletions(-) create mode 100644 src/saxshell/fullrmc/packmol_docker.py create mode 100644 src/saxshell/fullrmc/solvent_shell_builder.py create mode 100644 src/saxshell/fullrmc/ui/constraints_preview_window.py create mode 100644 src/saxshell/fullrmc/ui/packmol_docker_dialog.py create mode 100644 src/saxshell/fullrmc/ui/solvent_shell_builder_window.py diff --git a/src/saxshell/fullrmc/__init__.py b/src/saxshell/fullrmc/__init__.py index 1610bc1..75e821b 100644 --- a/src/saxshell/fullrmc/__init__.py +++ b/src/saxshell/fullrmc/__init__.py @@ -10,6 +10,19 @@ load_constraint_generation_metadata, save_constraint_generation_metadata, ) +from .packmol_docker import ( + DEFAULT_PACKMOL_CONTAINER_ROOT, + PackmolDockerClient, + PackmolDockerContainerRecord, + PackmolDockerDirectoryEntry, + PackmolDockerLink, + PackmolDockerSyncResult, + PackmolDockerValidationResult, + container_project_root_is_valid, + load_packmol_docker_link_metadata, + normalize_container_directory, + save_packmol_docker_link_metadata, +) from .packmol_planning import ( PackmolPlanningEntry, PackmolPlanningMetadata, @@ -86,16 +99,34 @@ load_solvent_handling_metadata, save_solvent_handling_metadata, ) +from .solvent_shell_builder import ( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + SolventShellAnalysisResult, + SolventShellBuildResult, + SolventShellResidueMismatchSummary, + SolventShellResidueSummary, + analyze_solvent_shell, + build_solvent_shell_output, + default_director_atom_name, + reference_atom_choices, +) if TYPE_CHECKING: from .ui.main_window import RMCSetupMainWindow, launch_rmcsetup_ui from .ui.representative_preview_window import RepresentativePreviewWindow + from .ui.solvent_shell_builder_window import ( + SolventShellBuilderMainWindow, + launch_solvent_shell_builder_ui, + ) __all__ = [ "ClusterSourceValidationResult", "ConstraintGenerationEntry", "ConstraintGenerationMetadata", "ConstraintGenerationSettings", + "DEFAULT_REFERENCE_MATCH_TOLERANCE_A", + "DEFAULT_PACKMOL_CONTAINER_ROOT", + "container_project_root_is_valid", "RMCDreamProjectSource", "RMCDreamRunRecord", "RMCSetupPaths", @@ -103,6 +134,12 @@ "collect_cluster_count_rows", "DistributionSelectionEntry", "DistributionSelectionMetadata", + "PackmolDockerClient", + "PackmolDockerContainerRecord", + "PackmolDockerDirectoryEntry", + "PackmolDockerLink", + "PackmolDockerSyncResult", + "PackmolDockerValidationResult", "PackmolPlanningEntry", "PackmolPlanningMetadata", "PackmolPlanningSettings", @@ -115,6 +152,10 @@ "RepresentativePreviewWindow", "RepresentativeSelectionEntry", "RepresentativeSelectionIssue", + "SolventShellBuildResult", + "build_solvent_shell_output", + "default_director_atom_name", + "reference_atom_choices", "RepresentativeSelectionMetadata", "RepresentativeSelectionSettings", "SolutionProperties", @@ -122,9 +163,14 @@ "SolutionPropertiesPreset", "SolutionPropertiesResult", "SolutionPropertiesSettings", + "SolventShellAnalysisResult", + "SolventShellBuilderMainWindow", + "SolventShellResidueMismatchSummary", + "SolventShellResidueSummary", "SolventHandlingEntry", "SolventHandlingMetadata", "SolventHandlingSettings", + "analyze_solvent_shell", "build_constraint_generation", "build_distribution_selection", "build_packmol_setup", @@ -135,8 +181,10 @@ "discover_valid_dream_runs", "ensure_rmcsetup_structure", "expected_cluster_inventory_rows", + "launch_solvent_shell_builder_ui", "launch_rmcsetup_ui", "load_constraint_generation_metadata", + "load_packmol_docker_link_metadata", "load_distribution_selection_metadata", "load_packmol_setup_metadata", "load_packmol_planning_metadata", @@ -147,9 +195,11 @@ "load_solvent_handling_metadata", "list_solvent_reference_presets", "ordered_solution_property_preset_names", + "normalize_container_directory", "parse_angle_triplet_text", "parse_bond_pair_text", "save_constraint_generation_metadata", + "save_packmol_docker_link_metadata", "save_distribution_selection_metadata", "save_packmol_setup_metadata", "save_packmol_planning_metadata", @@ -173,6 +223,22 @@ def __getattr__(name: str): "launch_rmcsetup_ui": launch_rmcsetup_ui, } return exports[name] + if name in { + "SolventShellBuilderMainWindow", + "launch_solvent_shell_builder_ui", + }: + from .ui.solvent_shell_builder_window import ( + SolventShellBuilderMainWindow, + launch_solvent_shell_builder_ui, + ) + + exports = { + "SolventShellBuilderMainWindow": SolventShellBuilderMainWindow, + "launch_solvent_shell_builder_ui": ( + launch_solvent_shell_builder_ui + ), + } + return exports[name] if name == "RepresentativePreviewWindow": from .ui.representative_preview_window import ( RepresentativePreviewWindow, diff --git a/src/saxshell/fullrmc/packmol_docker.py b/src/saxshell/fullrmc/packmol_docker.py new file mode 100644 index 0000000..2a23e54 --- /dev/null +++ b/src/saxshell/fullrmc/packmol_docker.py @@ -0,0 +1,626 @@ +from __future__ import annotations + +import json +import shlex +import subprocess +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path, PurePosixPath + +if False: # pragma: no cover + from .packmol_setup import PackmolSetupMetadata + +DEFAULT_PACKMOL_CONTAINER_ROOT = "/packmol_input_files" +_DOCKER_DAEMON_UNAVAILABLE_HINT = ( + "Docker Desktop or the Docker daemon does not appear to be running. " + "Start Docker Desktop (or another Docker runtime such as OrbStack or " + "Colima), wait for `docker info` to succeed, and retry." +) +_DOCKER_DAEMON_UNAVAILABLE_PATTERNS = ( + "cannot connect to the docker daemon", + "is the docker daemon running", + "error during connect", +) + + +@dataclass(slots=True) +class PackmolDockerLink: + display_name: str + container_name: str + container_project_root: str = DEFAULT_PACKMOL_CONTAINER_ROOT + packmol_command: str = "packmol" + shell_command: str = "sh" + packmol_version: str | None = None + linked_at: str | None = None + last_verified_at: str | None = None + container_id: str | None = None + image_name: str | None = None + packmol_command_path: str | None = None + last_sync_at: str | None = None + last_sync_status: str | None = None + last_sync_message: str | None = None + + def to_dict(self) -> dict[str, object]: + payload = asdict(self) + payload["display_name"] = self.resolved_display_name + payload["container_name"] = self.container_name.strip() + payload["container_project_root"] = normalize_container_directory( + self.container_project_root + ) + payload["packmol_command"] = self.packmol_command.strip() or "packmol" + payload["shell_command"] = self.shell_command.strip() or "sh" + return payload + + def to_preset_dict(self) -> dict[str, object]: + return { + "display_name": self.resolved_display_name, + "container_name": self.container_name.strip(), + "container_project_root": normalize_container_directory( + self.container_project_root + ), + "packmol_command": self.packmol_command.strip() or "packmol", + "shell_command": self.shell_command.strip() or "sh", + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object] | None, + ) -> "PackmolDockerLink | None": + if not payload: + return None + container_name = str(payload.get("container_name", "")).strip() + display_name = str(payload.get("display_name", container_name)).strip() + if not container_name: + return None + return cls( + display_name=display_name or container_name, + container_name=container_name, + container_project_root=normalize_container_directory( + payload.get("container_project_root") + ), + packmol_command=str( + payload.get("packmol_command", "packmol") + ).strip() + or "packmol", + shell_command=str(payload.get("shell_command", "sh")).strip() + or "sh", + packmol_version=_optional_text(payload.get("packmol_version")), + linked_at=_optional_text(payload.get("linked_at")), + last_verified_at=_optional_text(payload.get("last_verified_at")), + container_id=_optional_text(payload.get("container_id")), + image_name=_optional_text(payload.get("image_name")), + packmol_command_path=_optional_text( + payload.get("packmol_command_path") + ), + last_sync_at=_optional_text(payload.get("last_sync_at")), + last_sync_status=_optional_text(payload.get("last_sync_status")), + last_sync_message=_optional_text(payload.get("last_sync_message")), + ) + + @property + def resolved_display_name(self) -> str: + return self.display_name.strip() or self.container_name.strip() + + def remote_rmcsetup_dir(self) -> PurePosixPath: + return ( + PurePosixPath( + normalize_container_directory(self.container_project_root) + ) + / "rmcsetup" + ) + + def remote_packmol_inputs_dir(self) -> PurePosixPath: + return self.remote_rmcsetup_dir() / "packmol_inputs" + + def remote_packmol_input_path( + self, + packmol_setup_metadata: "PackmolSetupMetadata | None" = None, + ) -> str: + input_name = "packmol_combined.inp" + if ( + packmol_setup_metadata is not None + and packmol_setup_metadata.packmol_input_path + ): + input_name = Path(packmol_setup_metadata.packmol_input_path).name + return str(self.remote_packmol_inputs_dir() / input_name) + + def remote_packmol_output_path( + self, + packmol_setup_metadata: "PackmolSetupMetadata | None" = None, + ) -> str: + output_name = "packed_combined.pdb" + if ( + packmol_setup_metadata is not None + and packmol_setup_metadata.packed_output_filename + ): + output_name = packmol_setup_metadata.packed_output_filename + return str(self.remote_packmol_inputs_dir() / output_name) + + def summary_text( + self, + *, + packmol_setup_metadata: "PackmolSetupMetadata | None" = None, + ) -> str: + lines = [ + f"Preset: {self.resolved_display_name}", + f"Container: {self.container_name}", + ( + "Container project root: " + f"{normalize_container_directory(self.container_project_root)}" + ), + f"Packmol command: {self.packmol_command}", + ( + "Resolved Packmol binary: " + f"{self.packmol_command_path or '(not verified yet)'}" + ), + f"Packmol version: {self.packmol_version or '(not verified yet)'}", + f"Shell command: {self.shell_command}", + f"Image: {self.image_name or '(not verified yet)'}", + f"Container ID: {self.container_id or '(not verified yet)'}", + f"Linked at: {self.linked_at or '(not linked yet)'}", + f"Last verified: {self.last_verified_at or '(not verified yet)'}", + ( + "Remote Packmol inputs folder: " + f"{self.remote_packmol_inputs_dir()}" + ), + ] + if packmol_setup_metadata is not None: + lines.extend( + [ + ( + "Remote Packmol input file: " + f"{self.remote_packmol_input_path(packmol_setup_metadata)}" + ), + ( + "Remote packed output path: " + f"{self.remote_packmol_output_path(packmol_setup_metadata)}" + ), + ] + ) + if self.last_sync_status: + lines.append( + f"Last sync status: {self.last_sync_status} at " + f"{self.last_sync_at or '(unknown time)'}" + ) + if self.last_sync_message: + lines.append(f"Last sync details: {self.last_sync_message}") + return "\n".join(lines) + + +@dataclass(slots=True, frozen=True) +class PackmolDockerDirectoryEntry: + name: str + path: str + + +@dataclass(slots=True, frozen=True) +class PackmolDockerContainerRecord: + name: str + image_name: str + status: str + + @property + def summary_label(self) -> str: + details = [part for part in (self.image_name, self.status) if part] + if not details: + return self.name + return f"{self.name} ({' | '.join(details)})" + + +@dataclass(slots=True, frozen=True) +class PackmolDockerValidationResult: + verified_at: str + container_id: str + image_name: str + packmol_command_path: str + packmol_version: str + container_project_root: str + + def summary_text(self, link: PackmolDockerLink) -> str: + lines = [ + "Docker validation succeeded.", + "", + f"Preset: {link.resolved_display_name}", + f"Container: {link.container_name}", + f"Container ID: {self.container_id}", + f"Image: {self.image_name}", + f"Packmol command: {link.packmol_command}", + f"Resolved Packmol binary: {self.packmol_command_path}", + f"Packmol version: {self.packmol_version}", + f"Verified workspace root: {self.container_project_root}", + ( + "Remote Packmol sync folder: " + f"{link.remote_packmol_inputs_dir()}" + ), + f"Verified at: {self.verified_at}", + ] + return "\n".join(lines) + + +@dataclass(slots=True, frozen=True) +class PackmolDockerSyncResult: + synced_at: str + remote_packmol_inputs_dir: str + remote_packmol_input_path: str + remote_packed_output_path: str + synced_file_count: int + + def summary_text(self) -> str: + return ( + f"Synced {self.synced_file_count} file(s) to " + f"{self.remote_packmol_inputs_dir} at {self.synced_at}." + ) + + +class PackmolDockerClient: + def __init__(self, *, docker_executable: str = "docker") -> None: + self.docker_executable = docker_executable + + def list_containers(self) -> list[PackmolDockerContainerRecord]: + stdout = self._run_docker( + "ps", + "-a", + "--format", + "{{.Names}}\t{{.Image}}\t{{.Status}}", + ).stdout + records: list[PackmolDockerContainerRecord] = [] + seen_names: set[str] = set() + for raw_line in stdout.splitlines(): + line = raw_line.strip() + if not line: + continue + name, separator, remainder = line.partition("\t") + if not separator: + continue + image_name, _, status = remainder.partition("\t") + normalized_name = name.strip() + if not normalized_name or normalized_name in seen_names: + continue + seen_names.add(normalized_name) + records.append( + PackmolDockerContainerRecord( + name=normalized_name, + image_name=image_name.strip(), + status=status.strip(), + ) + ) + records.sort(key=lambda record: record.name.lower()) + return records + + def verify_link( + self, + link: PackmolDockerLink, + ) -> PackmolDockerValidationResult: + self._run_docker("info") + if not self._container_is_running(link.container_name): + self._run_docker("start", link.container_name) + if not self._container_is_running(link.container_name): + raise RuntimeError( + "Docker could not keep the selected container running. If " + "your container requires an attached shell, start it " + "manually with `docker start -i ` and try " + "linking it again." + ) + container_id = self._run_docker( + "inspect", + "--format", + "{{.Id}}", + link.container_name, + ).stdout.strip() + image_name = self._run_docker( + "inspect", + "--format", + "{{.Config.Image}}", + link.container_name, + ).stdout.strip() + project_root = normalize_container_directory( + link.container_project_root + ) + _validate_container_project_root(project_root) + self._run_in_container( + link, + f"target={shlex.quote(project_root)}; " + '[ -d "$target" ] || exit 3', + ) + packmol_command_path = self._resolve_packmol_command_path(link) + packmol_version = self._resolve_packmol_version( + link, + packmol_command_path, + ) + return PackmolDockerValidationResult( + verified_at=datetime.now().isoformat(timespec="seconds"), + container_id=container_id, + image_name=image_name, + packmol_command_path=packmol_command_path, + packmol_version=packmol_version, + container_project_root=project_root, + ) + + def list_directories( + self, + link: PackmolDockerLink, + directory: str, + ) -> list[PackmolDockerDirectoryEntry]: + normalized = normalize_container_directory(directory) + stdout = self._run_in_container( + link, + f"target={shlex.quote(normalized)}; " + '[ -d "$target" ] || exit 3; ' + 'for entry in "$target"/*; do ' + ' [ -d "$entry" ] || continue; ' + ' name=$(basename "$entry"); ' + ' printf "%s\\t%s\\n" "$name" "$entry"; ' + "done", + ) + entries: list[PackmolDockerDirectoryEntry] = [] + for raw_line in stdout.splitlines(): + line = raw_line.strip() + if not line: + continue + name, separator, path = line.partition("\t") + if not separator: + continue + entries.append( + PackmolDockerDirectoryEntry( + name=name.strip(), + path=normalize_container_directory(path), + ) + ) + entries.sort(key=lambda entry: entry.name.lower()) + return entries + + def sync_packmol_inputs( + self, + link: PackmolDockerLink, + local_packmol_inputs_dir: str | Path, + *, + packmol_setup_metadata: "PackmolSetupMetadata | None" = None, + ) -> PackmolDockerSyncResult: + local_dir = Path(local_packmol_inputs_dir).expanduser().resolve() + if not local_dir.is_dir(): + raise ValueError( + f"Local Packmol inputs directory does not exist: {local_dir}" + ) + verified = self.verify_link(link) + remote_inputs_dir = str(link.remote_packmol_inputs_dir()) + self._run_in_container( + link, + f'target={shlex.quote(remote_inputs_dir)}; mkdir -p "$target"', + ) + self._run_docker( + "cp", + str(local_dir) + "/.", + f"{link.container_name}:{remote_inputs_dir}", + ) + synced_file_count = sum( + 1 for path in local_dir.iterdir() if path.is_file() + ) + return PackmolDockerSyncResult( + synced_at=verified.verified_at, + remote_packmol_inputs_dir=remote_inputs_dir, + remote_packmol_input_path=link.remote_packmol_input_path( + packmol_setup_metadata + ), + remote_packed_output_path=link.remote_packmol_output_path( + packmol_setup_metadata + ), + synced_file_count=synced_file_count, + ) + + def _container_is_running(self, container_name: str) -> bool: + stdout = self._run_docker( + "inspect", + "--format", + "{{.State.Running}}", + container_name, + ).stdout.strip() + return stdout.lower() == "true" + + def _resolve_packmol_command_path(self, link: PackmolDockerLink) -> str: + candidate = link.packmol_command.strip() or "packmol" + stdout = self._run_in_container( + link, + f"candidate={shlex.quote(candidate)}; " + 'if command -v "$candidate" >/dev/null 2>&1; then ' + ' command -v "$candidate"; ' + 'elif [ -x "$candidate" ]; then ' + ' printf "%s\\n" "$candidate"; ' + "else " + " exit 127; " + "fi", + ) + resolved = stdout.strip() + if not resolved: + raise RuntimeError( + f"Unable to resolve Packmol command inside container: {candidate}" + ) + return resolved + + def _resolve_packmol_version( + self, + link: PackmolDockerLink, + command_path: str, + ) -> str: + stdout = self._run_in_container( + link, + f"candidate={shlex.quote(command_path)}; " + 'version_output=$("$candidate" --version 2>&1); ' + 'if [ -z "$version_output" ]; then ' + ' version_output=$("$candidate" -version 2>&1); ' + "fi; " + 'if [ -z "$version_output" ]; then ' + ' version_output=$("$candidate" -v 2>&1); ' + "fi; " + 'if [ -z "$version_output" ]; then ' + ' version_output=$("$candidate" -h 2>&1); ' + "fi; " + '[ -n "$version_output" ] || exit 125; ' + 'printf "%s\\n" "$version_output" | sed -n \'/./{p;q;}\'', + ) + version_line = stdout.strip() + if not version_line: + raise RuntimeError( + "Packmol executable was found inside the container, but its " + "version output could not be read." + ) + return version_line + + def _run_in_container( + self, + link: PackmolDockerLink, + script: str, + ) -> str: + completed = self._run_docker( + "exec", + link.container_name, + link.shell_command.strip() or "sh", + "-lc", + script, + ) + return completed.stdout + + def _run_docker(self, *args: str) -> subprocess.CompletedProcess[str]: + try: + completed = subprocess.run( + [self.docker_executable, *args], + check=False, + capture_output=True, + text=True, + encoding="utf-8", + ) + except FileNotFoundError as exc: + raise RuntimeError( + "Docker CLI was not found on PATH. Install Docker and make " + "sure the `docker` command is available before linking a " + "Packmol container." + ) from exc + if completed.returncode != 0: + message = ( + completed.stderr.strip() + or completed.stdout.strip() + or "Docker command failed." + ) + raise RuntimeError(message) + return completed + + +def normalize_container_directory(value: object) -> str: + text = _optional_text(value) or DEFAULT_PACKMOL_CONTAINER_ROOT + if not text.startswith("/"): + text = "/" + text + normalized_parts: list[str] = [] + for part in text.split("/"): + if not part or part == ".": + continue + if part == "..": + if normalized_parts: + normalized_parts.pop() + continue + normalized_parts.append(part) + if not normalized_parts: + return "/" + return "/" + "/".join(normalized_parts) + + +def container_project_root_is_valid(value: object) -> bool: + try: + normalized = normalize_container_directory(value) + except Exception: + return False + return _is_within_container_root( + normalized, + DEFAULT_PACKMOL_CONTAINER_ROOT, + ) + + +def docker_daemon_unavailable_hint(value: object) -> str | None: + text = _optional_text(value) + if text is None: + return None + normalized = " ".join(text.lower().split()) + if any( + pattern in normalized + for pattern in _DOCKER_DAEMON_UNAVAILABLE_PATTERNS + ): + return _DOCKER_DAEMON_UNAVAILABLE_HINT + return None + + +def _validate_container_project_root(value: object) -> str: + normalized = normalize_container_directory(value) + if not _is_within_container_root( + normalized, + DEFAULT_PACKMOL_CONTAINER_ROOT, + ): + raise RuntimeError( + "Container project root must be inside " + f"{DEFAULT_PACKMOL_CONTAINER_ROOT} so Packmol input files stay " + "inside the expected bind-mounted folder." + ) + return normalized + + +def _is_within_container_root(path_text: str, root_text: str) -> bool: + path_parts = PurePosixPath(path_text).parts + root_parts = PurePosixPath(root_text).parts + if len(path_parts) < len(root_parts): + return False + return path_parts[: len(root_parts)] == root_parts + + +def load_packmol_docker_link_metadata( + metadata_path: str | Path, +) -> PackmolDockerLink | None: + path = Path(metadata_path).expanduser().resolve() + if not path.is_file(): + return None + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return None + if isinstance(payload, dict) and "link" in payload: + payload = payload.get("link") + return PackmolDockerLink.from_dict( + payload if isinstance(payload, dict) else None + ) + + +def save_packmol_docker_link_metadata( + metadata_path: str | Path, + link: PackmolDockerLink | None, +) -> Path: + path = Path(metadata_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + payload: dict[str, object] + if link is None: + payload = {} + else: + payload = { + "version": 1, + "link": link.to_dict(), + } + path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + return path + + +def _optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +__all__ = [ + "DEFAULT_PACKMOL_CONTAINER_ROOT", + "container_project_root_is_valid", + "docker_daemon_unavailable_hint", + "PackmolDockerClient", + "PackmolDockerDirectoryEntry", + "PackmolDockerLink", + "PackmolDockerSyncResult", + "PackmolDockerValidationResult", + "load_packmol_docker_link_metadata", + "normalize_container_directory", + "save_packmol_docker_link_metadata", +] diff --git a/src/saxshell/fullrmc/packmol_planning.py b/src/saxshell/fullrmc/packmol_planning.py index 3768497..db9f91e 100644 --- a/src/saxshell/fullrmc/packmol_planning.py +++ b/src/saxshell/fullrmc/packmol_planning.py @@ -8,13 +8,24 @@ import numpy as np -from saxshell.fullrmc.representatives import RepresentativeSelectionMetadata +from saxshell.fullrmc.representatives import ( + RepresentativeSelectionMetadata, + validate_representative_selection_covers_distribution, +) from saxshell.fullrmc.solution_properties import ( SolutionProperties, SolutionPropertiesMetadata, ) -from saxshell.fullrmc.solvent_handling import SolventHandlingMetadata -from saxshell.structure import PDBStructure +from saxshell.fullrmc.solvent_handling import ( + RepresentativeSolventDistributionAnalysis, + SolventHandlingMetadata, + SolventHandlingSettings, + analyze_representative_solvent_distribution, + representative_source_solvent_mode_to_variant, +) +from saxshell.saxs.debye import load_structure_file +from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb import resolve_reference_path if False: # pragma: no cover from .project_loader import RMCDreamProjectSource @@ -24,6 +35,7 @@ class PackmolPlanningSettings: planning_mode: str = "per_element" box_side_length_a: float = 100.0 + free_solvent_reference: str | None = None def to_dict(self) -> dict[str, object]: return asdict(self) @@ -44,6 +56,88 @@ def from_dict( return cls( planning_mode=mode, box_side_length_a=max(box_side_length_a, 1.0), + free_solvent_reference=_optional_text( + source.get("free_solvent_reference") + ), + ) + + +@dataclass(slots=True) +class PackmolSolventAllocationEntry: + structure: str + motif: str + param: str + planned_count: int + solvent_molecules_per_cluster: int + solvent_molecules_total: int + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "PackmolSolventAllocationEntry": + return cls( + structure=str(payload.get("structure", "")).strip(), + motif=str(payload.get("motif", "no_motif")).strip() or "no_motif", + param=str(payload.get("param", "")).strip(), + planned_count=int(payload.get("planned_count", 0)), + solvent_molecules_per_cluster=int( + payload.get("solvent_molecules_per_cluster", 0) + ), + solvent_molecules_total=int( + payload.get("solvent_molecules_total", 0) + ), + ) + + +@dataclass(slots=True) +class PackmolSolventAllocation: + reference_name: str | None + reference_path: str | None + target_solvent_molecules: int + solvent_molecules_in_clusters: int + free_solvent_molecules: int + entries: list[PackmolSolventAllocationEntry] + + def to_dict(self) -> dict[str, object]: + return { + "reference_name": self.reference_name, + "reference_path": self.reference_path, + "target_solvent_molecules": self.target_solvent_molecules, + "solvent_molecules_in_clusters": ( + self.solvent_molecules_in_clusters + ), + "free_solvent_molecules": self.free_solvent_molecules, + "entries": [entry.to_dict() for entry in self.entries], + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object] | None, + ) -> "PackmolSolventAllocation | None": + if not payload: + return None + return cls( + reference_name=_optional_text(payload.get("reference_name")), + reference_path=_optional_text(payload.get("reference_path")), + target_solvent_molecules=int( + payload.get("target_solvent_molecules", 0) + ), + solvent_molecules_in_clusters=int( + payload.get("solvent_molecules_in_clusters", 0) + ), + free_solvent_molecules=int( + payload.get("free_solvent_molecules", 0) + ), + entries=[ + PackmolSolventAllocationEntry.from_dict(dict(entry)) + for entry in payload.get("entries", []) + if isinstance(entry, dict) + ], ) @@ -101,6 +195,7 @@ class PackmolPlanningMetadata: target_element_number_density_a3: dict[str, float] achieved_total_number_density_a3: float achieved_element_number_density_a3: dict[str, float] + solvent_allocation: PackmolSolventAllocation | None entries: list[PackmolPlanningEntry] report_text: str @@ -125,6 +220,11 @@ def to_dict(self) -> dict[str, object]: "achieved_element_number_density_a3": dict( self.achieved_element_number_density_a3 ), + "solvent_allocation": ( + None + if self.solvent_allocation is None + else self.solvent_allocation.to_dict() + ), "entries": [entry.to_dict() for entry in self.entries], "report_text": self.report_text, } @@ -167,6 +267,11 @@ def from_dict( payload.get("achieved_element_number_density_a3", {}) ).items() }, + solvent_allocation=PackmolSolventAllocation.from_dict( + payload.get("solvent_allocation") + if isinstance(payload.get("solvent_allocation"), dict) + else None + ), entries=[ PackmolPlanningEntry.from_dict(dict(entry)) for entry in payload.get("entries", []) @@ -181,15 +286,41 @@ def summary_text(self) -> str: f"Box side: {self.settings.box_side_length_a:.3f} A", f"Saved at: {self.updated_at}", f"Planned clusters: {sum(entry.planned_count for entry in self.entries)}", - ( - "Target total number density: " - f"{self.target_total_number_density_a3:.6g} atoms/A^3" - ), - ( - "Achieved total number density: " - f"{self.achieved_total_number_density_a3:.6g} atoms/A^3" - ), ] + if self.solvent_allocation is not None: + if self.solvent_allocation.reference_name: + lines.append( + "Free solvent structure: " + f"{self.solvent_allocation.reference_name}" + ) + lines.extend( + [ + ( + "Total solvent molecules: " + f"{self.solvent_allocation.target_solvent_molecules}" + ), + ( + "Cluster solvent molecules: " + f"{self.solvent_allocation.solvent_molecules_in_clusters}" + ), + ( + "Free solvent molecules: " + f"{self.solvent_allocation.free_solvent_molecules}" + ), + ] + ) + lines.extend( + [ + ( + "Target total number density: " + f"{self.target_total_number_density_a3:.6g} atoms/A^3" + ), + ( + "Achieved total number density: " + f"{self.achieved_total_number_density_a3:.6g} atoms/A^3" + ), + ] + ) if self.entries: first = self.entries[0] lines.extend( @@ -221,8 +352,11 @@ def build_packmol_plan( or not active_representatives.representative_entries ): raise ValueError( - "Compute representative clusters before planning the Packmol box." + "Save representative structures before planning the Packmol box." ) + validate_representative_selection_covers_distribution( + active_representatives + ) active_solution = solution_metadata or project_source.solution_properties if active_solution.result is None: raise ValueError( @@ -237,9 +371,18 @@ def build_packmol_plan( target_total_nd = float(solution.number_density_A3) target_element_nd = _element_number_density(solution) + active_solvent = solvent_metadata or project_source.solvent_handling + solvent_analysis = _build_packmol_solvent_analysis( + project_source, + settings, + active_representatives, + active_solvent, + ) + composition_lookup = _build_composition_lookup( active_representatives, - solvent_metadata or project_source.solvent_handling, + active_solvent, + solvent_analysis=solvent_analysis, ) keys: list[tuple[str, str, str]] = [] weights: list[float] = [] @@ -320,6 +463,14 @@ def build_packmol_plan( ) ) + solvent_allocation = _build_solvent_allocation( + settings=settings, + box_targets=box_targets, + representative_metadata=active_representatives, + planning_entries=entries, + solvent_metadata=active_solvent, + solvent_analysis=solvent_analysis, + ) report_text = _build_plan_report( settings=settings, box_targets=box_targets, @@ -328,6 +479,7 @@ def build_packmol_plan( target_element_nd=target_element_nd, achieved_total_nd=achieved_total_nd, achieved_element_nd=achieved_element_nd, + solvent_allocation=solvent_allocation, ) metadata = PackmolPlanningMetadata( settings=settings, @@ -338,6 +490,7 @@ def build_packmol_plan( target_element_number_density_a3=target_element_nd, achieved_total_number_density_a3=achieved_total_nd, achieved_element_number_density_a3=achieved_element_nd, + solvent_allocation=solvent_allocation, entries=entries, report_text=report_text, ) @@ -389,14 +542,22 @@ def _element_number_density(solution: SolutionProperties) -> dict[str, float]: def _build_composition_lookup( metadata: RepresentativeSelectionMetadata, solvent_metadata: SolventHandlingMetadata | None, + *, + solvent_analysis: RepresentativeSolventDistributionAnalysis | None = None, ) -> dict[tuple[str, str, str], tuple[dict[str, int], int, str]]: lookup: dict[tuple[str, str, str], tuple[dict[str, int], int, str]] = {} solvent_lookup: dict[tuple[str, str, str], Path] = {} + analysis_lookup: dict[tuple[str, str, str], object] = {} if solvent_metadata is not None: for entry in solvent_metadata.entries: solvent_lookup[(entry.structure, entry.motif, entry.param)] = Path( entry.no_solvent_pdb ) + if solvent_analysis is not None: + analysis_lookup = { + (entry.structure, entry.motif, entry.param): entry + for entry in solvent_analysis.entries + } for entry in metadata.representative_entries: key = (entry.structure, entry.motif, entry.param) @@ -412,6 +573,22 @@ def _build_composition_lookup( "pdb_no_solvent", ) continue + analysis_entry = analysis_lookup.get(key) + if analysis_entry is not None: + structure = _strip_detected_solvent_atoms( + _load_structure_as_pdb( + entry.source_file, + structure_label=entry.structure, + ), + analysis_entry.analysis_result, + ) + counts = _count_elements(structure) + lookup[key] = ( + counts, + len(structure.atoms), + "analyzed_source_no_solvent", + ) + continue lookup[key] = ( dict(entry.element_counts), int(entry.atom_count), @@ -520,6 +697,7 @@ def _build_plan_report( target_element_nd: dict[str, float], achieved_total_nd: float, achieved_element_nd: dict[str, float], + solvent_allocation: PackmolSolventAllocation | None, ) -> str: lines = [ "== Packmol Planning ==", @@ -530,19 +708,56 @@ def _build_plan_report( f"{int(box_targets.get('solute_molecules', 0))} solute molecules, " f"{int(box_targets.get('solvent_molecules', 0))} solvent molecules" ), - ("Target total number density: " f"{target_total_nd:.6f} atoms/A^3"), - ( - "Achieved total number density (cluster plan): " - f"{achieved_total_nd:.6f} atoms/A^3" - ), - "", - "Counts per cluster bin:", ] + if solvent_allocation is not None: + if solvent_allocation.reference_name: + lines.append( + "Free solvent structure: " + f"{solvent_allocation.reference_name}" + ) + lines.extend( + [ + ( + "Solvent allocation: " + f"{solvent_allocation.solvent_molecules_in_clusters} in cluster files, " + f"{solvent_allocation.free_solvent_molecules} free solvent molecules" + ), + ] + ) + lines.extend( + [ + ( + "Target total number density: " + f"{target_total_nd:.6f} atoms/A^3" + ), + ( + "Achieved total number density (cluster plan): " + f"{achieved_total_nd:.6f} atoms/A^3" + ), + "", + "Counts per cluster bin:", + ] + ) for entry in entries: lines.append( f" - {entry.structure}/{entry.motif}: {entry.planned_count} " f"(selected weight {entry.selected_weight:.6g})" ) + if solvent_allocation is not None and solvent_allocation.entries: + lines.extend(["", "Embedded cluster solvent contributions:"]) + for allocation_entry in solvent_allocation.entries: + if ( + allocation_entry.planned_count <= 0 + or allocation_entry.solvent_molecules_total <= 0 + ): + continue + lines.append( + " - " + f"{allocation_entry.structure}/{allocation_entry.motif}: " + f"{allocation_entry.planned_count} x " + f"{allocation_entry.solvent_molecules_per_cluster} = " + f"{allocation_entry.solvent_molecules_total}" + ) if target_element_nd: lines.extend( [ @@ -623,10 +838,291 @@ def _write_plan_reports( ) +def _build_packmol_solvent_analysis( + project_source: "RMCDreamProjectSource", + settings: PackmolPlanningSettings, + representative_metadata: RepresentativeSelectionMetadata, + solvent_metadata: SolventHandlingMetadata | None, +) -> RepresentativeSolventDistributionAnalysis | None: + if solvent_metadata is not None: + return None + if not _representatives_contain_solvent(representative_metadata): + return None + reference_identifier = _optional_text(settings.free_solvent_reference) + if reference_identifier is None: + raise ValueError( + "Choose a free-solvent structure before planning counts for " + "representative files that already contain solvent, or build " + "solvent-handling outputs first." + ) + return analyze_representative_solvent_distribution( + project_source, + _solvent_settings_for_reference(reference_identifier), + representative_metadata=representative_metadata, + ) + + +def _build_solvent_allocation( + *, + settings: PackmolPlanningSettings, + box_targets: dict[str, object], + representative_metadata: RepresentativeSelectionMetadata, + planning_entries: list[PackmolPlanningEntry], + solvent_metadata: SolventHandlingMetadata | None, + solvent_analysis: RepresentativeSolventDistributionAnalysis | None, +) -> PackmolSolventAllocation: + allocation_entries: list[PackmolSolventAllocationEntry] = [] + target_solvent_molecules = int( + round(float(box_targets.get("solvent_molecules", 0))) + ) + counts_by_key: dict[tuple[str, str, str], int] = {} + reference_name, reference_path = _selected_reference_details( + settings.free_solvent_reference + ) + if solvent_metadata is not None: + counts_by_key = { + ( + entry.structure, + entry.motif, + entry.param, + ): _completed_solvent_count(entry) + for entry in solvent_metadata.entries + } + if reference_path is None: + reference_name = _optional_text(solvent_metadata.reference_name) + reference_path = _optional_text(solvent_metadata.reference_path) + elif solvent_analysis is not None: + counts_by_key = { + ( + entry.structure, + entry.motif, + entry.param, + ): _source_solvent_count(entry) + for entry in solvent_analysis.entries + } + if reference_path is None: + reference_name = _optional_text(solvent_analysis.reference_name) + reference_path = _optional_text(solvent_analysis.reference_path) + + representative_variants = { + ( + entry.structure, + entry.motif, + entry.param, + ): representative_source_solvent_mode_to_variant( + entry.source_solvent_mode + ) + for entry in representative_metadata.representative_entries + } + if not counts_by_key and _representatives_contain_solvent( + representative_metadata + ): + raise ValueError( + "Unable to determine how much solvent is already present in the " + "representative cluster files. Choose a free-solvent structure " + "or build solvent-handling outputs first." + ) + + solvent_molecules_in_clusters = 0 + for entry in planning_entries: + key = (entry.structure, entry.motif, entry.param) + per_cluster = int(counts_by_key.get(key, 0)) + if per_cluster <= 0 and representative_variants.get(key) in { + "full_solvent", + "partial_solvent", + }: + raise ValueError( + "Unable to determine the embedded solvent count for " + f"{entry.structure}/{entry.motif}. Choose a free-solvent " + "structure or build solvent-handling outputs first." + ) + total = int(entry.planned_count) * per_cluster + solvent_molecules_in_clusters += total + allocation_entries.append( + PackmolSolventAllocationEntry( + structure=entry.structure, + motif=entry.motif, + param=entry.param, + planned_count=int(entry.planned_count), + solvent_molecules_per_cluster=per_cluster, + solvent_molecules_total=total, + ) + ) + + free_solvent_molecules = max( + 0, + target_solvent_molecules - solvent_molecules_in_clusters, + ) + return PackmolSolventAllocation( + reference_name=reference_name, + reference_path=reference_path, + target_solvent_molecules=target_solvent_molecules, + solvent_molecules_in_clusters=solvent_molecules_in_clusters, + free_solvent_molecules=free_solvent_molecules, + entries=allocation_entries, + ) + + +def _representatives_contain_solvent( + representative_metadata: RepresentativeSelectionMetadata, +) -> bool: + return any( + representative_source_solvent_mode_to_variant( + entry.source_solvent_mode + ) + in {"full_solvent", "partial_solvent"} + for entry in representative_metadata.representative_entries + ) + + +def _completed_solvent_count(entry: object) -> int: + detected_status = str(getattr(entry, "detected_source_status", "")).strip() + if detected_status == "partial_solvent": + return max( + int(getattr(entry, "detected_partial_solvent_count", 0)), + 0, + ) + max(int(getattr(entry, "solvent_molecules_added", 0)), 0) + if detected_status == "complete_solvent": + return max( + int(getattr(entry, "detected_complete_solvent_count", 0)), + int(getattr(entry, "solvent_molecules_added", 0)), + ) + return max(int(getattr(entry, "solvent_molecules_added", 0)), 0) + + +def _source_solvent_count(entry: object) -> int: + source_status = str(getattr(entry, "source_status", "")).strip() + analysis_result = getattr(entry, "analysis_result", None) + if analysis_result is None: + return 0 + if source_status == "complete_solvent": + return max( + int( + getattr(analysis_result, "complete_solvent_molecule_count", 0) + ), + 0, + ) + if source_status == "partial_solvent": + return max( + int(getattr(analysis_result, "partial_solvent_molecule_count", 0)), + 0, + ) + return 0 + + +def _solvent_settings_for_reference( + reference_identifier: str, +) -> SolventHandlingSettings: + candidate = Path(reference_identifier).expanduser() + if candidate.is_file(): + return SolventHandlingSettings( + reference_source="custom", + custom_reference_path=str(candidate.resolve()), + ) + return SolventHandlingSettings( + reference_source="preset", + preset_name=str(reference_identifier).strip(), + ) + + +def _selected_reference_details( + reference_identifier: str | None, +) -> tuple[str | None, str | None]: + identifier = _optional_text(reference_identifier) + if identifier is None: + return None, None + resolved_reference = resolve_reference_path(identifier).expanduser() + return resolved_reference.stem, str(resolved_reference.resolve()) + + +def _strip_detected_solvent_atoms( + structure: PDBStructure, + analysis_result: object, +) -> PDBStructure: + stripped_atom_ids = { + int(atom_id) + for atom_id in getattr( + analysis_result, + "complete_solvent_source_atom_ids", + (), + ) + } + stripped_atom_ids.update( + int(atom_id) + for atom_id in getattr( + analysis_result, + "partial_solvent_source_atom_ids", + (), + ) + ) + stripped_atoms = [ + atom.copy() + for atom in structure.atoms + if int(atom.atom_id) not in stripped_atom_ids + ] + for index, atom in enumerate(stripped_atoms, start=1): + atom.atom_id = index + return PDBStructure( + atoms=stripped_atoms, + source_name=structure.source_name, + ) + + +def _load_structure_as_pdb( + source_file: str | Path, + *, + structure_label: str, +) -> PDBStructure: + path = Path(source_file).expanduser().resolve() + if path.suffix.lower() == ".pdb": + return PDBStructure.from_file(path) + positions, elements = load_structure_file(path) + counters: dict[str, int] = {} + atoms: list[PDBAtom] = [] + residue_name = _normalized_residue_name(structure_label) + for index, (coordinates, element) in enumerate( + zip(positions, elements, strict=True), + start=1, + ): + counters[element] = counters.get(element, 0) + 1 + atoms.append( + PDBAtom( + atom_id=index, + atom_name=f"{element}{counters[element]}", + residue_name=residue_name, + residue_number=1, + coordinates=np.asarray(coordinates, dtype=float), + element=str(element), + ) + ) + return PDBStructure(atoms=atoms, source_name=path.stem) + + +def _count_elements(structure: PDBStructure) -> dict[str, int]: + counts: dict[str, int] = {} + for atom in structure.atoms: + counts[atom.element] = counts.get(atom.element, 0) + 1 + return counts + + +def _normalized_residue_name(text: str) -> str: + collapsed = "".join(char for char in str(text).upper() if char.isalnum()) + return (collapsed or "CLU")[:3] + + +def _optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + __all__ = [ "PackmolPlanningEntry", "PackmolPlanningMetadata", "PackmolPlanningSettings", + "PackmolSolventAllocation", + "PackmolSolventAllocationEntry", "build_packmol_plan", "load_packmol_planning_metadata", "save_packmol_planning_metadata", diff --git a/src/saxshell/fullrmc/packmol_setup.py b/src/saxshell/fullrmc/packmol_setup.py index 4da0da5..41cc50e 100644 --- a/src/saxshell/fullrmc/packmol_setup.py +++ b/src/saxshell/fullrmc/packmol_setup.py @@ -13,10 +13,16 @@ from saxshell.fullrmc.representatives import ( RepresentativeSelectionEntry, RepresentativeSelectionMetadata, + validate_representative_selection_covers_distribution, +) +from saxshell.fullrmc.solvent_handling import ( + SolventHandlingMetadata, + representative_structure_mode_label, + resolved_representative_structure_mode, ) -from saxshell.fullrmc.solvent_handling import SolventHandlingMetadata from saxshell.saxs.debye import load_structure_file from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb import resolve_reference_path if False: # pragma: no cover from .project_loader import RMCDreamProjectSource @@ -29,6 +35,7 @@ class PackmolSetupSettings: packed_output_filename: str = "packed_combined.pdb" use_completed_representatives: bool = True include_free_solvent: bool = True + free_solvent_reference: str | None = None def to_dict(self) -> dict[str, object]: return asdict(self) @@ -61,6 +68,9 @@ def from_dict( include_free_solvent=bool( source.get("include_free_solvent", True) ), + free_solvent_reference=_optional_text( + source.get("free_solvent_reference") + ), ) @@ -109,10 +119,15 @@ class PackmolSetupMetadata: updated_at: str planning_mode: str representative_selection_mode: str + representative_structure_mode: str box_side_length_a: float packmol_input_path: str packed_output_filename: str solvent_pdb_path: str | None + free_solvent_reference_name: str | None + free_solvent_reference_path: str | None + target_solvent_molecules: int + solvent_molecules_in_clusters: int free_solvent_molecules: int audit_report_path: str entries: list[PackmolSetupEntry] @@ -126,10 +141,19 @@ def to_dict(self) -> dict[str, object]: "representative_selection_mode": ( self.representative_selection_mode ), + "representative_structure_mode": ( + self.representative_structure_mode + ), "box_side_length_a": self.box_side_length_a, "packmol_input_path": self.packmol_input_path, "packed_output_filename": self.packed_output_filename, "solvent_pdb_path": self.solvent_pdb_path, + "free_solvent_reference_name": self.free_solvent_reference_name, + "free_solvent_reference_path": self.free_solvent_reference_path, + "target_solvent_molecules": self.target_solvent_molecules, + "solvent_molecules_in_clusters": ( + self.solvent_molecules_in_clusters + ), "free_solvent_molecules": self.free_solvent_molecules, "audit_report_path": self.audit_report_path, "entries": [entry.to_dict() for entry in self.entries], @@ -153,6 +177,9 @@ def from_dict( representative_selection_mode=str( payload.get("representative_selection_mode", "") ).strip(), + representative_structure_mode=str( + payload.get("representative_structure_mode", "") + ).strip(), box_side_length_a=float(payload.get("box_side_length_a", 0.0)), packmol_input_path=str( payload.get("packmol_input_path", "") @@ -161,6 +188,18 @@ def from_dict( payload.get("packed_output_filename", "") ).strip(), solvent_pdb_path=_optional_text(payload.get("solvent_pdb_path")), + free_solvent_reference_name=_optional_text( + payload.get("free_solvent_reference_name") + ), + free_solvent_reference_path=_optional_text( + payload.get("free_solvent_reference_path") + ), + target_solvent_molecules=int( + payload.get("target_solvent_molecules", 0) + ), + solvent_molecules_in_clusters=int( + payload.get("solvent_molecules_in_clusters", 0) + ), free_solvent_molecules=int( payload.get("free_solvent_molecules", 0) ), @@ -178,12 +217,31 @@ def summary_text(self) -> str: lines = [ f"Planning mode: {self.planning_mode}", f"Representative mode: {self.representative_selection_mode}", + ( + "Representative structure set: " + f"{representative_structure_mode_label(self.representative_structure_mode)}" + ), f"Saved at: {self.updated_at}", f"Box side: {self.box_side_length_a:.3f} A", + f"Packmol tolerance: {self.settings.tolerance_angstrom:.3f} A", f"Packmol input: {Path(self.packmol_input_path).name}", f"Representative PDBs copied: {len(self.entries)}", - f"Free solvent molecules: {self.free_solvent_molecules}", ] + if self.free_solvent_reference_name: + lines.append( + "Free solvent structure: " + f"{self.free_solvent_reference_name}" + ) + lines.extend( + [ + f"Total solvent molecules: {self.target_solvent_molecules}", + ( + "Cluster solvent molecules: " + f"{self.solvent_molecules_in_clusters}" + ), + f"Free solvent molecules: {self.free_solvent_molecules}", + ] + ) if self.entries: first = self.entries[0] lines.extend( @@ -223,18 +281,27 @@ def build_packmol_setup( or not active_representatives.representative_entries ): raise ValueError( - "Compute representative clusters before building Packmol setup." + "Save representative structures before building Packmol setup." ) + validate_representative_selection_covers_distribution( + active_representatives + ) active_solvent = solvent_metadata or project_source.solvent_handling + free_solvent_reference_name: str | None = None + free_solvent_reference_path: str | None = None if active_settings.include_free_solvent: - if active_solvent is None: - raise ValueError( - "Build representative solvent outputs before generating Packmol inputs." - ) - if not active_solvent.reference_path: + ( + free_solvent_reference_name, + free_solvent_reference_path, + ) = _resolve_free_solvent_reference( + active_settings, + active_plan, + active_solvent, + ) + if free_solvent_reference_path is None: raise ValueError( - "No solvent reference PDB is available for Packmol input generation." + "Choose a free-solvent structure before generating Packmol inputs." ) representative_lookup = { @@ -247,6 +314,10 @@ def build_packmol_setup( (entry.structure, entry.motif, entry.param): entry for entry in active_solvent.entries } + representative_structure_mode = resolved_representative_structure_mode( + active_representatives, + active_solvent, + ) entries: list[PackmolSetupEntry] = [] box_side_length_a = active_plan.settings.box_side_length_a @@ -263,6 +334,7 @@ def build_packmol_setup( source_structure, source_pdb_path = _resolve_structure_for_packmol( representative_entry, solvent_lookup.get(key), + representative_structure_mode=representative_structure_mode, use_completed=active_settings.use_completed_representatives, ) residue_name = _packmol_residue_code(index) @@ -302,24 +374,32 @@ def build_packmol_setup( ) solvent_pdb_path: str | None = None - free_solvent_molecules = 0 - if active_settings.include_free_solvent and active_solvent is not None: - source_solvent = ( - Path(active_solvent.reference_path).expanduser().resolve() - ) - free_solvent_molecules = int( - round( - float( - active_plan.target_box_composition.get( - "solvent_molecules", - 0, - ) - ) + solvent_allocation = active_plan.solvent_allocation + target_solvent_molecules = int( + round( + float( + active_plan.target_box_composition.get("solvent_molecules", 0) ) ) - solvent_copy_name = ( - f"{_safe_name(active_solvent.reference_name)}_single.pdb" + ) + solvent_molecules_in_clusters = 0 + free_solvent_molecules = target_solvent_molecules + if solvent_allocation is not None: + target_solvent_molecules = int( + solvent_allocation.target_solvent_molecules + ) + solvent_molecules_in_clusters = int( + solvent_allocation.solvent_molecules_in_clusters ) + free_solvent_molecules = int(solvent_allocation.free_solvent_molecules) + if ( + active_settings.include_free_solvent + and free_solvent_reference_path is not None + ): + source_solvent = ( + Path(free_solvent_reference_path).expanduser().resolve() + ) + solvent_copy_name = f"{_safe_name(free_solvent_reference_name or source_solvent.stem)}_single.pdb" destination = ( project_source.rmcsetup_paths.packmol_inputs_dir / solvent_copy_name @@ -341,6 +421,10 @@ def build_packmol_setup( entries, input_path=input_path, solvent_pdb_path=solvent_pdb_path, + free_solvent_reference_name=free_solvent_reference_name, + free_solvent_reference_path=free_solvent_reference_path, + target_solvent_molecules=target_solvent_molecules, + solvent_molecules_in_clusters=solvent_molecules_in_clusters, free_solvent_molecules=free_solvent_molecules, ) metadata = PackmolSetupMetadata( @@ -348,10 +432,15 @@ def build_packmol_setup( updated_at=datetime.now().isoformat(timespec="seconds"), planning_mode=active_plan.settings.planning_mode, representative_selection_mode=active_representatives.selection_mode, + representative_structure_mode=representative_structure_mode, box_side_length_a=box_side_length_a, packmol_input_path=str(input_path), packed_output_filename=active_settings.packed_output_filename, solvent_pdb_path=solvent_pdb_path, + free_solvent_reference_name=free_solvent_reference_name, + free_solvent_reference_path=free_solvent_reference_path, + target_solvent_molecules=target_solvent_molecules, + solvent_molecules_in_clusters=solvent_molecules_in_clusters, free_solvent_molecules=free_solvent_molecules, audit_report_path=str(audit_path), entries=entries, @@ -393,38 +482,50 @@ def _resolve_structure_for_packmol( representative_entry: RepresentativeSelectionEntry, solvent_entry: object | None, *, + representative_structure_mode: str, use_completed: bool, ) -> tuple[PDBStructure, Path]: - candidate_path: Path | None = None - if solvent_entry is not None and use_completed: - candidate_path = Path( + candidate_paths: list[Path] = [] + if solvent_entry is not None: + completed_path = Path( getattr(solvent_entry, "completed_pdb", "") ).expanduser() - if ( - candidate_path is None - or not str(candidate_path) - or not candidate_path.is_file() - ) and solvent_entry is not None: - candidate_path = Path( + no_solvent_path = Path( getattr(solvent_entry, "no_solvent_pdb", "") ).expanduser() - if ( - candidate_path is None - or not str(candidate_path) - or not candidate_path.is_file() - ): - source_path = ( - Path(representative_entry.source_file).expanduser().resolve() - ) - return ( - _load_structure_as_pdb( - source_path, - structure_label=representative_entry.structure, - ), + if representative_structure_mode == "full_solvent": + candidate_paths.extend([completed_path, no_solvent_path]) + elif representative_structure_mode == "no_solvent": + candidate_paths.extend([no_solvent_path, completed_path]) + elif representative_structure_mode == "partial_solvent": + candidate_paths.extend( + [Path(representative_entry.source_file).expanduser()] + ) + elif use_completed: + candidate_paths.extend([completed_path, no_solvent_path]) + else: + candidate_paths.extend([no_solvent_path, completed_path]) + candidate_paths.append(Path(representative_entry.source_file).expanduser()) + source_path = Path(representative_entry.source_file).expanduser().resolve() + for candidate_path in candidate_paths: + if str(candidate_path).strip() and candidate_path.is_file(): + resolved = candidate_path.resolve() + if resolved == source_path: + return ( + _load_structure_as_pdb( + resolved, + structure_label=representative_entry.structure, + ), + resolved, + ) + return PDBStructure.from_file(resolved), resolved + return ( + _load_structure_as_pdb( source_path, - ) - resolved = candidate_path.resolve() - return PDBStructure.from_file(resolved), resolved + structure_label=representative_entry.structure, + ), + source_path, + ) def _prepare_packmol_structure( @@ -485,6 +586,10 @@ def _write_packmol_audit_report( *, input_path: Path, solvent_pdb_path: str | None, + free_solvent_reference_name: str | None, + free_solvent_reference_path: str | None, + target_solvent_molecules: int, + solvent_molecules_in_clusters: int, free_solvent_molecules: int, ) -> Path: lines = [ @@ -499,6 +604,16 @@ def _write_packmol_audit_report( "- Solvent input: " f"{solvent_pdb_path if solvent_pdb_path is not None else '(none)'}" ), + ( + "- Free solvent structure: " + f"{free_solvent_reference_name or '(none)'}" + ), + ( + "- Free solvent source path: " + f"{free_solvent_reference_path if free_solvent_reference_path is not None else '(none)'}" + ), + f"- Target solvent molecules: {target_solvent_molecules}", + ("- Cluster solvent molecules: " f"{solvent_molecules_in_clusters}"), f"- Free solvent molecules: {free_solvent_molecules}", "", "## Planned Clusters", @@ -529,8 +644,8 @@ def _write_packmol_audit_report( "", "## Notes", "- Cluster PDBs were rewritten with unique residue names for Packmol use.", - "- Free solvent counts currently follow the target solvent molecules from the solution-properties box composition.", - "- Coordinated-solvent completion assumptions come from the saved solvent-handling step when available.", + "- Free solvent counts subtract solvent molecules already present in the cluster files from the bulk-solvent target.", + "- If solvent-handling outputs are available, the completed full-solvent representative PDBs define the embedded cluster solvent counts.", ] ) audit_path = project_source.rmcsetup_paths.packmol_audit_report_path @@ -608,6 +723,36 @@ def _safe_filename(text: str) -> str: return name or "item" +def _resolve_free_solvent_reference( + settings: PackmolSetupSettings, + plan_metadata: PackmolPlanningMetadata, + solvent_metadata: SolventHandlingMetadata | None, +) -> tuple[str | None, str | None]: + candidates = [ + settings.free_solvent_reference, + plan_metadata.settings.free_solvent_reference, + ( + None + if plan_metadata.solvent_allocation is None + else plan_metadata.solvent_allocation.reference_path + ), + ( + None + if solvent_metadata is None + else solvent_metadata.reference_path + ), + ] + for candidate in candidates: + reference_identifier = _optional_text(candidate) + if reference_identifier is None: + continue + resolved_reference = resolve_reference_path( + reference_identifier + ).expanduser() + return resolved_reference.stem, str(resolved_reference.resolve()) + return None, None + + __all__ = [ "PackmolSetupEntry", "PackmolSetupMetadata", diff --git a/src/saxshell/fullrmc/project_loader.py b/src/saxshell/fullrmc/project_loader.py index e3c817f..98f8d1b 100644 --- a/src/saxshell/fullrmc/project_loader.py +++ b/src/saxshell/fullrmc/project_loader.py @@ -11,12 +11,17 @@ ProjectSettings, SAXSProjectManager, build_project_paths, + project_artifact_paths, ) from .constraint_generation import ( ConstraintGenerationMetadata, load_constraint_generation_metadata, ) +from .packmol_docker import ( + PackmolDockerLink, + load_packmol_docker_link_metadata, +) from .packmol_planning import ( PackmolPlanningMetadata, load_packmol_planning_metadata, @@ -64,6 +69,7 @@ class RMCDreamProjectSource: solution_properties: SolutionPropertiesMetadata representative_selection: RepresentativeSelectionMetadata | None solvent_handling: SolventHandlingMetadata | None + packmol_docker_link: PackmolDockerLink | None packmol_planning: PackmolPlanningMetadata | None packmol_setup: PackmolSetupMetadata | None constraint_generation: ConstraintGenerationMetadata | None @@ -91,12 +97,16 @@ def load_rmc_project_source( manager = SAXSProjectManager() settings = manager.load_project(project_dir) paths = build_project_paths(settings.project_dir) + artifact_paths = project_artifact_paths(settings) rmcsetup_paths = ensure_rmcsetup_structure(paths) return RMCDreamProjectSource( settings=settings, paths=paths, rmcsetup_paths=rmcsetup_paths, - valid_runs=discover_valid_dream_runs(paths), + valid_runs=_discover_project_valid_dream_runs( + paths, + artifact_paths.dream_dir, + ), cluster_validation=validate_cluster_source( settings, project_paths=paths, @@ -110,6 +120,9 @@ def load_rmc_project_source( solvent_handling=load_solvent_handling_metadata( rmcsetup_paths.solvent_handling_path ), + packmol_docker_link=load_packmol_docker_link_metadata( + rmcsetup_paths.packmol_docker_link_path + ), packmol_planning=load_packmol_planning_metadata( rmcsetup_paths.packmol_plan_path ), @@ -131,9 +144,42 @@ def discover_valid_dream_runs( paths = paths_or_dir else: paths = build_project_paths(paths_or_dir) + return _discover_runs_in_dir( + paths.dream_dir, + project_dir=paths.project_dir, + ) + + +def _discover_project_valid_dream_runs( + paths: ProjectPaths, + active_dream_dir: Path, +) -> list[RMCDreamRunRecord]: + active_records = _discover_runs_in_dir( + active_dream_dir, + project_dir=paths.project_dir, + ) + if active_records: + return active_records + if active_dream_dir.resolve() == paths.dream_dir.resolve(): + return active_records + # Prefer the active artifact root used by the SAXS UI, but keep + # legacy project-root DREAM runs available when no scoped runs exist. + return _discover_runs_in_dir( + paths.dream_dir, + project_dir=paths.project_dir, + ) + + +def _discover_runs_in_dir( + dream_dir: Path, + *, + project_dir: Path, +) -> list[RMCDreamRunRecord]: + if not dream_dir.is_dir(): + return [] records: list[RMCDreamRunRecord] = [] for metadata_path in sorted( - paths.dream_dir.rglob("dream_runtime_metadata.json") + dream_dir.rglob("dream_runtime_metadata.json") ): run_dir = metadata_path.parent if not _is_valid_run_dir(run_dir): @@ -146,7 +192,7 @@ def discover_valid_dream_runs( settings = DreamRunSettings.from_dict( dict(metadata.get("settings", {})) ) - relative_path = str(run_dir.relative_to(paths.project_dir)) + relative_path = str(run_dir.relative_to(project_dir)) template_name = _optional_text(metadata.get("template_name")) records.append( RMCDreamRunRecord( diff --git a/src/saxshell/fullrmc/project_model.py b/src/saxshell/fullrmc/project_model.py index 2364556..1aad673 100644 --- a/src/saxshell/fullrmc/project_model.py +++ b/src/saxshell/fullrmc/project_model.py @@ -9,6 +9,7 @@ ProjectPaths, ProjectSettings, build_project_paths, + project_artifact_paths, ) _STRUCTURE_SUFFIXES = {".pdb", ".xyz"} @@ -20,6 +21,7 @@ class RMCSetupPaths: rmcsetup_dir: Path representative_clusters_dir: Path representative_selection_path: Path + representative_partial_solvent_dir: Path pdb_no_solvent_dir: Path pdb_with_solvent_dir: Path packmol_inputs_dir: Path @@ -28,6 +30,7 @@ class RMCSetupPaths: distribution_selection_path: Path solution_properties_path: Path solvent_handling_path: Path + packmol_docker_link_path: Path packmol_plan_path: Path packmol_setup_path: Path constraint_generation_path: Path @@ -61,15 +64,17 @@ def build_rmcsetup_paths( else: project_dir = Path(project_dir_or_paths).expanduser().resolve() rmcsetup_dir = project_dir / "rmcsetup" + representative_root_dir = rmcsetup_dir / "representative_structures" return RMCSetupPaths( project_dir=project_dir, rmcsetup_dir=rmcsetup_dir, - representative_clusters_dir=rmcsetup_dir / "representative_clusters", - representative_selection_path=rmcsetup_dir - / "representative_clusters" + representative_clusters_dir=representative_root_dir, + representative_selection_path=representative_root_dir / "representative_selection.json", - pdb_no_solvent_dir=rmcsetup_dir / "pdb_no_solvent", - pdb_with_solvent_dir=rmcsetup_dir / "pdb_with_solvent", + representative_partial_solvent_dir=representative_root_dir + / "partialsolv", + pdb_no_solvent_dir=representative_root_dir / "nosolv", + pdb_with_solvent_dir=representative_root_dir / "fullsolv", packmol_inputs_dir=rmcsetup_dir / "packmol_inputs", constraints_dir=rmcsetup_dir / "constraints", reports_dir=rmcsetup_dir / "reports", @@ -77,6 +82,7 @@ def build_rmcsetup_paths( / "distribution_selection.json", solution_properties_path=rmcsetup_dir / "solution_properties.json", solvent_handling_path=rmcsetup_dir / "solvent_handling.json", + packmol_docker_link_path=rmcsetup_dir / "packmol_docker_link.json", packmol_plan_path=rmcsetup_dir / "packmol_plan.json", packmol_setup_path=rmcsetup_dir / "packmol_setup.json", constraint_generation_path=rmcsetup_dir / "constraints.json", @@ -114,6 +120,7 @@ def ensure_rmcsetup_structure( for directory in ( paths.rmcsetup_dir, paths.representative_clusters_dir, + paths.representative_partial_solvent_dir, paths.pdb_no_solvent_dir, paths.pdb_with_solvent_dir, paths.packmol_inputs_dir, @@ -125,6 +132,7 @@ def ensure_rmcsetup_structure( _ensure_json_file(paths.distribution_selection_path) _ensure_json_file(paths.solution_properties_path) _ensure_json_file(paths.solvent_handling_path) + _ensure_json_file(paths.packmol_docker_link_path) _ensure_json_file(paths.packmol_plan_path) _ensure_json_file(paths.packmol_setup_path) _ensure_json_file(paths.constraint_generation_path) @@ -252,14 +260,39 @@ def expected_cluster_inventory_rows( *, project_paths: ProjectPaths | None = None, ) -> list[dict[str, object]]: - if settings.cluster_inventory_rows: - return [dict(row) for row in settings.cluster_inventory_rows] paths = ( project_paths if project_paths is not None else build_project_paths(settings.project_dir) ) - prior_weights_path = paths.project_dir / "md_prior_weights.json" + settings_rows = [dict(row) for row in settings.cluster_inventory_rows] + artifact_paths = project_artifact_paths(settings) + artifact_rows = _load_cluster_inventory_rows_from_prior_weights( + artifact_paths.prior_weights_file + ) + if artifact_rows and ( + artifact_paths.uses_distribution_storage + or settings.use_predicted_structure_weights + ): + return artifact_rows + if settings_rows: + return settings_rows + if artifact_rows: + return artifact_rows + legacy_prior_weights_path = paths.project_dir / "md_prior_weights.json" + if ( + legacy_prior_weights_path.resolve() + == artifact_paths.prior_weights_file.resolve() + ): + return [] + return _load_cluster_inventory_rows_from_prior_weights( + legacy_prior_weights_path + ) + + +def _load_cluster_inventory_rows_from_prior_weights( + prior_weights_path: Path, +) -> list[dict[str, object]]: if not prior_weights_path.is_file(): return [] try: @@ -362,7 +395,10 @@ def _count_structure_files_in_dir(directory: Path) -> int: def _row_uses_cluster_directory(row: dict[str, object]) -> bool: source_kind = _optional_text(row.get("source_kind")) - if source_kind == "single_structure_file": + if source_kind in { + "single_structure_file", + "predicted_structure", + }: return False source_dir = _optional_text(row.get("source_dir")) source_file = _optional_text(row.get("source_file")) diff --git a/src/saxshell/fullrmc/representatives.py b/src/saxshell/fullrmc/representatives.py index fadeb55..d37ed06 100644 --- a/src/saxshell/fullrmc/representatives.py +++ b/src/saxshell/fullrmc/representatives.py @@ -6,7 +6,7 @@ from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Iterable import numpy as np @@ -29,6 +29,16 @@ _STRUCTURE_SUFFIXES = {".pdb", ".xyz"} _DEFAULT_QUANTILES = tuple(np.linspace(0.0, 1.0, 11).tolist()) +_PROJECT_REPRESENTATIVE_SOURCE_SOLVENT_MODES = ( + "nosolv", + "partialsolv", + "fullsolv", +) +_SOURCE_SOLVENT_MODE_BY_VARIANT = { + "no_solvent": "nosolv", + "partial_solvent": "partialsolv", + "full_solvent": "fullsolv", +} RepresentativeSelectionProgressCallback = Callable[[int, int, str], None] RepresentativeSelectionLogCallback = Callable[[str], None] @@ -252,11 +262,13 @@ class RepresentativeSelectionEntry: source_file_name: str atom_count: int element_counts: dict[str, int] + source_solvent_mode: str = "unknown" analysis_source: str = "first_valid_file" score_total: float | None = None score_bond: float | None = None score_angle: float | None = None cached_results_path: str | None = None + project_cached_results_path: str | None = None def to_dict(self) -> dict[str, object]: return asdict(self) @@ -266,14 +278,24 @@ def from_dict( cls, payload: dict[str, object], ) -> "RepresentativeSelectionEntry": + source_dir = str(payload.get("source_dir", "")).strip() + source_file = str(payload.get("source_file", "")).strip() + source_solvent_mode = normalize_representative_source_solvent_mode( + payload.get("source_solvent_mode") + ) + if source_solvent_mode == "unknown": + source_solvent_mode = _infer_representative_source_solvent_mode( + source_file=source_file, + source_dir=source_dir, + ) return cls( structure=str(payload.get("structure", "")).strip(), motif=str(payload.get("motif", "no_motif")).strip() or "no_motif", param=str(payload.get("param", "")).strip(), selected_weight=float(payload.get("selected_weight", 0.0)), cluster_count=int(payload.get("cluster_count", 0)), - source_dir=str(payload.get("source_dir", "")).strip(), - source_file=str(payload.get("source_file", "")).strip(), + source_dir=source_dir, + source_file=source_file, source_file_name=str(payload.get("source_file_name", "")).strip(), atom_count=int(payload.get("atom_count", 0)), element_counts={ @@ -282,6 +304,7 @@ def from_dict( payload.get("element_counts", {}) ).items() }, + source_solvent_mode=source_solvent_mode, analysis_source=str( payload.get("analysis_source", "first_valid_file") ).strip() @@ -292,6 +315,9 @@ def from_dict( cached_results_path=_optional_text( payload.get("cached_results_path") ), + project_cached_results_path=_optional_text( + payload.get("project_cached_results_path") + ), ) @@ -473,6 +499,154 @@ def all_series(self) -> tuple[RepresentativePreviewSeries, ...]: return self.bond_series + self.angle_series +def validate_representative_selection_covers_distribution( + metadata: RepresentativeSelectionMetadata, + *, + require_source_files: bool = True, +) -> None: + """Ensure one saved representative exists for each positive + weight.""" + + expected_entries = [ + entry + for entry in metadata.distribution_selection.entries + if float(entry.selected_weight) > 0.0 + ] + representative_entries = list(metadata.representative_entries) + representative_lookup: dict[ + tuple[str, str, str], list[RepresentativeSelectionEntry] + ] = {} + for entry in representative_entries: + representative_lookup.setdefault( + _representative_entry_key(entry), + [], + ).append(entry) + + errors: list[str] = [] + if expected_entries: + expected_keys = { + _distribution_entry_key(entry) for entry in expected_entries + } + missing_entries = [ + entry + for entry in expected_entries + if _distribution_entry_key(entry) not in representative_lookup + ] + duplicate_keys = [ + key + for key in sorted( + representative_lookup, + key=lambda item: ( + _natural_sort_key(item[0]), + _natural_sort_key(item[1]), + _natural_sort_key(item[2]), + ), + ) + if key in expected_keys and len(representative_lookup[key]) != 1 + ] + extra_entries = [ + entry + for key, entries in representative_lookup.items() + if key not in expected_keys + for entry in entries + ] + expected_issue_keys = set(expected_keys) + else: + missing_entries = [] + duplicate_keys = [ + key + for key in sorted( + representative_lookup, + key=lambda item: ( + _natural_sort_key(item[0]), + _natural_sort_key(item[1]), + _natural_sort_key(item[2]), + ), + ) + if len(representative_lookup[key]) != 1 + ] + extra_entries = [] + expected_issue_keys = set() + + if missing_entries: + errors.append( + "Missing representatives: " + + _format_coverage_labels( + _coverage_key_label(_distribution_entry_key(entry)) + for entry in missing_entries + ) + ) + if duplicate_keys: + errors.append( + "Duplicate representatives: " + + _format_coverage_labels( + ( + f"{_coverage_key_label(key)} x" + f"{len(representative_lookup[key])}" + ) + for key in duplicate_keys + ) + ) + if extra_entries: + errors.append( + "Representatives not present in the selected weight distribution: " + + _format_coverage_labels( + _representative_entry_label(entry) for entry in extra_entries + ) + ) + + if require_source_files: + missing_source_labels: list[str] = [] + for entry in representative_entries: + source_file = _optional_text(entry.source_file) + if source_file is None: + missing_source_labels.append( + f"{_representative_entry_label(entry)} has no source file" + ) + continue + source_path = Path(source_file).expanduser().resolve() + if not source_path.is_file(): + missing_source_labels.append( + f"{_representative_entry_label(entry)} -> {source_path}" + ) + if missing_source_labels: + errors.append( + "Representative source files not found: " + + _format_coverage_labels(missing_source_labels) + ) + + missing_issue_labels = [ + _representative_issue_label(issue) + for issue in metadata.missing_bins + if not expected_issue_keys + or _representative_issue_key(issue) in expected_issue_keys + ] + invalid_issue_labels = [ + _representative_issue_label(issue) + for issue in metadata.invalid_bins + if not expected_issue_keys + or _representative_issue_key(issue) in expected_issue_keys + ] + if missing_issue_labels: + errors.append( + "Representative selection still reports missing bins: " + + _format_coverage_labels(missing_issue_labels) + ) + if invalid_issue_labels: + errors.append( + "Representative selection still reports invalid bins: " + + _format_coverage_labels(invalid_issue_labels) + ) + + if errors: + raise ValueError( + "Representative structures do not match the selected weight " + "distribution. Save exactly one representative structure for " + "every positive weight before planning or building Packmol. " + + " ".join(errors) + ) + + def build_distribution_selection( project_source: RMCDreamProjectSource, selection: DreamBestFitSelection, @@ -810,6 +984,10 @@ def select_first_file_representatives( source_file_name=selected_file.name, atom_count=len(selected_elements), element_counts=dict(Counter(selected_elements)), + source_solvent_mode=_infer_representative_source_solvent_mode( + source_file=selected_file, + source_dir=source_dir or selected_file.parent, + ), ) ) processed_work += work_units @@ -1146,6 +1324,10 @@ def select_distribution_representatives( source_file_name=best_candidate.path.name, atom_count=best_candidate.atom_count, element_counts=dict(best_candidate.element_counts), + source_solvent_mode=_infer_representative_source_solvent_mode( + source_file=best_candidate.path, + source_dir=source_dir or best_candidate.path.parent, + ), analysis_source=analysis_source, score_total=score_total, score_bond=score_bond, @@ -1838,6 +2020,51 @@ def _natural_sort_key(value: str) -> list[object]: ] +def _representative_entry_key( + entry: RepresentativeSelectionEntry, +) -> tuple[str, str, str]: + return ( + str(entry.structure).strip(), + str(entry.motif).strip() or "no_motif", + str(entry.param).strip(), + ) + + +def _representative_issue_key( + issue: RepresentativeSelectionIssue, +) -> tuple[str, str, str]: + return ( + str(issue.structure).strip(), + str(issue.motif).strip() or "no_motif", + str(issue.param).strip(), + ) + + +def _coverage_key_label(key: tuple[str, str, str]) -> str: + structure, motif, param = key + if motif == "no_motif": + return f"{structure} ({param})" + return f"{structure}/{motif} ({param})" + + +def _representative_entry_label(entry: RepresentativeSelectionEntry) -> str: + return _coverage_key_label(_representative_entry_key(entry)) + + +def _representative_issue_label(issue: RepresentativeSelectionIssue) -> str: + return _coverage_key_label(_representative_issue_key(issue)) + + +def _format_coverage_labels(labels: Iterable[str]) -> str: + values = [str(label) for label in labels if str(label).strip()] + if not values: + return "none" + shown = values[:6] + if len(values) > len(shown): + shown.append(f"... +{len(values) - len(shown)} more") + return "; ".join(shown) + + def _optional_text(value: object) -> str | None: if value is None: return None @@ -1860,6 +2087,88 @@ def _definition_chunks(text: str) -> list[str]: ] +def normalize_representative_source_solvent_mode( + value: object, +) -> str: + text = str(value or "").strip().lower() + if text in _PROJECT_REPRESENTATIVE_SOURCE_SOLVENT_MODES: + return text + return "unknown" + + +def _infer_representative_source_solvent_mode( + *, + source_file: object = None, + source_dir: object = None, +) -> str: + for candidate in (source_file, source_dir): + text = _optional_text(candidate) + if text is None: + continue + parts = Path(text).expanduser().resolve().parts + for part in reversed(parts): + normalized = normalize_representative_source_solvent_mode(part) + if normalized != "unknown": + return normalized + return "unknown" + + +def representative_source_solvent_mode_to_variant( + value: object, +) -> str | None: + normalized = normalize_representative_source_solvent_mode(value) + if normalized == "nosolv": + return "no_solvent" + if normalized == "partialsolv": + return "partial_solvent" + if normalized == "fullsolv": + return "full_solvent" + return None + + +def representative_variant_to_source_solvent_mode( + value: object, +) -> str | None: + normalized = str(value or "").strip().lower() + if normalized in _PROJECT_REPRESENTATIVE_SOURCE_SOLVENT_MODES: + return normalized + return _SOURCE_SOLVENT_MODE_BY_VARIANT.get(normalized) + + +def representative_structure_variant_path( + source_file: str | Path, + variant: object, +) -> Path | None: + resolved_source = Path(source_file).expanduser().resolve() + target_mode = representative_variant_to_source_solvent_mode(variant) + if target_mode is None: + return None + if ( + normalize_representative_source_solvent_mode( + resolved_source.parent.parent.name + ) + == target_mode + and resolved_source.is_file() + ): + return resolved_source + if ( + normalize_representative_source_solvent_mode( + resolved_source.parent.parent.name + ) + == "unknown" + ): + return None + candidate = ( + resolved_source.parent.parent.parent + / target_mode + / resolved_source.parent.name + / resolved_source.name + ) + if candidate.is_file(): + return candidate.resolve() + return None + + __all__ = [ "DistributionSelectionEntry", "DistributionSelectionMetadata", @@ -1873,10 +2182,15 @@ def _definition_chunks(text: str) -> list[str]: "build_representative_preview_clusters", "load_distribution_selection_metadata", "load_representative_selection_metadata", + "normalize_representative_source_solvent_mode", "parse_angle_triplet_text", "parse_bond_pair_text", + "representative_source_solvent_mode_to_variant", + "representative_structure_variant_path", + "representative_variant_to_source_solvent_mode", "save_distribution_selection_metadata", "save_representative_selection_metadata", "select_distribution_representatives", "select_first_file_representatives", + "validate_representative_selection_covers_distribution", ] diff --git a/src/saxshell/fullrmc/solvent_handling.py b/src/saxshell/fullrmc/solvent_handling.py index 9d7ff5f..9206c33 100644 --- a/src/saxshell/fullrmc/solvent_handling.py +++ b/src/saxshell/fullrmc/solvent_handling.py @@ -3,13 +3,19 @@ import json import re from collections import Counter -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING import numpy as np -from saxshell.fullrmc.representatives import RepresentativeSelectionMetadata +from saxshell.fullrmc.representatives import ( + RepresentativeSelectionEntry, + RepresentativeSelectionMetadata, + representative_source_solvent_mode_to_variant, + representative_structure_variant_path, +) from saxshell.saxs.debye import load_structure_file from saxshell.structure import PDBAtom, PDBStructure from saxshell.xyz2pdb import ( @@ -18,8 +24,9 @@ resolve_reference_path, ) -if False: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from .project_loader import RMCDreamProjectSource + from .solvent_shell_builder import SolventShellAnalysisResult _ANCHOR_ELEMENT_PRIORITY = ( "O", @@ -46,18 +53,71 @@ "I": 126.90447, "Pb": 207.2, } +_DEFAULT_REFERENCE_MATCH_TOLERANCE_A = 0.25 +_REPRESENTATIVE_STRUCTURE_MODE_LABELS = { + "source": "Selected representative source files", + "no_solvent": "No solvent", + "partial_solvent": "Partial solvent", + "full_solvent": "Full solvent", +} + + +@dataclass(slots=True) +class SoluteAtomBuildSetting: + coordination_center: bool = False + target_coordination_number: float = 0.0 + director_distance_cutoff_a: float = 2.5 + + def to_dict(self) -> dict[str, object]: + return { + "coordination_center": bool(self.coordination_center), + "target_coordination_number": float( + self.target_coordination_number + ), + "director_distance_cutoff_a": float( + self.director_distance_cutoff_a + ), + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object] | None, + ) -> "SoluteAtomBuildSetting": + source = dict(payload or {}) + return cls( + coordination_center=bool(source.get("coordination_center", False)), + target_coordination_number=max( + 0.0, + _float_value(source.get("target_coordination_number"), 0.0), + ), + director_distance_cutoff_a=max( + 0.0, + _float_value(source.get("director_distance_cutoff_a"), 2.5), + ), + ) @dataclass(slots=True) class SolventHandlingSettings: - coordinated_solvent_mode: str = "no_coordinated_solvent" + coordinated_solvent_mode: str = "automatic_detection" reference_source: str = "preset" preset_name: str = "dmf" custom_reference_path: str | None = None + reference_match_tolerance_a: float = _DEFAULT_REFERENCE_MATCH_TOLERANCE_A + director_atom_name: str | None = None minimum_solvent_atom_separation_a: float = 1.2 + solute_atom_settings: dict[str, SoluteAtomBuildSetting] = field( + default_factory=dict + ) def to_dict(self) -> dict[str, object]: - return asdict(self) + payload = asdict(self) + payload["solute_atom_settings"] = { + str(element): setting.to_dict() + for element, setting in sorted(self.solute_atom_settings.items()) + } + return payload @classmethod def from_dict( @@ -69,10 +129,10 @@ def from_dict( coordinated_solvent_mode=str( source.get( "coordinated_solvent_mode", - "no_coordinated_solvent", + "automatic_detection", ) ).strip() - or "no_coordinated_solvent", + or "automatic_detection", reference_source=str( source.get("reference_source", "preset") ).strip() @@ -81,6 +141,16 @@ def from_dict( custom_reference_path=_optional_text( source.get("custom_reference_path") ), + reference_match_tolerance_a=max( + 0.0, + _float_value( + source.get("reference_match_tolerance_a"), + _DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + ), + ), + director_atom_name=_optional_text( + source.get("director_atom_name") + ), minimum_solvent_atom_separation_a=max( 0.0, _float_value( @@ -88,6 +158,15 @@ def from_dict( 1.2, ), ), + solute_atom_settings={ + str(element): SoluteAtomBuildSetting.from_dict(dict(entry)) + for element, entry in ( + dict(source.get("solute_atom_settings", {})).items() + if isinstance(source.get("solute_atom_settings"), dict) + else [] + ) + if isinstance(entry, dict) + }, ) @@ -106,9 +185,22 @@ class SolventHandlingEntry: solvent_mode: str completion_strategy: str heuristic_note: str + detected_source_status: str = "unknown" + detected_complete_solvent_count: int = 0 + detected_partial_solvent_count: int = 0 + source_input_format: str = "" + matched_atom_count: int = 0 + unmatched_atom_count: int = 0 + solute_element_counts: dict[str, int] = field(default_factory=dict) + analysis_summary: str = "" + build_summary: str = "" def to_dict(self) -> dict[str, object]: - return asdict(self) + payload = asdict(self) + payload["solute_element_counts"] = dict( + sorted(self.solute_element_counts.items()) + ) + return payload @classmethod def from_dict( @@ -133,8 +225,43 @@ def from_dict( payload.get("completion_strategy", "") ).strip(), heuristic_note=str(payload.get("heuristic_note", "")).strip(), + detected_source_status=str( + payload.get("detected_source_status", "unknown") + ).strip() + or "unknown", + detected_complete_solvent_count=int( + payload.get("detected_complete_solvent_count", 0) + ), + detected_partial_solvent_count=int( + payload.get("detected_partial_solvent_count", 0) + ), + source_input_format=str( + payload.get("source_input_format", "") + ).strip(), + matched_atom_count=int(payload.get("matched_atom_count", 0)), + unmatched_atom_count=int(payload.get("unmatched_atom_count", 0)), + solute_element_counts={ + str(element): int(count) + for element, count in ( + dict(payload.get("solute_element_counts", {})).items() + if isinstance(payload.get("solute_element_counts"), dict) + else [] + ) + }, + analysis_summary=str(payload.get("analysis_summary", "")).strip(), + build_summary=str(payload.get("build_summary", "")).strip(), ) + @property + def representative_label(self) -> str: + if self.motif == "no_motif": + return self.structure + return f"{self.structure}/{self.motif}" + + @property + def detected_source_status_text(self) -> str: + return _solvent_state_text(self.detected_source_status) + @dataclass(slots=True) class SolventHandlingMetadata: @@ -144,6 +271,9 @@ class SolventHandlingMetadata: reference_residue_name: str updated_at: str representative_selection_mode: str + detected_distribution_status: str + detected_distribution_note: str + aggregate_solute_element_counts: dict[str, int] entries: list[SolventHandlingEntry] def to_dict(self) -> dict[str, object]: @@ -157,6 +287,11 @@ def to_dict(self) -> dict[str, object]: "representative_selection_mode": ( self.representative_selection_mode ), + "detected_distribution_status": self.detected_distribution_status, + "detected_distribution_note": self.detected_distribution_note, + "aggregate_solute_element_counts": dict( + sorted(self.aggregate_solute_element_counts.items()) + ), "entries": [entry.to_dict() for entry in self.entries], } @@ -182,6 +317,26 @@ def from_dict( representative_selection_mode=str( payload.get("representative_selection_mode", "") ).strip(), + detected_distribution_status=str( + payload.get("detected_distribution_status", "unknown") + ).strip() + or "unknown", + detected_distribution_note=str( + payload.get("detected_distribution_note", "") + ).strip(), + aggregate_solute_element_counts={ + str(element): int(count) + for element, count in ( + dict( + payload.get("aggregate_solute_element_counts", {}) + ).items() + if isinstance( + payload.get("aggregate_solute_element_counts"), + dict, + ) + else [] + ) + }, entries=[ SolventHandlingEntry.from_dict(dict(entry)) for entry in payload.get("entries", []) @@ -190,27 +345,62 @@ def from_dict( ) def summary_text(self) -> str: + active_mode = resolved_representative_structure_mode( + representative_metadata=None, + solvent_metadata=self, + ) lines = [ - f"Coordinated solvent mode: {self.settings.coordinated_solvent_mode}", f"Reference source: {self.settings.reference_source}", f"Reference molecule: {self.reference_name}", f"Reference residue: {self.reference_residue_name}", + ( + "Active representative structure set: " + f"{representative_structure_mode_label(active_mode)}" + ), + ( + "Reference match tolerance: " + f"{self.settings.reference_match_tolerance_a:.3g} A" + ), + ( + "Director atom: " + f"{self.settings.director_atom_name or 'auto-selected'}" + ), ( "Minimum solvent atom separation: " f"{self.settings.minimum_solvent_atom_separation_a:.3g} A" ), + ( + "Detected representative distribution state: " + f"{_solvent_state_text(self.detected_distribution_status)}" + ), f"Saved at: {self.updated_at}", f"Representative entries exported: {len(self.entries)}", ] + if self.detected_distribution_note: + lines.append(self.detected_distribution_note) + if self.aggregate_solute_element_counts: + lines.append( + "Recognized solute elements: " + + ", ".join( + f"{element}:{count}" + for element, count in sorted( + self.aggregate_solute_element_counts.items() + ) + ) + ) if self.entries: first = self.entries[0] lines.extend( [ "", "Example exported representative:", - f" {first.structure}/{first.motif}", + f" {first.representative_label}", + ( + " detected source status: " + f"{first.detected_source_status_text}" + ), f" no-solvent PDB: {Path(first.no_solvent_pdb).name}", - f" completed PDB: {Path(first.completed_pdb).name}", + f" decorated PDB: {Path(first.completed_pdb).name}", f" solvent atoms added: {first.solvent_atoms_added}", f" strategy: {first.completion_strategy}", ] @@ -218,6 +408,75 @@ def summary_text(self) -> str: return "\n".join(lines) +@dataclass(slots=True) +class RepresentativeSolventAnalysisEntry: + structure: str + motif: str + param: str + source_file: str + source_status: str + analysis_result: "SolventShellAnalysisResult" + + @property + def representative_label(self) -> str: + if self.motif == "no_motif": + return self.structure + return f"{self.structure}/{self.motif}" + + @property + def source_status_text(self) -> str: + return _solvent_state_text(self.source_status) + + +@dataclass(slots=True) +class RepresentativeSolventDistributionAnalysis: + reference_name: str + reference_path: str + reference_residue_name: str + representative_selection_mode: str + match_tolerance_a: float + distribution_status: str + distribution_note: str + aggregate_solute_element_counts: dict[str, int] + entries: list[RepresentativeSolventAnalysisEntry] + + @property + def build_required(self) -> bool: + return self.distribution_status != "complete_solvent" + + def summary_text(self) -> str: + lines = [ + f"Reference molecule: {self.reference_name}", + f"Reference residue: {self.reference_residue_name}", + f"Reference match tolerance: {self.match_tolerance_a:.3g} A", + ( + "Detected representative distribution state: " + f"{_solvent_state_text(self.distribution_status)}" + ), + f"Representative entries analyzed: {len(self.entries)}", + ] + if self.distribution_note: + lines.append(self.distribution_note) + if self.aggregate_solute_element_counts: + lines.append( + "Recognized solute elements: " + + ", ".join( + f"{element}:{count}" + for element, count in sorted( + self.aggregate_solute_element_counts.items() + ) + ) + ) + if self.entries: + lines.extend(["", "Detected representative states:"]) + for entry in self.entries: + lines.append( + f" {entry.representative_label}: " + f"{entry.source_status_text}" + ) + return "\n".join(lines) + + @dataclass(slots=True) class GeneratedPDBResidueSummary: residue_name: str @@ -378,80 +637,511 @@ def list_solvent_reference_presets() -> list[object]: return list_reference_library(default_reference_library_dir()) -def build_representative_solvent_outputs( +def representative_structure_mode_label(mode: str) -> str: + normalized = str(mode).strip() or "source" + return _REPRESENTATIVE_STRUCTURE_MODE_LABELS.get( + normalized, + normalized.replace("_", " "), + ) + + +def representative_structure_entry_key( + entry: RepresentativeSelectionEntry | SolventHandlingEntry, +) -> tuple[str, str, str]: + return ( + str(getattr(entry, "structure", "")).strip(), + str(getattr(entry, "motif", "no_motif")).strip() or "no_motif", + str(getattr(entry, "param", "")).strip(), + ) + + +def solvent_entry_lookup_for_representatives( + representative_metadata: RepresentativeSelectionMetadata | None, + solvent_metadata: SolventHandlingMetadata | None, +) -> dict[tuple[str, str, str], SolventHandlingEntry]: + if representative_metadata is None or solvent_metadata is None: + return {} + lookup = { + representative_structure_entry_key(entry): entry + for entry in solvent_metadata.entries + } + expected_keys = { + representative_structure_entry_key(entry) + for entry in representative_metadata.representative_entries + } + if not expected_keys or not expected_keys.issubset(lookup): + return {} + return {key: lookup[key] for key in expected_keys} + + +def _uniform_source_variant_mode( + representative_entries: list[object], +) -> str | None: + if not representative_entries: + return None + source_paths = [ + Path(entry.source_file).expanduser().resolve() + for entry in representative_entries + if str(entry.source_file).strip() + ] + if len(source_paths) != len(representative_entries) or not all( + path.is_file() for path in source_paths + ): + return None + variants = { + _representative_entry_source_variant(entry) + for entry in representative_entries + } + variants.discard(None) + if len(variants) == 1: + return next(iter(variants)) + return None + + +def _uniform_available_source_variants( + representative_entries: list[object], +) -> list[str]: + if not representative_entries: + return [] + mode_order = ("no_solvent", "partial_solvent", "full_solvent") + available_by_mode: list[str] = [] + for mode in mode_order: + if all( + _representative_entry_variant_path(entry, mode) is not None + for entry in representative_entries + ): + available_by_mode.append(mode) + return available_by_mode + + +def _representative_entry_variant_path( + entry: object, + mode: str, +) -> Path | None: + source_file = getattr(entry, "source_file", None) + if not str(source_file or "").strip(): + return None + return representative_structure_variant_path(source_file, mode) + + +def _representative_entry_source_variant(entry: object) -> str | None: + try: + source_solvent_mode = getattr(entry, "source_solvent_mode") + except AttributeError: + source_solvent_mode = None + variant = representative_source_solvent_mode_to_variant( + source_solvent_mode + ) + if variant is not None: + return variant + try: + detected_source_status = str( + getattr(entry, "detected_source_status") + ).strip() + except AttributeError: + detected_source_status = "" + if detected_source_status == "complete_solvent": + return "full_solvent" + if detected_source_status in {"partial_solvent", "no_solvent"}: + return detected_source_status + return None + + +def available_representative_structure_modes( + representative_metadata: RepresentativeSelectionMetadata | None, + solvent_metadata: SolventHandlingMetadata | None, +) -> list[str]: + representative_entries = ( + list(representative_metadata.representative_entries) + if representative_metadata is not None + else [] + ) + if not representative_entries: + representative_entries = ( + list(solvent_metadata.entries) + if solvent_metadata is not None + else [] + ) + if not representative_entries: + return [] + source_variant = _uniform_source_variant_mode(representative_entries) + + if representative_metadata is None or solvent_metadata is None: + if solvent_metadata is None: + available = _uniform_available_source_variants( + representative_entries + ) + if source_variant is not None and source_variant not in available: + available.append(source_variant) + return available or ["source"] + ordered_entries = list(solvent_metadata.entries) + else: + lookup = solvent_entry_lookup_for_representatives( + representative_metadata, + solvent_metadata, + ) + if not lookup: + return ( + [source_variant] if source_variant is not None else ["source"] + ) + ordered_entries = list(lookup.values()) + + available: list[str] = [] + if all( + Path(entry.no_solvent_pdb).expanduser().is_file() + for entry in ordered_entries + ): + available.append("no_solvent") + if ( + solvent_metadata.detected_distribution_status == "partial_solvent" + and all( + Path(entry.source_file).expanduser().is_file() + for entry in representative_entries + ) + ): + available.append("partial_solvent") + if all( + Path(entry.completed_pdb).expanduser().is_file() + for entry in ordered_entries + ): + available.append("full_solvent") + if source_variant is not None and source_variant not in available: + available.append(source_variant) + return available or ["source"] + + +def resolved_representative_structure_mode( + representative_metadata: RepresentativeSelectionMetadata | None, + solvent_metadata: SolventHandlingMetadata | None, + *, + preferred_mode: str | None = None, +) -> str: + available = available_representative_structure_modes( + representative_metadata, + solvent_metadata, + ) + if not available: + return "source" + requested = str( + preferred_mode + if preferred_mode is not None + else ( + solvent_metadata.settings.coordinated_solvent_mode + if solvent_metadata is not None + else "source" + ) + ).strip() + if requested in available: + return requested + for candidate in ( + "full_solvent", + "partial_solvent", + "no_solvent", + "source", + ): + if candidate in available: + return candidate + return available[0] + + +def representative_structure_mode_is_ready( + representative_metadata: RepresentativeSelectionMetadata | None, + solvent_metadata: SolventHandlingMetadata | None, +) -> bool: + return ( + resolved_representative_structure_mode( + representative_metadata, + solvent_metadata, + ) + == "full_solvent" + ) + + +def representative_structure_path_for_mode( + representative_entry: RepresentativeSelectionEntry, + solvent_entry: SolventHandlingEntry | None, + mode: str, +) -> Path: + normalized = str(mode).strip() or "source" + if normalized == "no_solvent" and solvent_entry is not None: + return Path(solvent_entry.no_solvent_pdb).expanduser().resolve() + if normalized == "full_solvent" and solvent_entry is not None: + return Path(solvent_entry.completed_pdb).expanduser().resolve() + mirrored_variant_path = _representative_entry_variant_path( + representative_entry, + normalized, + ) + if mirrored_variant_path is not None: + return mirrored_variant_path + if normalized == representative_source_solvent_mode_to_variant( + representative_entry.source_solvent_mode + ): + return Path(representative_entry.source_file).expanduser().resolve() + return Path(representative_entry.source_file).expanduser().resolve() + + +def analyze_representative_solvent_distribution( project_source: "RMCDreamProjectSource", settings: SolventHandlingSettings, *, representative_metadata: RepresentativeSelectionMetadata | None = None, -) -> SolventHandlingMetadata: +) -> RepresentativeSolventDistributionAnalysis: metadata = ( representative_metadata or project_source.representative_selection ) if metadata is None or not metadata.representative_entries: raise ValueError( - "Compute representative clusters before building solvent-aware representative PDBs." + "Save representative structures before analyzing representative solvent states." ) + from .solvent_shell_builder import analyze_solvent_shell + + reference_identifier = _reference_identifier(settings) reference_path = _resolve_reference_path(settings) reference_structure = PDBStructure.from_file(reference_path) if not reference_structure.atoms: raise ValueError( f"The solvent reference PDB has no atoms: {reference_path}" ) - reference_name = reference_path.stem - reference_residue = reference_structure.atoms[0].residue_name or "SOL" - source_kind_lookup = _representative_source_kind_lookup(metadata) - entries: list[SolventHandlingEntry] = [] + entries: list[RepresentativeSolventAnalysisEntry] = [] + aggregate_solute_counts: Counter[str] = Counter() for representative_entry in metadata.representative_entries: - key = ( - representative_entry.structure, - representative_entry.motif, - representative_entry.param, + analysis_result = analyze_solvent_shell( + representative_entry.source_file, + reference_identifier, + reference_match_tolerance_a=settings.reference_match_tolerance_a, + ) + source_status = _classify_source_solvent_status(analysis_result) + entries.append( + RepresentativeSolventAnalysisEntry( + structure=representative_entry.structure, + motif=representative_entry.motif, + param=representative_entry.param, + source_file=representative_entry.source_file, + source_status=source_status, + analysis_result=analysis_result, + ) + ) + aggregate_solute_counts.update(analysis_result.solute_element_counts) + distribution_status, distribution_note = _resolve_distribution_status( + entries + ) + return RepresentativeSolventDistributionAnalysis( + reference_name=reference_path.stem, + reference_path=str(reference_path), + reference_residue_name=reference_structure.atoms[0].residue_name + or "SOL", + representative_selection_mode=metadata.selection_mode, + match_tolerance_a=settings.reference_match_tolerance_a, + distribution_status=distribution_status, + distribution_note=distribution_note, + aggregate_solute_element_counts=dict( + sorted(aggregate_solute_counts.items()) + ), + entries=entries, + ) + + +def build_representative_solvent_outputs( + project_source: "RMCDreamProjectSource", + settings: SolventHandlingSettings, + *, + representative_metadata: RepresentativeSelectionMetadata | None = None, + distribution_analysis: ( + RepresentativeSolventDistributionAnalysis | None + ) = None, +) -> SolventHandlingMetadata: + metadata = ( + representative_metadata or project_source.representative_selection + ) + if metadata is None or not metadata.representative_entries: + raise ValueError( + "Save representative structures before building " + "solvent-decorated representative PDBs." + ) + analysis = ( + distribution_analysis + or analyze_representative_solvent_distribution( + project_source, + settings, + representative_metadata=metadata, + ) + ) + reference_path = _resolve_reference_path(settings) + reference_structure = PDBStructure.from_file(reference_path) + if not reference_structure.atoms: + raise ValueError( + f"The solvent reference PDB has no atoms: {reference_path}" ) - source_kind = source_kind_lookup.get(key) + reference_identifier = _reference_identifier(settings) + director_atom_name = ( + settings.director_atom_name + or _default_director_atom_name_for_settings(settings) + ) + if not director_atom_name: + raise ValueError( + "Select a solvent director atom before building representative solvent outputs." + ) + + entries: list[SolventHandlingEntry] = [] + for representative_entry, analysis_entry in zip( + metadata.representative_entries, + analysis.entries, + strict=True, + ): cluster_structure = _load_cluster_structure_as_pdb( representative_entry.source_file, structure_label=representative_entry.structure, ) - no_solvent_path = ( - project_source.rmcsetup_paths.pdb_no_solvent_dir - / _representative_pdb_name(representative_entry) + no_solvent_structure = _strip_detected_solvent_atoms( + cluster_structure, + analysis_entry.analysis_result, + ) + no_solvent_path = _representative_pdb_output_path( + project_source.rmcsetup_paths.pdb_no_solvent_dir, + representative_entry, ) - cluster_structure.write_pdb_file(no_solvent_path) + no_solvent_structure.write_pdb_file(no_solvent_path) - completed_structure = PDBStructure( - atoms=[atom.copy() for atom in cluster_structure.atoms], - source_name=cluster_structure.source_name, + completed_path = _representative_pdb_output_path( + project_source.rmcsetup_paths.pdb_with_solvent_dir, + representative_entry, ) + completed_structure: PDBStructure solvent_atoms_added = 0 solvent_molecules_added = 0 - completion_strategy = "copied_without_addition" - if source_kind == "single_structure_file": - completion_strategy = "preserved_single_structure_file" - elif ( - settings.coordinated_solvent_mode == "partial_coordinated_solvent" - ): - ( - completed_structure, - solvent_atoms_added, - solvent_molecules_added, - completion_strategy, - ) = _build_partial_coordinated_solvent_structure( - cluster_structure, - reference_structure, - minimum_atom_separation_a=( + completion_strategy = "" + build_summary = "" + if analysis.distribution_status == "complete_solvent": + completed_structure = _decorated_source_structure_as_pdb( + representative_entry.source_file, + structure_label=representative_entry.structure, + analysis_result=analysis_entry.analysis_result, + reference_residue_name=analysis.reference_residue_name, + ) + completed_structure.write_pdb_file(completed_path) + solvent_molecules_added = int( + analysis_entry.analysis_result.complete_solvent_molecule_count + ) + solvent_atoms_added = max( + len(completed_structure.atoms) + - len(no_solvent_structure.atoms), + 0, + ) + completion_strategy = "preserved_detected_complete_solvent" + build_summary = ( + "Completed representative PDB was passed through because " + "all representative structures already contained complete " + "solvent molecules." + ) + else: + from .solvent_shell_builder import build_solvent_shell_output + + build_input_path = no_solvent_path + build_analysis_result = None + solute_distance_cutoffs = _solute_distance_cutoffs_for_analysis( + settings, + analysis_entry.analysis_result, + ) + coordinating_center_elements = ( + _coordinating_center_elements_for_analysis( + settings, + analysis_entry.analysis_result, + ) + ) + target_coordination_numbers = ( + _target_coordination_numbers_for_analysis( + settings, + analysis_entry.analysis_result, + ) + ) + if analysis.distribution_status == "partial_solvent": + build_input_path = Path(representative_entry.source_file) + build_analysis_result = analysis_entry.analysis_result + elif ( + not coordinating_center_elements + or not target_coordination_numbers + ): + completed_structure = PDBStructure( + atoms=[atom.copy() for atom in no_solvent_structure.atoms], + source_name=no_solvent_structure.source_name, + ) + completed_structure.write_pdb_file(completed_path) + completion_strategy = ( + "preserved_without_matching_coordination_settings" + ) + build_summary = ( + "No matching coordination-center settings were " + "selected for this representative structure, so the " + "stripped no-solvent PDB was preserved without adding " + "solvent molecules." + ) + entries.append( + SolventHandlingEntry( + structure=representative_entry.structure, + motif=representative_entry.motif, + param=representative_entry.param, + source_file=representative_entry.source_file, + no_solvent_pdb=str(no_solvent_path), + completed_pdb=str(completed_path), + atom_count_no_solvent=len(no_solvent_structure.atoms), + atom_count_completed=len(completed_structure.atoms), + solvent_atoms_added=0, + solvent_molecules_added=0, + solvent_mode=analysis.distribution_status, + completion_strategy=completion_strategy, + heuristic_note=analysis_entry.analysis_result.cluster_solvent_status_text, + detected_source_status=analysis_entry.source_status, + detected_complete_solvent_count=int( + analysis_entry.analysis_result.complete_solvent_molecule_count + ), + detected_partial_solvent_count=int( + analysis_entry.analysis_result.partial_solvent_molecule_count + ), + source_input_format=analysis_entry.analysis_result.input_format, + matched_atom_count=int( + analysis_entry.analysis_result.matched_atom_count + ), + unmatched_atom_count=int( + analysis_entry.analysis_result.unmatched_atom_count + ), + solute_element_counts=dict( + sorted( + analysis_entry.analysis_result.solute_element_counts.items() + ) + ), + analysis_summary=analysis_entry.analysis_result.summary_text(), + build_summary=build_summary, + ) + ) + continue + build_result = build_solvent_shell_output( + build_input_path, + reference_identifier, + output_path=completed_path, + director_atom_name=director_atom_name, + minimum_solvent_atom_separation_a=( settings.minimum_solvent_atom_separation_a ), + solute_distance_cutoffs_a=solute_distance_cutoffs, + coordinating_center_elements=coordinating_center_elements, + target_average_coordination_numbers=target_coordination_numbers, + reference_match_tolerance_a=( + settings.reference_match_tolerance_a + ), + analysis_result=build_analysis_result, ) - elif settings.coordinated_solvent_mode == "full_coordinated_solvent": - completion_strategy = "assumed_source_already_complete" - - completed_path = ( - project_source.rmcsetup_paths.pdb_with_solvent_dir - / _representative_pdb_name(representative_entry) - ) - completed_structure.write_pdb_file(completed_path) + completed_structure = PDBStructure.from_file(completed_path) + solvent_atoms_added = int(build_result.solvent_atoms_added) + solvent_molecules_added = int(build_result.solvent_molecules_added) + completion_strategy = str(build_result.build_mode) + if ( + analysis.distribution_status == "no_solvent" + and analysis.distribution_note + ): + completion_strategy = "rebuilt_from_no_solvent_distribution" + build_summary = build_result.summary_text() entries.append( SolventHandlingEntry( @@ -461,26 +1151,51 @@ def build_representative_solvent_outputs( source_file=representative_entry.source_file, no_solvent_pdb=str(no_solvent_path), completed_pdb=str(completed_path), - atom_count_no_solvent=len(cluster_structure.atoms), + atom_count_no_solvent=len(no_solvent_structure.atoms), atom_count_completed=len(completed_structure.atoms), solvent_atoms_added=solvent_atoms_added, solvent_molecules_added=solvent_molecules_added, - solvent_mode=settings.coordinated_solvent_mode, + solvent_mode=analysis.distribution_status, completion_strategy=completion_strategy, - heuristic_note=_solvent_heuristic_note( - cluster_structure, - reference_structure, + heuristic_note=analysis_entry.analysis_result.cluster_solvent_status_text, + detected_source_status=analysis_entry.source_status, + detected_complete_solvent_count=int( + analysis_entry.analysis_result.complete_solvent_molecule_count + ), + detected_partial_solvent_count=int( + analysis_entry.analysis_result.partial_solvent_molecule_count ), + source_input_format=analysis_entry.analysis_result.input_format, + matched_atom_count=int( + analysis_entry.analysis_result.matched_atom_count + ), + unmatched_atom_count=int( + analysis_entry.analysis_result.unmatched_atom_count + ), + solute_element_counts=dict( + sorted( + analysis_entry.analysis_result.solute_element_counts.items() + ) + ), + analysis_summary=analysis_entry.analysis_result.summary_text(), + build_summary=build_summary, ) ) + saved_settings = SolventHandlingSettings.from_dict(settings.to_dict()) + saved_settings.coordinated_solvent_mode = "full_solvent" solvent_metadata = SolventHandlingMetadata( - settings=settings, + settings=saved_settings, reference_path=str(reference_path), - reference_name=reference_name, - reference_residue_name=reference_residue, + reference_name=analysis.reference_name, + reference_residue_name=analysis.reference_residue_name, updated_at=datetime.now().isoformat(timespec="seconds"), representative_selection_mode=metadata.selection_mode, + detected_distribution_status=analysis.distribution_status, + detected_distribution_note=analysis.distribution_note, + aggregate_solute_element_counts=dict( + sorted(analysis.aggregate_solute_element_counts.items()) + ), entries=entries, ) save_solvent_handling_metadata( @@ -503,6 +1218,189 @@ def _representative_source_kind_lookup( } +def _reference_identifier(settings: SolventHandlingSettings) -> str: + if ( + settings.reference_source == "custom" + and settings.custom_reference_path + ): + return str(Path(settings.custom_reference_path).expanduser().resolve()) + return settings.preset_name + + +def _default_director_atom_name_for_settings( + settings: SolventHandlingSettings, +) -> str | None: + from .solvent_shell_builder import default_director_atom_name + + return default_director_atom_name(_reference_identifier(settings)) + + +def _classify_source_solvent_status( + analysis_result: "SolventShellAnalysisResult", +) -> str: + if analysis_result.complete_solvent_molecule_count > 0: + if analysis_result.partial_solvent_molecule_count > 0: + return "mixed_complete_and_partial" + return "complete_solvent" + if analysis_result.partial_solvent_molecule_count > 0: + return "partial_solvent" + return "no_solvent" + + +def _resolve_distribution_status( + entries: list[RepresentativeSolventAnalysisEntry], +) -> tuple[str, str]: + statuses = { + entry.source_status + for entry in entries + if str(entry.source_status).strip() + } + if not statuses: + return ( + "unknown", + "No representative solvent states were available.", + ) + if statuses == {"complete_solvent"}: + return ( + "complete_solvent", + "Every representative structure already contains complete solvent molecules, so the existing solvent-decorated structures can be passed through.", + ) + if statuses == {"partial_solvent"}: + return ( + "partial_solvent", + "Every representative structure contains partial solvent molecules, so the saved anchors will be used to rebuild complete solvent molecules.", + ) + if statuses == {"no_solvent"}: + return ( + "no_solvent", + "No representative structure contains coordinated solvent molecules, so solvent shells will be built from the stripped solute structures.", + ) + return ( + "no_solvent", + "Representative solvent detections were inconsistent across the saved structures. Following the conservative workflow rule, the current cluster distribution is treated as having no coordinated solvent.", + ) + + +def _solvent_state_text(status: str) -> str: + mapping = { + "complete_solvent": "Complete solvent molecules detected", + "partial_solvent": "Partial solvent molecules detected", + "no_solvent": "No solvent molecules detected", + "mixed_complete_and_partial": ( + "Complete and partial solvent molecules detected" + ), + "unknown": "Unknown solvent state", + } + return mapping.get(str(status).strip(), str(status).replace("_", " ")) + + +def _strip_detected_solvent_atoms( + structure: PDBStructure, + analysis_result: "SolventShellAnalysisResult", +) -> PDBStructure: + stripped_atom_ids = { + int(atom_id) + for atom_id in analysis_result.complete_solvent_source_atom_ids + } + stripped_atom_ids.update( + int(atom_id) + for atom_id in analysis_result.partial_solvent_source_atom_ids + ) + stripped_atoms = [ + atom.copy() + for atom in structure.atoms + if int(atom.atom_id) not in stripped_atom_ids + ] + for index, atom in enumerate(stripped_atoms, start=1): + atom.atom_id = index + return PDBStructure( + atoms=stripped_atoms, + source_name=structure.source_name, + ) + + +def _decorated_source_structure_as_pdb( + source_file: str | Path, + *, + structure_label: str, + analysis_result: "SolventShellAnalysisResult", + reference_residue_name: str, +) -> PDBStructure: + structure = _load_cluster_structure_as_pdb( + source_file, + structure_label=structure_label, + ) + path = Path(source_file).expanduser().resolve() + if path.suffix.lower() == ".pdb": + return structure + + group_lookup = { + int(atom_id): group_index + for group_index, atom_group in enumerate( + analysis_result.complete_solvent_source_atom_groups, + start=1, + ) + for atom_id in atom_group + } + solute_residue_name = _normalized_residue_name(structure_label) + next_residue_number = 2 + solvent_residue_numbers: dict[int, int] = {} + for atom in structure.atoms: + group_index = group_lookup.get(int(atom.atom_id)) + if group_index is None: + atom.residue_name = solute_residue_name + atom.residue_number = 1 + continue + residue_number = solvent_residue_numbers.get(group_index) + if residue_number is None: + residue_number = next_residue_number + solvent_residue_numbers[group_index] = residue_number + next_residue_number += 1 + atom.residue_name = reference_residue_name + atom.residue_number = residue_number + return structure + + +def _solute_distance_cutoffs_for_analysis( + settings: SolventHandlingSettings, + analysis_result: "SolventShellAnalysisResult", +) -> dict[str, float]: + return { + str(element): float(setting.director_distance_cutoff_a) + for element, setting in settings.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and float(setting.director_distance_cutoff_a) > 0.0 + } + + +def _coordinating_center_elements_for_analysis( + settings: SolventHandlingSettings, + analysis_result: "SolventShellAnalysisResult", +) -> tuple[str, ...]: + return tuple( + sorted( + element + for element, setting in settings.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and setting.coordination_center + and float(setting.target_coordination_number) > 0.0 + ) + ) + + +def _target_coordination_numbers_for_analysis( + settings: SolventHandlingSettings, + analysis_result: "SolventShellAnalysisResult", +) -> dict[str, float]: + return { + str(element): float(setting.target_coordination_number) + for element, setting in settings.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and setting.coordination_center + and float(setting.target_coordination_number) > 0.0 + } + + def save_solvent_handling_metadata( output_path: str | Path, metadata: SolventHandlingMetadata, @@ -1094,6 +1992,20 @@ def _representative_pdb_name(representative_entry: object) -> str: return f"{structure}_{motif}_{source_name}.pdb" +def _representative_pdb_output_path( + root_dir: str | Path, + representative_entry: object, +) -> Path: + structure = _safe_name( + getattr(representative_entry, "structure", "cluster") + ) + return ( + Path(root_dir).expanduser().resolve() + / structure + / _representative_pdb_name(representative_entry) + ) + + def _safe_name(text: str) -> str: collapsed = re.sub(r"[^0-9A-Za-z]+", "_", str(text).strip()) collapsed = re.sub(r"_+", "_", collapsed).strip("_") @@ -1245,9 +2157,13 @@ def _float_value(value: object, default: float) -> float: __all__ = [ "GeneratedPDBInspection", "GeneratedPDBResidueSummary", + "RepresentativeSolventAnalysisEntry", + "RepresentativeSolventDistributionAnalysis", + "SoluteAtomBuildSetting", "SolventHandlingEntry", "SolventHandlingMetadata", "SolventHandlingSettings", + "analyze_representative_solvent_distribution", "build_generated_pdb_inspections", "build_representative_solvent_outputs", "list_solvent_reference_presets", diff --git a/src/saxshell/fullrmc/solvent_shell_builder.py b/src/saxshell/fullrmc/solvent_shell_builder.py new file mode 100644 index 0000000..f1225c6 --- /dev/null +++ b/src/saxshell/fullrmc/solvent_shell_builder.py @@ -0,0 +1,2931 @@ +from __future__ import annotations + +from collections import Counter, defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Sequence + +import numpy as np + +from saxshell.fullrmc.solvent_handling import ( + _best_solvent_orientation, + _next_residue_number, + _normalize_vector, + _normalized_residue_name, + _rotate_points, + _rotation_between_vectors, + _weighted_center, +) +from saxshell.saxs.debye import load_structure_file +from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb import ( + AnchorPairDefinition, + MoleculeDefinition, + ReferenceLibraryEntry, + XYZToPDBConfiguration, + XYZToPDBWorkflow, + default_reference_library_dir, + list_reference_library, +) +from saxshell.xyz2pdb.workflow import _normalized_atom_name + +_DEFAULT_REFERENCE_MATCH_TOLERANCE_A = 0.25 +DEFAULT_REFERENCE_MATCH_TOLERANCE_A = _DEFAULT_REFERENCE_MATCH_TOLERANCE_A + + +@dataclass(slots=True) +class SolventShellResidueSummary: + residue_name: str + molecule_count: int + residue_numbers: tuple[int, ...] + atom_count: int + element_counts: dict[str, int] + + @property + def residue_numbers_text(self) -> str: + if not self.residue_numbers: + return "n/a" + return ", ".join(str(number) for number in self.residue_numbers) + + @property + def element_counts_text(self) -> str: + if not self.element_counts: + return "none" + return ", ".join( + f"{element}:{count}" + for element, count in sorted(self.element_counts.items()) + ) + + +@dataclass(slots=True) +class SolventShellResidueMismatchSummary: + residue_name: str + residue_number: int + observed_atom_count: int + common_atom_count: int + reference_atom_count: int + missing_atom_names: tuple[str, ...] + extra_atom_names: tuple[str, ...] + distance_pair_count: int + distribution_rmsd_a: float + max_distance_delta_a: float + mismatch_reason: str + source_atom_ids: tuple[int, ...] = () + + @property + def residue_label(self) -> str: + return f"{self.residue_name} {self.residue_number}" + + @property + def missing_atom_names_text(self) -> str: + if not self.missing_atom_names: + return "none" + return ", ".join(self.missing_atom_names) + + @property + def extra_atom_names_text(self) -> str: + if not self.extra_atom_names: + return "none" + return ", ".join(self.extra_atom_names) + + @property + def matched_atom_ratio_text(self) -> str: + return f"{self.common_atom_count}/{self.reference_atom_count}" + + @property + def source_atom_ids_text(self) -> str: + if not self.source_atom_ids: + return "n/a" + return ", ".join(str(value) for value in self.source_atom_ids) + + +@dataclass(slots=True) +class SolventShellAnalysisResult: + input_path: Path + input_format: str + reference_name: str + reference_path: Path + reference_residue_name: str + reference_atom_count: int + detected_solvent_molecules: int + matched_atom_count: int + unmatched_atom_count: int + total_atoms: int + match_tolerance_a: float + solute_element_counts: dict[str, int] + complete_solvent_source_atom_ids: tuple[int, ...] = () + complete_solvent_source_atom_groups: tuple[tuple[int, ...], ...] = () + partial_solvent_source_atom_ids: tuple[int, ...] = () + matched_residue_summaries: tuple[SolventShellResidueSummary, ...] = () + residue_mismatch_summaries: tuple[ + SolventShellResidueMismatchSummary, ... + ] = () + notes: tuple[str, ...] = () + + @property + def has_solvent_molecules(self) -> bool: + return ( + self.complete_solvent_molecule_count > 0 + or self.partial_solvent_molecule_count > 0 + ) + + @property + def complete_solvent_molecule_count(self) -> int: + return int(self.detected_solvent_molecules) + + @property + def partial_solvent_molecule_count(self) -> int: + return int(len(self.residue_mismatch_summaries)) + + @property + def partial_solvent_status_supported(self) -> bool: + return True + + @property + def complete_solvent_status_text(self) -> str: + return "yes" if self.complete_solvent_molecule_count > 0 else "no" + + @property + def partial_solvent_status_text(self) -> str: + return "yes" if self.partial_solvent_molecule_count > 0 else "no" + + @property + def no_solvent_status_text(self) -> str: + return "yes" if not self.has_solvent_molecules else "no" + + @property + def cluster_solvent_status_text(self) -> str: + if self.complete_solvent_molecule_count > 0: + if self.partial_solvent_molecule_count > 0: + return "Complete and partial solvent molecules detected." + return "Complete solvent molecules detected." + if self.partial_solvent_molecule_count > 0: + return "Partial solvent molecules detected." + return "No solvent molecules detected." + + @property + def solvent_presence_text(self) -> str: + return "yes" if self.has_solvent_molecules else "no" + + @property + def solute_elements_text(self) -> str: + if not self.solute_element_counts: + return "none" + return ", ".join( + f"{element}:{count}" + for element, count in sorted(self.solute_element_counts.items()) + ) + + def status_statistics_text(self) -> str: + lines = [ + f"No solvent molecules: {self.no_solvent_status_text}", + f"Partial solvent molecules: {self.partial_solvent_status_text}", + f"Complete solvent molecules: {self.complete_solvent_status_text}", + f"Complete solvent count: {self.complete_solvent_molecule_count}", + ] + if self.input_format == "pdb": + lines.append( + "Partial solvent residue count: " + f"{self.partial_solvent_molecule_count}" + ) + else: + lines.append( + "Partial solvent candidate count: " + f"{self.partial_solvent_molecule_count}" + ) + lines.extend( + [ + f"Recognized solute elements: {self.solute_elements_text}", + f"Matched atoms: {self.matched_atom_count}/{self.total_atoms}", + f"Unmatched atoms: {self.unmatched_atom_count}", + ] + ) + return "\n".join(lines) + + def summary_text(self) -> str: + lines = [ + f"Input file: {self.input_path}", + f"Input format: {self.input_format.upper()}", + f"Reference molecule: {self.reference_name}", + f"Reference residue: {self.reference_residue_name}", + f"Reference atom count: {self.reference_atom_count}", + f"Reference match tolerance: {self.match_tolerance_a:.3g} A", + f"Total atoms: {self.total_atoms}", + f"Cluster solvent status: {self.cluster_solvent_status_text}", + f"No solvent molecules: {self.no_solvent_status_text}", + f"Partial solvent molecules: {self.partial_solvent_status_text}", + f"Complete solvent molecules: {self.complete_solvent_status_text}", + f"Complete solvent count: {self.complete_solvent_molecule_count}", + f"Solvent molecules detected: {self.detected_solvent_molecules}", + f"Solvent present: {self.solvent_presence_text}", + f"Recognized solute elements: {self.solute_elements_text}", + f"Matched atoms: {self.matched_atom_count}", + f"Unmatched atoms: {self.unmatched_atom_count}", + ] + if self.input_format == "pdb": + lines.append( + "Partial solvent residue count: " + f"{self.partial_solvent_molecule_count}" + ) + lines.append( + "Matched residue types: " + f"{len(self.matched_residue_summaries)}" + ) + lines.append( + "Residue mismatches preserved: " + f"{len(self.residue_mismatch_summaries)}" + ) + else: + lines.append( + "Partial solvent candidate count: " + f"{self.partial_solvent_molecule_count}" + ) + lines.append("Matched residue types: n/a for XYZ inputs") + lines.append( + "Partial solvent candidates inferred: " + f"{len(self.residue_mismatch_summaries)}" + ) + if self.matched_residue_summaries: + lines.extend( + [ + "", + "PDB residue matches:", + ] + ) + for summary in self.matched_residue_summaries: + lines.append( + f" {summary.residue_name}: {summary.molecule_count} " + f"molecule(s) in residue(s) {summary.residue_numbers_text}" + ) + if self.residue_mismatch_summaries: + lines.extend( + [ + "", + ( + "PDB residue mismatches:" + if self.input_format == "pdb" + else "XYZ partial solvent candidates:" + ), + ] + ) + for summary in self.residue_mismatch_summaries: + detail_parts = [ + f"matched {summary.matched_atom_ratio_text} reference atom(s)", + ] + if summary.missing_atom_names: + detail_parts.append( + f"missing {summary.missing_atom_names_text}" + ) + if summary.extra_atom_names: + detail_parts.append( + f"extra {summary.extra_atom_names_text}" + ) + if summary.source_atom_ids: + detail_parts.append( + f"source atom ids {summary.source_atom_ids_text}" + ) + lines.append( + f" {summary.residue_label}: {summary.mismatch_reason}; " + + ", ".join(detail_parts) + ) + if self.notes: + lines.extend(["", "Notes:"]) + lines.extend(f" - {note}" for note in self.notes) + return "\n".join(lines) + + +@dataclass(slots=True) +class SolventShellBuildResult: + input_path: Path + output_path: Path + input_format: str + reference_name: str + reference_residue_name: str + director_atom_name: str + build_mode: str + solvent_molecules_added: int + solvent_atoms_added: int + partial_candidates_completed: int + replaced_source_atom_count: int + minimum_solvent_atom_separation_a: float + solute_distance_cutoffs_a: dict[str, float] + coordinating_center_elements: tuple[str, ...] = () + target_average_coordination_numbers: dict[str, float] | None = None + achieved_average_coordination_numbers: dict[str, float] | None = None + + def summary_text(self) -> str: + cutoff_text = ( + ", ".join( + f"{element}:{distance:.3g}" + for element, distance in sorted( + self.solute_distance_cutoffs_a.items() + ) + ) + if self.solute_distance_cutoffs_a + else "none" + ) + center_text = ( + ", ".join(self.coordinating_center_elements) + if self.coordinating_center_elements + else "none" + ) + target_coordination_text = ( + ", ".join( + f"{element}:{value:.3g}" + for element, value in sorted( + (self.target_average_coordination_numbers or {}).items() + ) + ) + if self.target_average_coordination_numbers + else "none" + ) + achieved_coordination_text = ( + ", ".join( + f"{element}:{value:.3g}" + for element, value in sorted( + (self.achieved_average_coordination_numbers or {}).items() + ) + ) + if self.achieved_average_coordination_numbers + else "none" + ) + return "\n".join( + [ + f"Output file: {self.output_path}", + f"Build mode: {self.build_mode}", + f"Input format: {self.input_format.upper()}", + f"Reference molecule: {self.reference_name}", + f"Reference residue: {self.reference_residue_name}", + f"Director atom: {self.director_atom_name}", + ( + "Minimum solvent atom separation: " + f"{self.minimum_solvent_atom_separation_a:.3g} A" + ), + f"Solute distance cutoffs: {cutoff_text}", + f"Coordinating center elements: {center_text}", + f"Target average coordination: {target_coordination_text}", + f"Achieved average coordination: {achieved_coordination_text}", + f"Solvent molecules added: {self.solvent_molecules_added}", + f"Solvent atoms added: {self.solvent_atoms_added}", + ( + "Partial solvent candidates completed: " + f"{self.partial_candidates_completed}" + ), + ( + "Source atoms replaced during completion: " + f"{self.replaced_source_atom_count}" + ), + ] + ) + + +@dataclass(slots=True) +class _MatchingAtom: + atom_id: int + element: str + coordinates: np.ndarray + + +@dataclass(slots=True) +class _MatchingFrame: + filepath: Path + atoms: list[_MatchingAtom] + + +def analyze_solvent_shell( + input_path: str | Path, + reference_name: str, + *, + reference_library_dir: str | Path | None = None, + reference_match_tolerance_a: float = _DEFAULT_REFERENCE_MATCH_TOLERANCE_A, +) -> SolventShellAnalysisResult: + resolved_input = Path(input_path).expanduser().resolve() + if not resolved_input.is_file(): + raise FileNotFoundError( + f"Input structure file was not found: {resolved_input}" + ) + + resolved_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + reference_entry = _resolve_reference_entry( + reference_name, + library_dir=resolved_library_dir, + ) + reference_path = reference_entry.path.expanduser().resolve() + reference_structure = PDBStructure.from_file(reference_path) + reference_atoms = tuple(atom.copy() for atom in reference_structure.atoms) + if not reference_atoms: + raise ValueError( + f"Reference molecule has no atoms: {reference_entry.name}" + ) + + if resolved_input.suffix.lower() == ".pdb": + return _analyze_pdb_input( + resolved_input, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_library_dir=resolved_library_dir, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + if resolved_input.suffix.lower() == ".xyz": + return _analyze_xyz_input( + resolved_input, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_library_dir=resolved_library_dir, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + raise ValueError( + "Solvent Shell Builder supports only PDB and XYZ input files." + ) + + +def build_solvent_shell_output( + input_path: str | Path, + reference_name: str, + *, + output_path: str | Path, + director_atom_name: str, + minimum_solvent_atom_separation_a: float, + solute_distance_cutoffs_a: dict[str, float], + coordinating_center_elements: Sequence[str] | None = None, + target_average_coordination_numbers: dict[str, float] | None = None, + reference_library_dir: str | Path | None = None, + reference_match_tolerance_a: float = _DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + analysis_result: SolventShellAnalysisResult | None = None, +) -> SolventShellBuildResult: + resolved_input = Path(input_path).expanduser().resolve() + resolved_output = Path(output_path).expanduser().resolve() + if analysis_result is None: + analysis_result = analyze_solvent_shell( + resolved_input, + reference_name, + reference_library_dir=reference_library_dir, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + + resolved_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + reference_entry = _resolve_reference_entry( + reference_name, + library_dir=resolved_library_dir, + ) + reference_path = reference_entry.path.expanduser().resolve() + reference_structure = PDBStructure.from_file(reference_path) + if not reference_structure.atoms: + raise ValueError( + f"Reference molecule has no atoms: {reference_entry.name}" + ) + director_atom_index = _resolve_reference_director_atom_index( + reference_structure, + director_atom_name=director_atom_name, + ) + + input_structure = _load_input_structure_as_pdb( + resolved_input, + structure_label=resolved_input.stem, + ) + atoms_to_replace = { + int(atom_id) + for atom_id in analysis_result.partial_solvent_source_atom_ids + } + build_mode = "partial_solvent_completion" + partial_anchor_positions = _partial_candidate_anchor_positions( + input_structure=input_structure, + analysis_result=analysis_result, + reference_structure=reference_structure, + director_atom_index=director_atom_index, + ) + partial_candidate_count = len(partial_anchor_positions) + if not partial_anchor_positions: + if analysis_result.complete_solvent_molecule_count > 0: + raise ValueError( + "The analyzed structure already contains complete solvent " + "molecules and does not expose partial solvent candidates to rebuild." + ) + build_mode = "no_solvent_shell_build" + + remaining_atoms = [ + atom.copy() + for atom in input_structure.atoms + if int(atom.atom_id) not in atoms_to_replace + ] + selected_coordination_elements = tuple( + sorted( + { + str(element) + for element in (coordinating_center_elements or ()) + if str(element).strip() + } + ) + ) + target_coordination_by_element = { + str(element): max(float(value), 0.0) + for element, value in ( + target_average_coordination_numbers or {} + ).items() + if str(element).strip() and max(float(value), 0.0) > 0.0 + } + placed_partial_atoms = _place_anchor_positions( + reference_structure=reference_structure, + director_atom_index=director_atom_index, + anchor_positions=partial_anchor_positions, + solute_atoms=remaining_atoms, + occupied_atoms=remaining_atoms, + starting_atom_id=len(remaining_atoms) + 1, + starting_residue_number=_next_residue_number(remaining_atoms), + minimum_atom_separation_a=max( + float(minimum_solvent_atom_separation_a), + 0.0, + ), + require_clearance=False, + ) + coordination_center_atoms = _coordination_center_atoms( + input_structure=input_structure, + analysis_result=analysis_result, + coordinating_center_elements=selected_coordination_elements, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + anchor_positions_for_counting = [ + position.copy() for position in partial_anchor_positions + ] + additional_placed_atoms = _build_coordination_target_solvent_atoms( + reference_structure=reference_structure, + director_atom_index=director_atom_index, + solute_atoms=remaining_atoms, + occupied_atoms=remaining_atoms + placed_partial_atoms, + center_atoms=coordination_center_atoms, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + target_average_coordination_numbers=target_coordination_by_element, + existing_anchor_positions=anchor_positions_for_counting, + starting_atom_id=len(remaining_atoms) + len(placed_partial_atoms) + 1, + starting_residue_number=( + _next_residue_number(remaining_atoms + placed_partial_atoms) + ), + minimum_atom_separation_a=max( + float(minimum_solvent_atom_separation_a), + 0.0, + ), + ) + if ( + not partial_anchor_positions + and not additional_placed_atoms + and not coordination_center_atoms + ): + raise ValueError( + "No solvent anchor positions could be determined. Select at least " + "one coordinating center element, provide its director-distance " + "cutoff, and set a target average coordination number." + ) + if ( + not partial_anchor_positions + and not additional_placed_atoms + and coordination_center_atoms + ): + raise ValueError( + "No solvent molecules could be placed from the selected " + "coordination targets while respecting the current cutoff and " + "minimum-separation settings." + ) + placed_atoms = placed_partial_atoms + additional_placed_atoms + completed_atoms = remaining_atoms + placed_atoms + for atom_id, atom in enumerate(completed_atoms, start=1): + atom.atom_id = atom_id + output_structure = PDBStructure( + atoms=completed_atoms, + source_name=resolved_output.stem, + ) + resolved_output.parent.mkdir(parents=True, exist_ok=True) + output_structure.write_pdb_file(resolved_output) + placed_molecule_count = ( + int(len(placed_atoms) / max(len(reference_structure.atoms), 1)) + if reference_structure.atoms + else 0 + ) + achieved_coordination_by_element = _average_coordination_by_element( + center_atoms=coordination_center_atoms, + anchor_positions=anchor_positions_for_counting, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + return SolventShellBuildResult( + input_path=resolved_input, + output_path=resolved_output, + input_format=analysis_result.input_format, + reference_name=reference_entry.name, + reference_residue_name=reference_entry.residue_name, + director_atom_name=str( + reference_structure.atoms[director_atom_index].atom_name + ), + build_mode=build_mode, + solvent_molecules_added=placed_molecule_count, + solvent_atoms_added=max(len(placed_atoms) - len(atoms_to_replace), 0), + partial_candidates_completed=partial_candidate_count, + replaced_source_atom_count=len(atoms_to_replace), + minimum_solvent_atom_separation_a=float( + minimum_solvent_atom_separation_a + ), + solute_distance_cutoffs_a=dict( + sorted( + ( + str(element), + float(distance), + ) + for element, distance in solute_distance_cutoffs_a.items() + if float(distance) > 0.0 + ) + ), + coordinating_center_elements=selected_coordination_elements, + target_average_coordination_numbers=dict( + sorted(target_coordination_by_element.items()) + ), + achieved_average_coordination_numbers=achieved_coordination_by_element, + ) + + +def reference_atom_choices( + reference_name: str, + *, + reference_library_dir: str | Path | None = None, +) -> tuple[str, ...]: + resolved_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + reference_entry = _resolve_reference_entry( + reference_name, + library_dir=resolved_library_dir, + ) + structure = PDBStructure.from_file( + reference_entry.path.expanduser().resolve() + ) + return tuple(str(atom.atom_name) for atom in structure.atoms) + + +def default_director_atom_name( + reference_name: str, + *, + reference_library_dir: str | Path | None = None, +) -> str | None: + resolved_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + reference_entry = _resolve_reference_entry( + reference_name, + library_dir=resolved_library_dir, + ) + structure = PDBStructure.from_file( + reference_entry.path.expanduser().resolve() + ) + if not structure.atoms: + return None + director_index = _default_director_atom_index(structure) + if director_index is None: + return None + return str(structure.atoms[director_index].atom_name) + + +def _load_input_structure_as_pdb( + input_path: Path, + *, + structure_label: str, +) -> PDBStructure: + if input_path.suffix.lower() == ".pdb": + source_structure = PDBStructure.from_file(input_path) + copied_atoms = [atom.copy() for atom in source_structure.atoms] + for index, atom in enumerate(copied_atoms, start=1): + atom.atom_id = index + return PDBStructure( + atoms=copied_atoms, + source_name=input_path.stem, + ) + + positions, elements = load_structure_file(input_path) + residue_name = _normalized_residue_name(structure_label) + counters: dict[str, int] = {} + atoms: list[object] = [] + for index, (coordinates, element) in enumerate( + zip(positions, elements, strict=True), + start=1, + ): + counters[str(element)] = counters.get(str(element), 0) + 1 + atoms.append( + PDBAtom( + atom_id=index, + atom_name=f"{element}{counters[str(element)]}", + residue_name=residue_name, + residue_number=1, + coordinates=np.asarray(coordinates, dtype=float), + element=str(element), + ) + ) + return PDBStructure(atoms=atoms, source_name=input_path.stem) + + +def _resolve_reference_director_atom_index( + reference_structure: PDBStructure, + *, + director_atom_name: str, +) -> int: + normalized_name = _normalized_atom_name( + director_atom_name, + fallback="DIR1", + ) + for index, atom in enumerate(reference_structure.atoms): + if ( + _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{index + 1}", + ) + == normalized_name + ): + return index + raise ValueError( + f"Director atom {director_atom_name!r} was not found in the selected solvent reference." + ) + + +def _default_director_atom_index( + reference_structure: PDBStructure, +) -> int | None: + if not reference_structure.atoms: + return None + oxygen_indices = [ + index + for index, atom in enumerate(reference_structure.atoms) + if str(atom.element).upper() == "O" + ] + if len(oxygen_indices) == 1: + return oxygen_indices[0] + available_elements = { + str(atom.element).upper() for atom in reference_structure.atoms + } + return _select_partial_anchor_index( + reference_structure.atoms, + available_elements=available_elements, + ) + + +def _partial_candidate_anchor_positions( + *, + input_structure: PDBStructure, + analysis_result: SolventShellAnalysisResult, + reference_structure: PDBStructure, + director_atom_index: int, +) -> list[np.ndarray]: + if not analysis_result.residue_mismatch_summaries: + return [] + atoms_by_id = {int(atom.atom_id): atom for atom in input_structure.atoms} + director_atom = reference_structure.atoms[director_atom_index] + director_name = _normalized_atom_name( + str(director_atom.atom_name), + fallback=f"{director_atom.element}{director_atom_index + 1}", + ) + director_element = str(director_atom.element).upper() + positions: list[np.ndarray] = [] + for summary in analysis_result.residue_mismatch_summaries: + candidate_atoms = [ + atoms_by_id[int(atom_id)] + for atom_id in summary.source_atom_ids + if int(atom_id) in atoms_by_id + ] + if not candidate_atoms: + continue + director_matches = [ + atom + for atom in candidate_atoms + if _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{int(atom.atom_id)}", + ) + == director_name + ] + if director_matches: + positions.append(director_matches[0].coordinates.copy()) + continue + element_matches = [ + atom + for atom in candidate_atoms + if str(atom.element).upper() == director_element + ] + if element_matches: + positions.append(element_matches[0].coordinates.copy()) + continue + positions.append( + np.mean( + np.asarray( + [atom.coordinates for atom in candidate_atoms], + dtype=float, + ), + axis=0, + ) + ) + return positions + + +def _no_solvent_anchor_positions( + *, + input_structure: PDBStructure, + analysis_result: SolventShellAnalysisResult, + solute_distance_cutoffs_a: dict[str, float], +) -> list[np.ndarray]: + cutoff_by_element = { + str(element): max(float(distance), 0.0) + for element, distance in solute_distance_cutoffs_a.items() + if max(float(distance), 0.0) > 0.0 + } + if not cutoff_by_element: + return [] + solvent_like_ids = { + int(atom_id) + for atom_id in analysis_result.complete_solvent_source_atom_ids + }.union( + int(atom_id) + for atom_id in analysis_result.partial_solvent_source_atom_ids + ) + solute_atoms = [ + atom.copy() + for atom in input_structure.atoms + if int(atom.atom_id) not in solvent_like_ids + and str(atom.element) in cutoff_by_element + ] + if not solute_atoms: + return [] + cluster_center = _weighted_center(solute_atoms) + positions: list[np.ndarray] = [] + for atom in solute_atoms: + outward_vector = _normalize_vector( + atom.coordinates - cluster_center, + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + cutoff_distance = cutoff_by_element.get(str(atom.element)) + if cutoff_distance is None or cutoff_distance <= 0.0: + continue + positions.append( + np.asarray(atom.coordinates, dtype=float) + + outward_vector * float(cutoff_distance) + ) + return positions + + +def _place_anchor_positions( + *, + reference_structure: PDBStructure, + director_atom_index: int, + anchor_positions: Sequence[np.ndarray], + solute_atoms: list[PDBAtom], + occupied_atoms: list[PDBAtom], + starting_atom_id: int, + starting_residue_number: int, + minimum_atom_separation_a: float, + require_clearance: bool, +) -> list[PDBAtom]: + placed_atoms: list[PDBAtom] = [] + next_atom_id = int(starting_atom_id) + next_residue_number = int(starting_residue_number) + current_occupied_atoms = [atom.copy() for atom in occupied_atoms] + for anchor_position in anchor_positions: + trial_atoms, clearance_met, _min_distance = ( + _trial_place_solvent_molecule( + reference_structure=reference_structure, + director_atom_index=director_atom_index, + anchor_position=np.asarray( + anchor_position, dtype=float + ).copy(), + solute_atoms=solute_atoms, + occupied_atoms=current_occupied_atoms, + starting_atom_id=next_atom_id, + residue_number=next_residue_number, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + ) + if require_clearance and not clearance_met: + continue + placed_atoms.extend(trial_atoms) + current_occupied_atoms.extend(atom.copy() for atom in trial_atoms) + next_atom_id += len(trial_atoms) + next_residue_number += 1 + return placed_atoms + + +def _trial_place_solvent_molecule( + *, + reference_structure: PDBStructure, + director_atom_index: int, + anchor_position: np.ndarray, + solute_atoms: list[PDBAtom], + occupied_atoms: list[PDBAtom], + starting_atom_id: int, + residue_number: int, + minimum_atom_separation_a: float, +) -> tuple[list[PDBAtom], bool, float]: + reference_atoms = [atom.copy() for atom in reference_structure.atoms] + reference_anchor = reference_atoms[director_atom_index] + reference_anchor_coord = reference_anchor.coordinates.copy() + reference_offsets = np.asarray( + [ + atom.coordinates - reference_anchor_coord + for atom in reference_atoms + ], + dtype=float, + ) + reference_body_atoms = [ + atom + for index, atom in enumerate(reference_atoms) + if index != director_atom_index + ] + reference_body_center = _weighted_center(reference_body_atoms) + reference_body_vector = reference_body_center - reference_anchor_coord + if np.linalg.norm(reference_body_vector) <= 1e-8 and reference_body_atoms: + reference_body_vector = ( + reference_body_atoms[0].coordinates - reference_anchor_coord + ) + if np.linalg.norm(reference_body_vector) <= 1e-8: + reference_body_vector = np.array([1.0, 0.0, 0.0], dtype=float) + + solute_center = _weighted_center( + solute_atoms if solute_atoms else occupied_atoms + ) + outward_vector = _normalize_vector( + np.asarray(anchor_position, dtype=float) - solute_center, + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + alignment = _rotation_between_vectors( + reference_body_vector, + outward_vector, + ) + aligned_offsets = _rotate_points(reference_offsets, alignment) + occupied_coords = [atom.coordinates.copy() for atom in occupied_atoms] + candidate_coords, min_distance, clearance_met = _best_solvent_orientation( + aligned_offsets, + anchor_position=np.asarray(anchor_position, dtype=float), + outward_vector=outward_vector, + occupied_coords=occupied_coords, + excluded_index=director_atom_index, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + placed_atoms: list[PDBAtom] = [] + next_atom_id = int(starting_atom_id) + for reference_atom, coordinate in zip( + reference_atoms, + np.asarray(candidate_coords, dtype=float), + strict=True, + ): + atom = reference_atom.copy() + atom.atom_id = next_atom_id + atom.residue_number = int(residue_number) + atom.coordinates = np.asarray(coordinate, dtype=float).copy() + placed_atoms.append(atom) + next_atom_id += 1 + return placed_atoms, bool(clearance_met), float(min_distance) + + +def _coordination_center_atoms( + *, + input_structure: PDBStructure, + analysis_result: SolventShellAnalysisResult, + coordinating_center_elements: Sequence[str], + solute_distance_cutoffs_a: dict[str, float], +) -> list[PDBAtom]: + selected_elements = { + str(element) + for element in coordinating_center_elements + if str(element).strip() + and float(solute_distance_cutoffs_a.get(str(element), 0.0)) > 0.0 + } + if not selected_elements: + return [] + solvent_like_ids = { + int(atom_id) + for atom_id in analysis_result.complete_solvent_source_atom_ids + }.union( + int(atom_id) + for atom_id in analysis_result.partial_solvent_source_atom_ids + ) + return [ + atom.copy() + for atom in input_structure.atoms + if int(atom.atom_id) not in solvent_like_ids + and str(atom.element) in selected_elements + ] + + +def _build_coordination_target_solvent_atoms( + *, + reference_structure: PDBStructure, + director_atom_index: int, + solute_atoms: list[PDBAtom], + occupied_atoms: list[PDBAtom], + center_atoms: list[PDBAtom], + solute_distance_cutoffs_a: dict[str, float], + target_average_coordination_numbers: dict[str, float], + existing_anchor_positions: list[np.ndarray], + starting_atom_id: int, + starting_residue_number: int, + minimum_atom_separation_a: float, +) -> list[PDBAtom]: + if not center_atoms or not target_average_coordination_numbers: + return [] + current_counts = _coordination_count_by_center_atom( + center_atoms=center_atoms, + anchor_positions=existing_anchor_positions, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + placed_atoms: list[PDBAtom] = [] + occupied_atom_copies = [atom.copy() for atom in occupied_atoms] + next_atom_id = int(starting_atom_id) + next_residue_number = int(starting_residue_number) + used_candidate_keys: set[tuple[float, float, float]] = set() + while _coordination_targets_unmet( + center_atoms=center_atoms, + current_counts=current_counts, + target_average_coordination_numbers=target_average_coordination_numbers, + ): + candidate_positions = _coordination_candidate_positions( + center_atoms=center_atoms, + solute_atoms=solute_atoms, + existing_anchor_positions=existing_anchor_positions, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + if not candidate_positions: + break + best_choice: ( + tuple[ + float, + int, + list[PDBAtom], + np.ndarray, + tuple[int, ...], + float, + ] + | None + ) = None + for candidate_position in candidate_positions: + candidate_key = tuple( + float(value) + for value in np.round(candidate_position, decimals=4) + ) + if candidate_key in used_candidate_keys: + continue + coordinated_center_ids = _coordinated_center_atom_ids( + center_atoms=center_atoms, + anchor_position=candidate_position, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + if not coordinated_center_ids: + continue + benefit_score = _coordination_candidate_benefit( + center_atoms=center_atoms, + current_counts=current_counts, + coordinated_center_ids=coordinated_center_ids, + target_average_coordination_numbers=target_average_coordination_numbers, + ) + if benefit_score <= 1e-8: + continue + ( + refined_position, + trial_atoms, + clearance_met, + min_distance, + ) = _refine_anchor_position( + candidate_position=candidate_position, + center_atoms=center_atoms, + coordinated_center_ids=coordinated_center_ids, + solute_atoms=solute_atoms, + occupied_atoms=occupied_atom_copies, + reference_structure=reference_structure, + director_atom_index=director_atom_index, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + refined_center_ids = _coordinated_center_atom_ids( + center_atoms=center_atoms, + anchor_position=refined_position, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + if not refined_center_ids: + continue + benefit_score = _coordination_candidate_benefit( + center_atoms=center_atoms, + current_counts=current_counts, + coordinated_center_ids=refined_center_ids, + target_average_coordination_numbers=target_average_coordination_numbers, + ) + if benefit_score <= 1e-8: + continue + if not clearance_met: + continue + candidate_rank = ( + benefit_score, + len(refined_center_ids), + min_distance, + ) + if best_choice is None or candidate_rank > ( + best_choice[0], + best_choice[1], + best_choice[5], + ): + best_choice = ( + benefit_score, + len(refined_center_ids), + trial_atoms, + refined_position.copy(), + refined_center_ids, + min_distance, + ) + if best_choice is None: + break + ( + _benefit, + _coordination_count, + accepted_atoms, + accepted_position, + center_ids, + _accepted_min_distance, + ) = best_choice + placed_atoms.extend(accepted_atoms) + occupied_atom_copies.extend(atom.copy() for atom in accepted_atoms) + existing_anchor_positions.append(accepted_position.copy()) + for center_id in center_ids: + current_counts[int(center_id)] = ( + current_counts.get(int(center_id), 0) + 1 + ) + used_candidate_keys.add( + tuple(float(value) for value in np.round(accepted_position, 4)) + ) + next_atom_id += len(accepted_atoms) + next_residue_number += 1 + return placed_atoms + + +def _coordination_candidate_positions( + *, + center_atoms: list[PDBAtom], + solute_atoms: list[PDBAtom], + existing_anchor_positions: Sequence[np.ndarray], + solute_distance_cutoffs_a: dict[str, float], +) -> list[np.ndarray]: + if not center_atoms: + return [] + cluster_center = _weighted_center( + solute_atoms if solute_atoms else center_atoms + ) + candidates: list[np.ndarray] = [] + seen_keys: set[tuple[float, float, float]] = set() + for center_atom in center_atoms: + cutoff_distance = max( + float( + solute_distance_cutoffs_a.get(str(center_atom.element), 0.0) + ), + 0.0, + ) + if cutoff_distance <= 0.0: + continue + for candidate_position in _octahedral_candidate_positions_for_center( + center_atom=center_atom, + cutoff_distance=cutoff_distance, + solute_atoms=solute_atoms, + existing_anchor_positions=existing_anchor_positions, + cluster_center=cluster_center, + ): + candidate_key = tuple( + float(value) + for value in np.round(candidate_position, decimals=4) + ) + if candidate_key in seen_keys: + continue + seen_keys.add(candidate_key) + candidates.append(candidate_position) + for direction in _single_center_directions( + center_coordinates=center_atom.coordinates, + cluster_center=cluster_center, + ): + candidate_position = ( + np.asarray(center_atom.coordinates, dtype=float) + + direction * cutoff_distance + ) + candidate_key = tuple( + float(value) + for value in np.round(candidate_position, decimals=4) + ) + if candidate_key in seen_keys: + continue + seen_keys.add(candidate_key) + candidates.append(candidate_position) + for index, center_atom in enumerate(center_atoms): + cutoff_a = max( + float( + solute_distance_cutoffs_a.get(str(center_atom.element), 0.0) + ), + 0.0, + ) + if cutoff_a <= 0.0: + continue + for other_atom in center_atoms[index + 1 :]: + cutoff_b = max( + float( + solute_distance_cutoffs_a.get(str(other_atom.element), 0.0) + ), + 0.0, + ) + if cutoff_b <= 0.0: + continue + for candidate_position in _pair_center_intersection_positions( + center_a=center_atom.coordinates, + cutoff_a=cutoff_a, + center_b=other_atom.coordinates, + cutoff_b=cutoff_b, + cluster_center=cluster_center, + ): + candidate_key = tuple( + float(value) + for value in np.round(candidate_position, decimals=4) + ) + if candidate_key in seen_keys: + continue + seen_keys.add(candidate_key) + candidates.append(candidate_position) + return candidates + + +def _refine_anchor_position( + *, + candidate_position: np.ndarray, + center_atoms: list[PDBAtom], + coordinated_center_ids: Sequence[int], + solute_atoms: list[PDBAtom], + occupied_atoms: list[PDBAtom], + reference_structure: PDBStructure, + director_atom_index: int, + solute_distance_cutoffs_a: dict[str, float], + minimum_atom_separation_a: float, +) -> tuple[np.ndarray, list[PDBAtom], bool, float]: + center_by_id = {int(atom.atom_id): atom for atom in center_atoms} + coordinated_centers = [ + center_by_id[int(center_id)] + for center_id in coordinated_center_ids + if int(center_id) in center_by_id + ] + best_position = _project_anchor_position_to_coordination_shells( + anchor_position=np.asarray(candidate_position, dtype=float).copy(), + coordinated_centers=coordinated_centers, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + best_atoms, best_clearance, best_min_distance = ( + _trial_place_solvent_molecule( + reference_structure=reference_structure, + director_atom_index=director_atom_index, + anchor_position=best_position, + solute_atoms=solute_atoms, + occupied_atoms=occupied_atoms, + starting_atom_id=1, + residue_number=1, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + ) + best_score = _anchor_refinement_score( + anchor_position=best_position, + coordinated_centers=coordinated_centers, + clearance_met=best_clearance, + min_distance=best_min_distance, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + radial_direction = _normalize_vector( + best_position - _weighted_center(coordinated_centers), + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + tangent_a = _orthogonal_unit_vector(radial_direction) + tangent_b = _normalize_vector( + np.cross(radial_direction, tangent_a), + fallback=np.array([0.0, 1.0, 0.0], dtype=float), + ) + search_directions = ( + radial_direction, + -radial_direction, + tangent_a, + -tangent_a, + tangent_b, + -tangent_b, + _normalize_vector( + radial_direction + tangent_a, + fallback=radial_direction, + ), + _normalize_vector( + radial_direction - tangent_a, + fallback=radial_direction, + ), + _normalize_vector( + radial_direction + tangent_b, + fallback=radial_direction, + ), + _normalize_vector( + radial_direction - tangent_b, + fallback=radial_direction, + ), + ) + for step_size in (0.35, 0.18, 0.08): + improved = True + while improved: + improved = False + for direction in search_directions: + trial_position = ( + _project_anchor_position_to_coordination_shells( + anchor_position=best_position + + np.asarray(direction, dtype=float) + * float(step_size), + coordinated_centers=coordinated_centers, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + ) + ( + trial_atoms, + trial_clearance, + trial_min_distance, + ) = _trial_place_solvent_molecule( + reference_structure=reference_structure, + director_atom_index=director_atom_index, + anchor_position=trial_position, + solute_atoms=solute_atoms, + occupied_atoms=occupied_atoms, + starting_atom_id=1, + residue_number=1, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + trial_score = _anchor_refinement_score( + anchor_position=trial_position, + coordinated_centers=coordinated_centers, + clearance_met=trial_clearance, + min_distance=trial_min_distance, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + minimum_atom_separation_a=minimum_atom_separation_a, + ) + if trial_score + 1e-6 >= best_score: + continue + best_position = trial_position + best_atoms = trial_atoms + best_clearance = trial_clearance + best_min_distance = trial_min_distance + best_score = trial_score + improved = True + break + return best_position, best_atoms, best_clearance, best_min_distance + + +def _project_anchor_position_to_coordination_shells( + *, + anchor_position: np.ndarray, + coordinated_centers: Sequence[PDBAtom], + solute_distance_cutoffs_a: dict[str, float], + iterations: int = 8, +) -> np.ndarray: + projected = np.asarray(anchor_position, dtype=float).copy() + if not coordinated_centers: + return projected + for _ in range(max(int(iterations), 1)): + correction = np.zeros(3, dtype=float) + contributing_centers = 0 + for center_atom in coordinated_centers: + cutoff_distance = max( + float( + solute_distance_cutoffs_a.get( + str(center_atom.element), 0.0 + ) + ), + 0.0, + ) + if cutoff_distance <= 0.0: + continue + displacement = projected - np.asarray( + center_atom.coordinates, dtype=float + ) + observed_distance = float(np.linalg.norm(displacement)) + if observed_distance <= 1e-8: + displacement = np.array([1.0, 0.0, 0.0], dtype=float) + observed_distance = 1.0 + correction += ( + (cutoff_distance - observed_distance) + * displacement + / observed_distance + ) + contributing_centers += 1 + if contributing_centers == 0: + break + projected += correction / float(contributing_centers) + if float(np.linalg.norm(correction)) <= 1e-6: + break + return projected + + +def _anchor_refinement_score( + *, + anchor_position: np.ndarray, + coordinated_centers: Sequence[PDBAtom], + clearance_met: bool, + min_distance: float, + solute_distance_cutoffs_a: dict[str, float], + minimum_atom_separation_a: float, +) -> float: + anchor = np.asarray(anchor_position, dtype=float) + score = 0.0 + for center_atom in coordinated_centers: + cutoff_distance = max( + float( + solute_distance_cutoffs_a.get(str(center_atom.element), 0.0) + ), + 0.0, + ) + if cutoff_distance <= 0.0: + continue + observed_distance = float( + np.linalg.norm( + anchor - np.asarray(center_atom.coordinates, dtype=float) + ) + ) + normalized_error = (observed_distance - cutoff_distance) / max( + cutoff_distance, + 1e-6, + ) + score += normalized_error * normalized_error + if observed_distance > cutoff_distance + 0.2: + score += 2.0 * (observed_distance - cutoff_distance) + clearance_gap = max( + float(minimum_atom_separation_a) - float(min_distance), 0.0 + ) + score += 20.0 * clearance_gap * clearance_gap + if not clearance_met: + score += 5.0 + score -= 0.05 * max(float(min_distance), 0.0) + return score + + +def _octahedral_candidate_positions_for_center( + *, + center_atom: PDBAtom, + cutoff_distance: float, + solute_atoms: list[PDBAtom], + existing_anchor_positions: Sequence[np.ndarray], + cluster_center: np.ndarray, +) -> tuple[np.ndarray, ...]: + neighbor_vectors = _existing_first_shell_neighbor_vectors( + center_atom=center_atom, + solute_atoms=solute_atoms, + existing_anchor_positions=existing_anchor_positions, + cutoff_distance=cutoff_distance, + ) + octahedral_directions = _octahedral_direction_frame( + existing_neighbor_vectors=neighbor_vectors, + preferred_direction=_normalize_vector( + np.asarray(center_atom.coordinates, dtype=float) + - np.asarray(cluster_center, dtype=float), + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ), + ) + occupied_indices = _occupied_octahedral_direction_indices( + existing_neighbor_vectors=neighbor_vectors, + octahedral_directions=octahedral_directions, + ) + return tuple( + np.asarray(center_atom.coordinates, dtype=float) + + np.asarray(direction, dtype=float) * float(cutoff_distance) + for index, direction in enumerate(octahedral_directions) + if index not in occupied_indices + ) + + +def _existing_first_shell_neighbor_vectors( + *, + center_atom: PDBAtom, + solute_atoms: list[PDBAtom], + existing_anchor_positions: Sequence[np.ndarray], + cutoff_distance: float, +) -> tuple[np.ndarray, ...]: + coordination_radius = float(cutoff_distance) + 0.35 + center_coordinates = np.asarray(center_atom.coordinates, dtype=float) + neighbor_vectors: list[np.ndarray] = [] + for atom in solute_atoms: + if int(atom.atom_id) == int(center_atom.atom_id): + continue + displacement = ( + np.asarray(atom.coordinates, dtype=float) - center_coordinates + ) + distance = float(np.linalg.norm(displacement)) + if distance <= 1e-8 or distance > coordination_radius: + continue + neighbor_vectors.append( + _normalize_vector( + displacement, + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + ) + for anchor_position in existing_anchor_positions: + displacement = ( + np.asarray(anchor_position, dtype=float) - center_coordinates + ) + distance = float(np.linalg.norm(displacement)) + if distance <= 1e-8 or distance > coordination_radius: + continue + neighbor_vectors.append( + _normalize_vector( + displacement, + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + ) + unique_vectors: list[np.ndarray] = [] + for vector in neighbor_vectors: + if any( + float(np.dot(vector, other)) > 0.95 for other in unique_vectors + ): + continue + unique_vectors.append(vector) + return tuple(unique_vectors) + + +def _octahedral_direction_frame( + *, + existing_neighbor_vectors: Sequence[np.ndarray], + preferred_direction: np.ndarray, +) -> tuple[np.ndarray, ...]: + if existing_neighbor_vectors: + primary_axis = _normalize_vector( + np.asarray(existing_neighbor_vectors[0], dtype=float), + fallback=preferred_direction, + ) + else: + primary_axis = _normalize_vector( + preferred_direction, + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + secondary_seed: np.ndarray | None = None + for vector in existing_neighbor_vectors[1:]: + projected = np.asarray(vector, dtype=float) - primary_axis * float( + np.dot(np.asarray(vector, dtype=float), primary_axis) + ) + if float(np.linalg.norm(projected)) > 0.25: + secondary_seed = projected + break + if secondary_seed is None: + preferred_projected = np.asarray( + preferred_direction, dtype=float + ) - primary_axis * float( + np.dot(np.asarray(preferred_direction, dtype=float), primary_axis) + ) + if float(np.linalg.norm(preferred_projected)) > 0.25: + secondary_seed = preferred_projected + if secondary_seed is None: + secondary_axis = _orthogonal_unit_vector(primary_axis) + else: + secondary_axis = _normalize_vector( + secondary_seed, + fallback=_orthogonal_unit_vector(primary_axis), + ) + tertiary_axis = _normalize_vector( + np.cross(primary_axis, secondary_axis), + fallback=np.array([0.0, 0.0, 1.0], dtype=float), + ) + secondary_axis = _normalize_vector( + np.cross(tertiary_axis, primary_axis), + fallback=secondary_axis, + ) + return ( + primary_axis, + -primary_axis, + secondary_axis, + -secondary_axis, + tertiary_axis, + -tertiary_axis, + ) + + +def _occupied_octahedral_direction_indices( + *, + existing_neighbor_vectors: Sequence[np.ndarray], + octahedral_directions: Sequence[np.ndarray], +) -> set[int]: + occupied_indices: set[int] = set() + for vector in existing_neighbor_vectors: + normalized_vector = _normalize_vector( + np.asarray(vector, dtype=float), + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + best_index = max( + range(len(octahedral_directions)), + key=lambda index: float( + np.dot( + normalized_vector, + np.asarray(octahedral_directions[index], dtype=float), + ) + ), + ) + if ( + float( + np.dot( + normalized_vector, + np.asarray(octahedral_directions[best_index], dtype=float), + ) + ) + < 0.55 + ): + continue + occupied_indices.add(int(best_index)) + return occupied_indices + + +def _single_center_directions( + *, + center_coordinates: np.ndarray, + cluster_center: np.ndarray, +) -> tuple[np.ndarray, ...]: + preferred = _normalize_vector( + np.asarray(center_coordinates, dtype=float) + - np.asarray(cluster_center, dtype=float), + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + basis1 = _orthogonal_unit_vector(preferred) + basis2 = _normalize_vector( + np.cross(preferred, basis1), + fallback=np.array([0.0, 1.0, 0.0], dtype=float), + ) + direction_vectors = [ + preferred, + preferred + 0.6 * basis1, + preferred - 0.6 * basis1, + preferred + 0.6 * basis2, + preferred - 0.6 * basis2, + preferred + 0.45 * basis1 + 0.45 * basis2, + preferred - 0.45 * basis1 + 0.45 * basis2, + preferred + 0.45 * basis1 - 0.45 * basis2, + preferred - 0.45 * basis1 - 0.45 * basis2, + ] + return tuple( + _normalize_vector( + np.asarray(direction, dtype=float), + fallback=preferred, + ) + for direction in direction_vectors + ) + + +def _pair_center_intersection_positions( + *, + center_a: np.ndarray, + cutoff_a: float, + center_b: np.ndarray, + cutoff_b: float, + cluster_center: np.ndarray, +) -> tuple[np.ndarray, ...]: + center_a = np.asarray(center_a, dtype=float) + center_b = np.asarray(center_b, dtype=float) + axis_vector = center_b - center_a + axis_distance = float(np.linalg.norm(axis_vector)) + if axis_distance <= 1e-8: + return () + if axis_distance > float(cutoff_a + cutoff_b) + 1e-6: + return () + if axis_distance < abs(float(cutoff_a - cutoff_b)) - 1e-6: + return () + axis_unit = axis_vector / axis_distance + offset_along_axis = ( + axis_distance * axis_distance + - float(cutoff_b) * float(cutoff_b) + + float(cutoff_a) * float(cutoff_a) + ) / (2.0 * axis_distance) + circle_center = center_a + axis_unit * offset_along_axis + circle_radius_sq = float(cutoff_a) * float(cutoff_a) - ( + offset_along_axis * offset_along_axis + ) + if circle_radius_sq < -1e-6: + return () + circle_radius = float(np.sqrt(max(circle_radius_sq, 0.0))) + if circle_radius <= 1e-8: + return (circle_center,) + outward_hint = np.asarray(circle_center, dtype=float) - np.asarray( + cluster_center, + dtype=float, + ) + plane_projection = outward_hint - axis_unit * float( + np.dot(outward_hint, axis_unit) + ) + basis1 = _normalize_vector( + plane_projection, + fallback=_orthogonal_unit_vector(axis_unit), + ) + basis2 = _normalize_vector( + np.cross(axis_unit, basis1), + fallback=np.array([0.0, 1.0, 0.0], dtype=float), + ) + return ( + circle_center + circle_radius * basis1, + circle_center - circle_radius * basis1, + circle_center + circle_radius * basis2, + circle_center - circle_radius * basis2, + ) + + +def _orthogonal_unit_vector(vector: np.ndarray) -> np.ndarray: + normalized = _normalize_vector( + np.asarray(vector, dtype=float), + fallback=np.array([1.0, 0.0, 0.0], dtype=float), + ) + if abs(float(normalized[0])) < 0.9: + return _normalize_vector( + np.cross(normalized, np.array([1.0, 0.0, 0.0], dtype=float)), + fallback=np.array([0.0, 1.0, 0.0], dtype=float), + ) + return _normalize_vector( + np.cross(normalized, np.array([0.0, 1.0, 0.0], dtype=float)), + fallback=np.array([0.0, 0.0, 1.0], dtype=float), + ) + + +def _coordination_count_by_center_atom( + *, + center_atoms: list[PDBAtom], + anchor_positions: Sequence[np.ndarray], + solute_distance_cutoffs_a: dict[str, float], +) -> dict[int, int]: + counts = {int(atom.atom_id): 0 for atom in center_atoms} + if not center_atoms or not anchor_positions: + return counts + center_by_id = {int(atom.atom_id): atom for atom in center_atoms} + for anchor_position in anchor_positions: + anchor = np.asarray(anchor_position, dtype=float) + for atom_id, atom in center_by_id.items(): + cutoff_distance = max( + float(solute_distance_cutoffs_a.get(str(atom.element), 0.0)), + 0.0, + ) + if cutoff_distance <= 0.0: + continue + if ( + float( + np.linalg.norm( + anchor - np.asarray(atom.coordinates, dtype=float) + ) + ) + <= cutoff_distance + 1e-6 + ): + counts[int(atom_id)] = counts.get(int(atom_id), 0) + 1 + return counts + + +def _coordination_targets_unmet( + *, + center_atoms: list[PDBAtom], + current_counts: dict[int, int], + target_average_coordination_numbers: dict[str, float], +) -> bool: + achieved = _average_coordination_by_element_from_counts( + center_atoms=center_atoms, + current_counts=current_counts, + ) + for element, target_value in target_average_coordination_numbers.items(): + if achieved.get(str(element), 0.0) + 1e-6 < float(target_value): + return True + return False + + +def _coordinated_center_atom_ids( + *, + center_atoms: list[PDBAtom], + anchor_position: np.ndarray, + solute_distance_cutoffs_a: dict[str, float], +) -> tuple[int, ...]: + coordinated_ids: list[int] = [] + anchor = np.asarray(anchor_position, dtype=float) + for atom in center_atoms: + cutoff_distance = max( + float(solute_distance_cutoffs_a.get(str(atom.element), 0.0)), + 0.0, + ) + if cutoff_distance <= 0.0: + continue + if ( + float( + np.linalg.norm( + anchor - np.asarray(atom.coordinates, dtype=float) + ) + ) + <= cutoff_distance + 1e-6 + ): + coordinated_ids.append(int(atom.atom_id)) + return tuple(sorted(coordinated_ids)) + + +def _coordination_candidate_benefit( + *, + center_atoms: list[PDBAtom], + current_counts: dict[int, int], + coordinated_center_ids: Sequence[int], + target_average_coordination_numbers: dict[str, float], +) -> float: + center_by_id = {int(atom.atom_id): atom for atom in center_atoms} + benefit = 0.0 + for center_id in coordinated_center_ids: + atom = center_by_id.get(int(center_id)) + if atom is None: + continue + target_value = float( + target_average_coordination_numbers.get(str(atom.element), 0.0) + ) + current_value = float(current_counts.get(int(center_id), 0)) + benefit += max(target_value - current_value, 0.0) + return benefit + + +def _average_coordination_by_element( + *, + center_atoms: list[PDBAtom], + anchor_positions: Sequence[np.ndarray], + solute_distance_cutoffs_a: dict[str, float], +) -> dict[str, float]: + current_counts = _coordination_count_by_center_atom( + center_atoms=center_atoms, + anchor_positions=anchor_positions, + solute_distance_cutoffs_a=solute_distance_cutoffs_a, + ) + return _average_coordination_by_element_from_counts( + center_atoms=center_atoms, + current_counts=current_counts, + ) + + +def _average_coordination_by_element_from_counts( + *, + center_atoms: list[PDBAtom], + current_counts: dict[int, int], +) -> dict[str, float]: + centers_by_element: dict[str, list[int]] = defaultdict(list) + for atom in center_atoms: + centers_by_element[str(atom.element)].append(int(atom.atom_id)) + averages: dict[str, float] = {} + for element, atom_ids in sorted(centers_by_element.items()): + if not atom_ids: + continue + averages[str(element)] = float( + sum(current_counts.get(int(atom_id), 0) for atom_id in atom_ids) + ) / float(len(atom_ids)) + return averages + + +def _director_anchor_positions_from_atoms( + *, + placed_atoms: list[PDBAtom], + reference_structure: PDBStructure, + director_atom_index: int, +) -> list[np.ndarray]: + if not placed_atoms: + return [] + director_atom_name = str( + reference_structure.atoms[director_atom_index].atom_name + ) + director_name = _normalized_atom_name( + director_atom_name, + fallback=f"{reference_structure.atoms[director_atom_index].element}{director_atom_index + 1}", + ) + positions: list[np.ndarray] = [] + residue_groups: dict[tuple[str, int], list[PDBAtom]] = defaultdict(list) + for atom in placed_atoms: + residue_groups[ + (str(atom.residue_name), int(atom.residue_number)) + ].append(atom) + for residue_atoms in residue_groups.values(): + director_atom = next( + ( + atom + for atom in residue_atoms + if _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{int(atom.atom_id)}", + ) + == director_name + ), + None, + ) + if director_atom is None: + continue + positions.append( + np.asarray(director_atom.coordinates, dtype=float).copy() + ) + return positions + + +def _analyze_pdb_input( + input_path: Path, + *, + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_library_dir: Path, + reference_match_tolerance_a: float, +) -> SolventShellAnalysisResult: + structure = PDBStructure.from_file(input_path) + total_atoms = len(structure.atoms) + reference_atom_count = len(reference_atoms) + if reference_atom_count == 1: + return _analyze_single_atom_pdb_input( + input_path, + structure=structure, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + + workflow = XYZToPDBWorkflow( + input_path, + reference_library_dir=reference_library_dir, + ) + configuration = _single_reference_configuration( + workflow, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + molecule_name = configuration.molecules[0].name + residue_groups: dict[tuple[str, int], list[object]] = defaultdict(list) + for atom in structure.atoms: + residue_groups[(atom.residue_name, atom.residue_number)].append(atom) + + residue_numbers_by_name: dict[str, list[int]] = defaultdict(list) + molecule_counts_by_name: Counter[str] = Counter() + residue_mismatches: list[SolventShellResidueMismatchSummary] = [] + complete_solvent_atom_ids: set[int] = set() + matched_groups: list[tuple[str, int, list[object]]] = [] + matched_atom_count = 0 + matched_molecule_count = 0 + for (residue_name, residue_number), residue_atoms in sorted( + residue_groups.items(), + key=lambda item: (item[0][1], item[0][0]), + ): + frame = _matching_frame_from_pdb_atoms( + input_path, + residue_atoms=residue_atoms, + ) + converted_residues = workflow._convert_first_frame( + frame, + configuration, + ) + matched_molecules = [ + residue + for residue in converted_residues + if residue.molecule_name == molecule_name + ] + unmatched_atom_total = sum( + len(residue.atoms) + for residue in converted_residues + if residue.molecule_name != molecule_name + ) + matched_atom_total = sum( + len(residue.atoms) for residue in matched_molecules + ) + if ( + not matched_molecules + or unmatched_atom_total > 0 + or matched_atom_total != len(residue_atoms) + ): + mismatch_summary = _build_pdb_residue_mismatch_summary( + residue_name=residue_name, + residue_number=int(residue_number), + residue_atoms=residue_atoms, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + if mismatch_summary is not None: + residue_mismatches.append(mismatch_summary) + continue + matched_here = len(matched_molecules) + residue_numbers_by_name[residue_name].append(residue_number) + molecule_counts_by_name[residue_name] += matched_here + complete_solvent_atom_ids.update( + int(atom.atom_id) for atom in residue_atoms + ) + matched_groups.append( + (str(residue_name), int(residue_number), list(residue_atoms)) + ) + matched_molecule_count += matched_here + matched_atom_count += matched_atom_total + + partial_solvent_atom_ids = { + atom_id + for summary in residue_mismatches + for atom_id in summary.source_atom_ids + } + solute_element_counts = dict( + sorted( + Counter( + atom.element + for atom in structure.atoms + if int(atom.atom_id) not in complete_solvent_atom_ids + and int(atom.atom_id) not in partial_solvent_atom_ids + ).items() + ) + ) + + reference_element_counts = dict( + sorted(Counter(atom.element for atom in reference_atoms).items()) + ) + residue_summaries = tuple( + SolventShellResidueSummary( + residue_name=residue_name, + molecule_count=int(molecule_counts_by_name[residue_name]), + residue_numbers=tuple( + sorted(residue_numbers_by_name[residue_name]) + ), + atom_count=reference_atom_count, + element_counts=reference_element_counts, + ) + for residue_name in sorted(residue_numbers_by_name, key=str.casefold) + ) + return SolventShellAnalysisResult( + input_path=input_path, + input_format="pdb", + reference_name=reference_entry.name, + reference_path=reference_entry.path.expanduser().resolve(), + reference_residue_name=reference_entry.residue_name, + reference_atom_count=reference_atom_count, + detected_solvent_molecules=matched_molecule_count, + matched_atom_count=matched_atom_count, + unmatched_atom_count=max(total_atoms - matched_atom_count, 0), + total_atoms=total_atoms, + match_tolerance_a=reference_match_tolerance_a, + solute_element_counts=solute_element_counts, + complete_solvent_source_atom_ids=tuple( + sorted(complete_solvent_atom_ids) + ), + complete_solvent_source_atom_groups=tuple( + tuple(sorted(int(atom.atom_id) for atom in residue_atoms)) + for _residue_name, _residue_number, residue_atoms in matched_groups + ), + partial_solvent_source_atom_ids=tuple( + sorted(partial_solvent_atom_ids) + ), + matched_residue_summaries=residue_summaries, + residue_mismatch_summaries=tuple(residue_mismatches), + notes=_build_pdb_analysis_notes( + residue_mismatch_count=len(residue_mismatches) + ), + ) + + +def _analyze_single_atom_pdb_input( + input_path: Path, + *, + structure: PDBStructure, + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_match_tolerance_a: float, +) -> SolventShellAnalysisResult: + reference_atom = reference_atoms[0] + residue_numbers_by_name: dict[str, list[int]] = defaultdict(list) + molecule_counts_by_name: Counter[str] = Counter() + for atom in structure.atoms: + if atom.element != reference_atom.element: + continue + residue_numbers_by_name[atom.residue_name].append(atom.residue_number) + molecule_counts_by_name[atom.residue_name] += 1 + residue_summaries = tuple( + SolventShellResidueSummary( + residue_name=residue_name, + molecule_count=int(molecule_counts_by_name[residue_name]), + residue_numbers=tuple( + sorted(residue_numbers_by_name[residue_name]) + ), + atom_count=1, + element_counts={reference_atom.element: 1}, + ) + for residue_name in sorted(residue_numbers_by_name, key=str.casefold) + ) + matched_atom_count = sum( + summary.molecule_count for summary in residue_summaries + ) + matched_atom_ids = { + int(atom.atom_id) + for atom in structure.atoms + if atom.element == reference_atom.element + } + solute_element_counts = dict( + sorted( + Counter( + atom.element + for atom in structure.atoms + if int(atom.atom_id) not in matched_atom_ids + ).items() + ) + ) + return SolventShellAnalysisResult( + input_path=input_path, + input_format="pdb", + reference_name=reference_entry.name, + reference_path=reference_entry.path.expanduser().resolve(), + reference_residue_name=reference_entry.residue_name, + reference_atom_count=1, + detected_solvent_molecules=matched_atom_count, + matched_atom_count=matched_atom_count, + unmatched_atom_count=max(len(structure.atoms) - matched_atom_count, 0), + total_atoms=len(structure.atoms), + match_tolerance_a=reference_match_tolerance_a, + solute_element_counts=solute_element_counts, + complete_solvent_source_atom_ids=tuple(sorted(matched_atom_ids)), + complete_solvent_source_atom_groups=tuple( + (int(atom.atom_id),) + for atom in structure.atoms + if atom.element == reference_atom.element + ), + matched_residue_summaries=residue_summaries, + notes=( + "Single-atom references are matched by element within each PDB " + "residue.", + ), + ) + + +def _analyze_xyz_input( + input_path: Path, + *, + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_library_dir: Path, + reference_match_tolerance_a: float, +) -> SolventShellAnalysisResult: + workflow = XYZToPDBWorkflow( + input_path, + reference_library_dir=reference_library_dir, + ) + if len(reference_atoms) == 1: + return _analyze_single_atom_xyz_input( + input_path, + workflow=workflow, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + + frame = workflow.read_xyz_frame(input_path) + configuration = _single_reference_configuration( + workflow, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + converted_residues = workflow._convert_first_frame(frame, configuration) + molecule_name = configuration.molecules[0].name + matched_residues = [ + residue + for residue in converted_residues + if residue.molecule_name == molecule_name + ] + matched_source_indices = { + int(source_index) + for residue in matched_residues + for source_index in residue.source_atom_indices + } + matched_source_atom_ids = { + int(source_index) + 1 for source_index in matched_source_indices + } + unmatched_atom_records = [ + (index, frame.atoms[index]) + for index in range(len(frame.atoms)) + if index not in matched_source_indices + ] + partial_candidates = _build_xyz_partial_candidate_summaries( + unmatched_atom_records=unmatched_atom_records, + reference_entry=reference_entry, + reference_atoms=reference_atoms, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + partial_candidate_atom_ids = { + atom_id + for summary in partial_candidates + for atom_id in summary.source_atom_ids + } + matched_atom_count = sum( + len(residue.atoms) for residue in matched_residues + ) + total_atoms = len(frame.atoms) + solute_element_counts = dict( + sorted( + Counter( + atom.element + for atom in frame.atoms + if int(atom.atom_id) not in matched_source_atom_ids + and int(atom.atom_id) not in partial_candidate_atom_ids + ).items() + ) + ) + return SolventShellAnalysisResult( + input_path=input_path, + input_format="xyz", + reference_name=reference_entry.name, + reference_path=reference_entry.path.expanduser().resolve(), + reference_residue_name=reference_entry.residue_name, + reference_atom_count=len(reference_atoms), + detected_solvent_molecules=len(matched_residues), + matched_atom_count=matched_atom_count, + unmatched_atom_count=max(total_atoms - matched_atom_count, 0), + total_atoms=total_atoms, + match_tolerance_a=reference_match_tolerance_a, + solute_element_counts=solute_element_counts, + complete_solvent_source_atom_ids=tuple( + sorted(matched_source_atom_ids) + ), + complete_solvent_source_atom_groups=tuple( + tuple( + sorted( + int(source_index) + 1 + for source_index in residue.source_atom_indices + ) + ) + for residue in matched_residues + ), + partial_solvent_source_atom_ids=tuple( + sorted(partial_candidate_atom_ids) + ), + residue_mismatch_summaries=partial_candidates, + notes=( + _build_xyz_analysis_note( + partial_candidate_count=len(partial_candidates) + ), + ), + ) + + +def _analyze_single_atom_xyz_input( + input_path: Path, + *, + workflow: XYZToPDBWorkflow, + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_match_tolerance_a: float, +) -> SolventShellAnalysisResult: + frame = workflow.read_xyz_frame(input_path) + reference_atom = reference_atoms[0] + matched_atom_count = sum( + 1 for atom in frame.atoms if atom.element == reference_atom.element + ) + total_atoms = len(frame.atoms) + solute_element_counts = dict( + sorted( + Counter( + atom.element + for atom in frame.atoms + if atom.element != reference_atom.element + ).items() + ) + ) + return SolventShellAnalysisResult( + input_path=input_path, + input_format="xyz", + reference_name=reference_entry.name, + reference_path=reference_entry.path.expanduser().resolve(), + reference_residue_name=reference_entry.residue_name, + reference_atom_count=1, + detected_solvent_molecules=matched_atom_count, + matched_atom_count=matched_atom_count, + unmatched_atom_count=max(total_atoms - matched_atom_count, 0), + total_atoms=total_atoms, + match_tolerance_a=reference_match_tolerance_a, + solute_element_counts=solute_element_counts, + complete_solvent_source_atom_ids=tuple( + int(atom.atom_id) + for atom in frame.atoms + if atom.element == reference_atom.element + ), + complete_solvent_source_atom_groups=tuple( + (int(atom.atom_id),) + for atom in frame.atoms + if atom.element == reference_atom.element + ), + notes=( + "Single-atom references are matched by element only for XYZ " + "inputs.", + ), + ) + + +def _single_reference_configuration( + workflow: XYZToPDBWorkflow, + *, + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_match_tolerance_a: float, +) -> XYZToPDBConfiguration: + backbone_pairs = ( + reference_entry.backbone_pairs + or _fallback_backbone_pairs(reference_atoms) + ) + if not backbone_pairs: + raise ValueError( + "The selected solvent reference does not provide enough atoms " + "to define a matching backbone pair." + ) + anchors = tuple( + AnchorPairDefinition( + atom1_name=atom1_name, + atom2_name=atom2_name, + tolerance=reference_match_tolerance_a, + ) + for atom1_name, atom2_name in backbone_pairs + ) + resolved_anchor_indices = tuple( + workflow._resolve_anchor_pair_indices( + reference_atoms, + anchor.atom1_name, + anchor.atom2_name, + molecule_name=reference_entry.name, + ) + + (anchor.tolerance,) + for anchor in anchors + ) + molecule = MoleculeDefinition( + name=reference_entry.name, + reference_name=reference_entry.name, + reference_path=reference_entry.path.expanduser().resolve(), + residue_name=reference_entry.residue_name, + reference_atoms=reference_atoms, + anchors=anchors, + resolved_anchor_indices=resolved_anchor_indices, + preferred_anchor_indices=tuple( + (index1, index2) + for index1, index2, _tolerance in resolved_anchor_indices + ), + max_assignment_distance=None, + ) + return XYZToPDBConfiguration( + molecules=(molecule,), + free_atoms={}, + exclude_hydrogen=False, + pbc_params={}, + ) + + +def _resolve_reference_entry( + reference_name: str, + *, + library_dir: Path, +) -> ReferenceLibraryEntry: + presets = list_reference_library(library_dir) + lowered_name = str(reference_name).strip().casefold() + for preset in presets: + if preset.name.casefold() == lowered_name: + return preset + reference_path = Path(str(reference_name)).expanduser() + if reference_path.is_file(): + structure = PDBStructure.from_file(reference_path) + residue_name = ( + structure.atoms[0].residue_name if structure.atoms else "UNK" + ) + return ReferenceLibraryEntry( + name=reference_path.stem, + path=reference_path.resolve(), + residue_name=residue_name, + atom_count=len(structure.atoms), + atom_names=tuple(atom.atom_name for atom in structure.atoms), + backbone_pairs=(), + ) + raise ValueError( + f"Reference molecule {reference_name!r} was not found in {library_dir}." + ) + + +def _fallback_backbone_pairs( + reference_atoms: tuple[object, ...], +) -> tuple[tuple[str, str], ...]: + non_hydrogen_atoms = [ + atom for atom in reference_atoms if atom.element.upper() != "H" + ] + candidate_atoms = ( + non_hydrogen_atoms if len(non_hydrogen_atoms) >= 2 else reference_atoms + ) + if len(candidate_atoms) < 2: + return () + return ( + ( + str(candidate_atoms[0].atom_name), + str(candidate_atoms[1].atom_name), + ), + ) + + +_ANCHOR_ELEMENT_PRIORITY = ( + "O", + "N", + "S", + "P", + "F", + "CL", + "BR", + "I", + "C", +) + + +def _matching_frame_from_pdb_atoms( + input_path: Path, + *, + residue_atoms: list[object], +) -> _MatchingFrame: + return _MatchingFrame( + filepath=input_path, + atoms=[ + _MatchingAtom( + atom_id=index, + element=str(atom.element), + coordinates=np.asarray(atom.coordinates, dtype=float).copy(), + ) + for index, atom in enumerate(residue_atoms, start=1) + ], + ) + + +def _build_xyz_partial_candidate_summaries( + *, + unmatched_atom_records: Sequence[tuple[int, object]], + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_match_tolerance_a: float, +) -> tuple[SolventShellResidueMismatchSummary, ...]: + if not unmatched_atom_records or not reference_atoms: + return () + + reference_element_counts = Counter( + str(atom.element).upper() for atom in reference_atoms + ) + candidate_records = [ + (index, atom) + for index, atom in unmatched_atom_records + if str(atom.element).upper() in reference_element_counts + ] + if not candidate_records: + return () + + available_elements = { + str(atom.element).upper() for _index, atom in candidate_records + } + anchor_index = _select_partial_anchor_index( + reference_atoms, + available_elements=available_elements, + ) + if anchor_index is None: + return () + + reference_names = [ + _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{index + 1}", + ) + for index, atom in enumerate(reference_atoms) + ] + reference_anchor_atom = reference_atoms[anchor_index] + anchor_element = str(reference_anchor_atom.element).upper() + anchor_reference_indices = [anchor_index] + [ + index + for index, atom in enumerate(reference_atoms) + if index != anchor_index + and str(atom.element).upper() == anchor_element + ] + anchor_records = [ + (source_index, atom) + for source_index, atom in candidate_records + if str(atom.element).upper() == anchor_element + ] + if not anchor_records: + return () + + anchor_capacity = max(len(anchor_reference_indices), 1) + summaries: list[SolventShellResidueMismatchSummary] = [] + used_source_indices: set[int] = set() + sorted_anchor_records = sorted( + anchor_records, + key=lambda item: (int(item[1].atom_id), int(item[0])), + ) + for candidate_index, chunk_start in enumerate( + range(0, len(sorted_anchor_records), anchor_capacity), + start=1, + ): + anchor_chunk = [ + item + for item in sorted_anchor_records[ + chunk_start : chunk_start + anchor_capacity + ] + if int(item[0]) not in used_source_indices + ] + if not anchor_chunk: + continue + assigned_source_by_ref_index: dict[int, tuple[int, object]] = {} + assigned_ref_indices: set[int] = set() + for ref_index, (source_index, source_atom) in zip( + anchor_reference_indices, + anchor_chunk, + ): + used_source_indices.add(int(source_index)) + assigned_source_by_ref_index[int(ref_index)] = ( + int(source_index), + source_atom, + ) + assigned_ref_indices.add(int(ref_index)) + + primary_source_index, primary_source_atom = ( + assigned_source_by_ref_index[anchor_index] + ) + primary_anchor_coordinates = np.asarray( + primary_source_atom.coordinates, + dtype=float, + ) + for ref_index, reference_atom in sorted( + enumerate(reference_atoms), + key=lambda item: ( + float( + np.linalg.norm( + np.asarray(item[1].coordinates, dtype=float) + - np.asarray( + reference_anchor_atom.coordinates, dtype=float + ) + ) + ), + item[0], + ), + ): + if ref_index in assigned_ref_indices: + continue + reference_element = str(reference_atom.element).upper() + candidate_matches = [ + (source_index, atom) + for source_index, atom in candidate_records + if int(source_index) not in used_source_indices + and str(atom.element).upper() == reference_element + ] + if not candidate_matches: + continue + expected_distance = float( + np.linalg.norm( + np.asarray(reference_atom.coordinates, dtype=float) + - np.asarray( + reference_anchor_atom.coordinates, dtype=float + ) + ) + ) + best_source_index: int | None = None + best_source_atom: object | None = None + best_delta: float | None = None + for source_index, source_atom in candidate_matches: + observed_distance = float( + np.linalg.norm( + np.asarray(source_atom.coordinates, dtype=float) + - primary_anchor_coordinates + ) + ) + distance_delta = abs(observed_distance - expected_distance) + if best_delta is None or distance_delta < best_delta: + best_source_index = int(source_index) + best_source_atom = source_atom + best_delta = distance_delta + if ( + best_source_index is None + or best_source_atom is None + or best_delta is None + or best_delta + > _xyz_partial_assignment_tolerance( + element=reference_element, + reference_match_tolerance_a=reference_match_tolerance_a, + ) + ): + continue + used_source_indices.add(best_source_index) + assigned_source_by_ref_index[int(ref_index)] = ( + best_source_index, + best_source_atom, + ) + assigned_ref_indices.add(int(ref_index)) + + summaries.append( + _build_xyz_partial_candidate_summary( + residue_name=reference_entry.residue_name, + residue_number=candidate_index, + assigned_source_by_ref_index=assigned_source_by_ref_index, + reference_atoms=reference_atoms, + reference_names=reference_names, + ) + ) + return tuple(summaries) + + +def _build_xyz_partial_candidate_summary( + *, + residue_name: str, + residue_number: int, + assigned_source_by_ref_index: dict[int, tuple[int, object]], + reference_atoms: tuple[object, ...], + reference_names: Sequence[str], +) -> SolventShellResidueMismatchSummary: + ordered_assignments = [ + (ref_index, source_index, source_atom) + for ref_index, (source_index, source_atom) in sorted( + assigned_source_by_ref_index.items() + ) + ] + missing_atom_names = tuple( + reference_names[index] + for index in range(len(reference_atoms)) + if index not in assigned_source_by_ref_index + ) + common_reference_coordinates = [ + np.asarray(reference_atoms[ref_index].coordinates, dtype=float) + for ref_index, _source_index, _source_atom in ordered_assignments + ] + common_source_coordinates = [ + np.asarray(source_atom.coordinates, dtype=float) + for _ref_index, _source_index, source_atom in ordered_assignments + ] + distance_pair_count = 0 + distribution_rmsd_a = 0.0 + max_distance_delta_a = 0.0 + if len(common_reference_coordinates) >= 2: + reference_distances = _pairwise_distance_vector( + common_reference_coordinates + ) + source_distances = _pairwise_distance_vector(common_source_coordinates) + if len(reference_distances) == len(source_distances): + distance_deltas = source_distances - reference_distances + distance_pair_count = int(len(reference_distances)) + if distance_pair_count > 0: + distribution_rmsd_a = float( + np.sqrt(np.mean(np.square(distance_deltas))) + ) + max_distance_delta_a = float(np.max(np.abs(distance_deltas))) + mismatch_reason = ( + "partial XYZ solvent candidate inferred from unmatched anchor atoms" + if len(ordered_assignments) == 1 + else "partial XYZ solvent candidate inferred from unmatched reference-element atoms" + ) + return SolventShellResidueMismatchSummary( + residue_name=residue_name, + residue_number=residue_number, + observed_atom_count=len(ordered_assignments), + common_atom_count=len(ordered_assignments), + reference_atom_count=len(reference_atoms), + missing_atom_names=missing_atom_names, + extra_atom_names=(), + distance_pair_count=distance_pair_count, + distribution_rmsd_a=distribution_rmsd_a, + max_distance_delta_a=max_distance_delta_a, + mismatch_reason=mismatch_reason, + source_atom_ids=tuple( + int(source_atom.atom_id) + for _ref_index, _source_index, source_atom in ordered_assignments + ), + ) + + +def _select_partial_anchor_index( + reference_atoms: Sequence[object], + *, + available_elements: set[str] | None = None, +) -> int | None: + if not reference_atoms: + return None + non_hydrogen_indices = [ + index + for index, atom in enumerate(reference_atoms) + if str(atom.element).upper() != "H" + and ( + available_elements is None + or str(atom.element).upper() in available_elements + ) + ] + if not non_hydrogen_indices: + fallback_indices = [ + index + for index, atom in enumerate(reference_atoms) + if available_elements is None + or str(atom.element).upper() in available_elements + ] + return fallback_indices[0] if fallback_indices else None + + non_hydrogen_counts = Counter( + str(reference_atoms[index].element).upper() + for index in non_hydrogen_indices + ) + for element in _ANCHOR_ELEMENT_PRIORITY: + matches = [ + index + for index in non_hydrogen_indices + if str(reference_atoms[index].element).upper() == element + and non_hydrogen_counts[element] == 1 + ] + if matches: + return matches[0] + for element in _ANCHOR_ELEMENT_PRIORITY: + matches = [ + index + for index in non_hydrogen_indices + if str(reference_atoms[index].element).upper() == element + ] + if matches: + return matches[0] + return non_hydrogen_indices[0] + + +def _xyz_partial_assignment_tolerance( + *, + element: str, + reference_match_tolerance_a: float, +) -> float: + base_tolerance = max(float(reference_match_tolerance_a) * 4.0, 0.75) + if str(element).upper() == "H": + return base_tolerance + return max(base_tolerance, 1.25) + + +def _build_pdb_residue_mismatch_summary( + *, + residue_name: str, + residue_number: int, + residue_atoms: list[object], + reference_entry: ReferenceLibraryEntry, + reference_atoms: tuple[object, ...], + reference_match_tolerance_a: float, +) -> SolventShellResidueMismatchSummary | None: + reference_names = [ + _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{index + 1}", + ) + for index, atom in enumerate(reference_atoms) + ] + residue_names = [ + _normalized_atom_name( + str(atom.atom_name), + fallback=f"{atom.element}{index + 1}", + ) + for index, atom in enumerate(residue_atoms) + ] + reference_counts = Counter(reference_names) + residue_counts = Counter(residue_names) + + remaining_missing = reference_counts - residue_counts + missing_names: list[str] = [] + for atom_name in reference_names: + if remaining_missing[atom_name] <= 0: + continue + missing_names.append(atom_name) + remaining_missing[atom_name] -= 1 + + remaining_extra = residue_counts - reference_counts + extra_names: list[str] = [] + for atom_name in residue_names: + if remaining_extra[atom_name] <= 0: + continue + extra_names.append(atom_name) + remaining_extra[atom_name] -= 1 + + missing_atom_names = tuple(missing_names) + extra_atom_names = tuple(extra_names) + if not missing_atom_names and not extra_atom_names: + return None + + residue_atoms_by_name: dict[str, list[object]] = defaultdict(list) + for index, atom in enumerate(residue_atoms): + residue_atoms_by_name[residue_names[index]].append(atom) + + consumed_counts: Counter[str] = Counter() + common_reference_coordinates: list[np.ndarray] = [] + common_residue_coordinates: list[np.ndarray] = [] + common_atom_count = 0 + for index, reference_atom in enumerate(reference_atoms): + atom_name = reference_names[index] + atom_matches = residue_atoms_by_name.get(atom_name, []) + occurrence = consumed_counts[atom_name] + if occurrence >= len(atom_matches): + continue + common_atom_count += 1 + common_reference_coordinates.append( + np.asarray(reference_atom.coordinates, dtype=float) + ) + common_residue_coordinates.append( + np.asarray(atom_matches[occurrence].coordinates, dtype=float) + ) + consumed_counts[atom_name] += 1 + if common_atom_count == 0: + return None + + distance_pair_count = 0 + distribution_rmsd_a = 0.0 + max_distance_delta_a = 0.0 + if common_atom_count >= 2: + reference_distances = _pairwise_distance_vector( + common_reference_coordinates + ) + residue_distances = _pairwise_distance_vector( + common_residue_coordinates + ) + if len(reference_distances) == len(residue_distances): + distance_deltas = residue_distances - reference_distances + distance_pair_count = int(len(reference_distances)) + if distance_pair_count > 0: + distribution_rmsd_a = float( + np.sqrt(np.mean(np.square(distance_deltas))) + ) + max_distance_delta_a = float(np.max(np.abs(distance_deltas))) + + same_reference_residue = ( + residue_name.upper().strip() + == reference_entry.residue_name.upper().strip() + ) + geometry_consistent = common_atom_count < 2 or max_distance_delta_a <= max( + float(reference_match_tolerance_a), 0.35 + ) + if not same_reference_residue and ( + common_atom_count < 2 or not geometry_consistent + ): + return None + + reason_parts: list[str] = [] + if missing_atom_names: + reason_parts.append("missing reference atoms") + if extra_atom_names: + reason_parts.append("contains extra non-reference atoms") + if not reason_parts: + reason_parts.append("did not resolve to a complete solvent residue") + return SolventShellResidueMismatchSummary( + residue_name=residue_name, + residue_number=residue_number, + observed_atom_count=len(residue_atoms), + common_atom_count=common_atom_count, + reference_atom_count=len(reference_atoms), + missing_atom_names=missing_atom_names, + extra_atom_names=extra_atom_names, + distance_pair_count=distance_pair_count, + distribution_rmsd_a=distribution_rmsd_a, + max_distance_delta_a=max_distance_delta_a, + mismatch_reason=", ".join(reason_parts), + source_atom_ids=tuple(int(atom.atom_id) for atom in residue_atoms), + ) + + +def _pairwise_distance_vector( + coordinates: Sequence[np.ndarray], +) -> np.ndarray: + coordinate_array = np.asarray(coordinates, dtype=float) + if len(coordinate_array) < 2: + return np.zeros(0, dtype=float) + distances: list[float] = [] + for index in range(len(coordinate_array) - 1): + deltas = coordinate_array[index + 1 :] - coordinate_array[index] + distances.extend( + float(value) for value in np.linalg.norm(deltas, axis=1) + ) + return np.asarray(distances, dtype=float) + + +def _build_pdb_analysis_notes( + *, + residue_mismatch_count: int, +) -> tuple[str, ...]: + notes = [ + "PDB matching is constrained to existing residue groups so " + "reported residue names reflect the source file.", + ] + if residue_mismatch_count > 0: + notes.append( + "Incomplete residue groups that retained identifiable solvent " + "atom names are preserved as mismatches with missing-atom details." + ) + return tuple(notes) + + +def _build_xyz_analysis_note( + *, + partial_candidate_count: int, +) -> str: + if partial_candidate_count > 0: + return ( + "XYZ matching uses geometric complete-molecule fits first, then " + "infers partial solvent candidates heuristically from unmatched " + "reference-element atoms so those anchors can guide solvent rebuilds." + ) + return ( + "XYZ matching uses geometric complete-molecule fits first, then " + "checks unmatched atoms for solvent-like partial candidates that " + "could guide solvent rebuilds." + ) + + +__all__ = [ + "DEFAULT_REFERENCE_MATCH_TOLERANCE_A", + "SolventShellAnalysisResult", + "SolventShellBuildResult", + "SolventShellResidueMismatchSummary", + "SolventShellResidueSummary", + "analyze_solvent_shell", + "build_solvent_shell_output", + "default_director_atom_name", + "reference_atom_choices", +] diff --git a/src/saxshell/fullrmc/ui/__init__.py b/src/saxshell/fullrmc/ui/__init__.py index de421a0..954156d 100644 --- a/src/saxshell/fullrmc/ui/__init__.py +++ b/src/saxshell/fullrmc/ui/__init__.py @@ -1,14 +1,24 @@ """Qt UI package for the rmcsetup scaffold.""" +from .constraints_preview_window import ConstraintsPreviewWindow from .main_window import RMCSetupMainWindow, launch_rmcsetup_ui +from .packmol_docker_dialog import PackmolDockerLinkDialog from .representative_preview_window import ( RepresentativePreviewTab, RepresentativePreviewWindow, ) +from .solvent_shell_builder_window import ( + SolventShellBuilderMainWindow, + launch_solvent_shell_builder_ui, +) __all__ = [ + "ConstraintsPreviewWindow", + "PackmolDockerLinkDialog", "RMCSetupMainWindow", "RepresentativePreviewTab", "RepresentativePreviewWindow", + "SolventShellBuilderMainWindow", "launch_rmcsetup_ui", + "launch_solvent_shell_builder_ui", ] diff --git a/src/saxshell/fullrmc/ui/constraints_preview_window.py b/src/saxshell/fullrmc/ui/constraints_preview_window.py new file mode 100644 index 0000000..eff176a --- /dev/null +++ b/src/saxshell/fullrmc/ui/constraints_preview_window.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QHBoxLayout, + QLabel, + QMainWindow, + QPlainTextEdit, + QPushButton, + QVBoxLayout, + QWidget, +) + + +class ConstraintsPreviewWindow(QMainWindow): + def __init__( + self, + constraints_path: str | Path, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.constraints_path = Path(constraints_path).expanduser().resolve() + self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, True) + self.setWindowFlag(Qt.WindowType.Window, True) + self.setWindowTitle( + f"Merged Constraints Preview - {self.constraints_path.name}" + ) + self.resize(980, 760) + self._build_ui() + self._load_text() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(8) + + self.path_label = QLabel(str(self.constraints_path)) + self.path_label.setTextInteractionFlags( + Qt.TextInteractionFlag.TextSelectableByMouse + ) + self.path_label.setWordWrap(True) + root.addWidget(self.path_label) + + controls_row = QHBoxLayout() + self.copy_all_button = QPushButton("Copy All") + self.copy_all_button.clicked.connect(self._copy_all_text) + controls_row.addWidget(self.copy_all_button) + self.reload_button = QPushButton("Reload") + self.reload_button.clicked.connect(self._load_text) + controls_row.addWidget(self.reload_button) + controls_row.addStretch(1) + root.addLayout(controls_row) + + self.text_box = QPlainTextEdit() + self.text_box.setReadOnly(True) + self.text_box.setLineWrapMode(QPlainTextEdit.LineWrapMode.NoWrap) + root.addWidget(self.text_box, stretch=1) + + self.setCentralWidget(central) + + def _load_text(self) -> None: + self.text_box.setPlainText( + self.constraints_path.read_text(encoding="utf-8") + ) + + def _copy_all_text(self) -> None: + QApplication.clipboard().setText(self.text_box.toPlainText()) + self.statusBar().showMessage("Copied merged constraints to clipboard.") diff --git a/src/saxshell/fullrmc/ui/main_window.py b/src/saxshell/fullrmc/ui/main_window.py index b7d5d72..c759845 100644 --- a/src/saxshell/fullrmc/ui/main_window.py +++ b/src/saxshell/fullrmc/ui/main_window.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import json import re import sys from datetime import datetime @@ -13,7 +14,8 @@ NavigationToolbar2QT as NavigationToolbar, ) from matplotlib.figure import Figure -from PySide6.QtCore import QObject, Qt, QThread, Signal, Slot +from PySide6.QtCore import QObject, QSettings, Qt, QThread, QUrl, Signal, Slot +from PySide6.QtGui import QAction, QDesktopServices from PySide6.QtWidgets import ( QApplication, QCheckBox, @@ -55,6 +57,13 @@ ConstraintGenerationSettings, build_constraint_generation, ) +from saxshell.fullrmc.packmol_docker import ( + DEFAULT_PACKMOL_CONTAINER_ROOT, + PackmolDockerClient, + PackmolDockerLink, + PackmolDockerSyncResult, + save_packmol_docker_link_metadata, +) from saxshell.fullrmc.packmol_planning import ( PackmolPlanningMetadata, PackmolPlanningSettings, @@ -74,6 +83,7 @@ RepresentativeSelectionMetadata, RepresentativeSelectionSettings, build_representative_preview_clusters, + representative_source_solvent_mode_to_variant, select_distribution_representatives, select_first_file_representatives, ) @@ -94,24 +104,49 @@ ) from saxshell.fullrmc.solvent_handling import ( GeneratedPDBInspection, + RepresentativeSolventDistributionAnalysis, + SoluteAtomBuildSetting, SolventHandlingMetadata, SolventHandlingSettings, + analyze_representative_solvent_distribution, + available_representative_structure_modes, build_generated_pdb_inspections, build_representative_solvent_outputs, list_solvent_reference_presets, + representative_structure_mode_label, + representative_structure_path_for_mode, + resolved_representative_structure_mode, + save_solvent_handling_metadata, + solvent_entry_lookup_for_representatives, +) +from saxshell.fullrmc.solvent_shell_builder import ( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + default_director_atom_name, + reference_atom_choices, +) +from saxshell.fullrmc.ui.constraints_preview_window import ( + ConstraintsPreviewWindow, ) from saxshell.fullrmc.ui.generated_pdb_preview_window import ( GeneratedPDBPreviewWindow, ) +from saxshell.fullrmc.ui.packmol_docker_dialog import PackmolDockerLinkDialog from saxshell.fullrmc.ui.representative_preview_window import ( RepresentativePreviewWindow, ) +from saxshell.plotting import Q_A_INVERSE_LABEL from saxshell.saxs.dream import ( DreamModelPlotData, DreamSummary, DreamViolinPlotData, SAXSDreamResultsLoader, ) +from saxshell.saxs.electron_density_mapping.ui.viewer import ( + ElectronDensityStructureViewer, +) +from saxshell.saxs.electron_density_mapping.workflow import ( + load_electron_density_structure, +) from saxshell.saxs.project_manager import ( DreamBestFitSelection, SAXSProjectManager, @@ -178,6 +213,7 @@ "#d0b060", "#5f9e8f", ) +_PACKMOL_DOCKER_PRESETS_KEY = "packmol_docker_presets" _READINESS_TASK_DETAILS = { "project_source": { "title": "Project Source", @@ -225,31 +261,40 @@ ), }, "representative_selection": { - "title": "Representative Clusters", + "title": "Representative Structures", "purpose": ( - "Choose one representative structure for each active DREAM " - "cluster bin." + "Load the saved representative structures that will be combined " + "with the active DREAM weights for the Solvent Shell Builder, " + "Packmol planning, Packmol setup, and cluster-specific " + "constraints." ), "needed": ( - "An active DREAM selection. For bond/angle mode, bond-pair and " - "angle-triplet definitions are also needed." + "A saved representative-structure set in the SAXS project plus " + "an active DREAM selection so the current model weights remain " + "available downstream." ), "prerequisites": ( - "The SAXS project must be loaded and a DREAM model must be " - "selected. Cluster-source data must also be available." + "The SAXS project must be loaded. Use the dedicated " + "Representative Structures tool to create or update the saved " + "structure set for this project." ), }, "solvent_outputs": { - "title": "Representative PDB Outputs", + "title": "Solvent Shell Builder", "purpose": ( - "Build the representative no-solvent and solvent-completed PDB " - "files used for inspection and Packmol setup." + "Use the active representative structure set directly when it " + "already has full solvent, or build the missing solvent shell " + "for no-solvent and partial-solvent representatives." ), "needed": ( - "A coordinated-solvent mode and either a bundled solvent preset " - "or a custom solvent-reference PDB." + "Either imported representative structures that already provide " + "the Full solvent set, or a solvent reference plus the saved " + "representative solvent outputs with Full solvent selected as " + "active." + ), + "prerequisites": ( + "Representative structures must already be saved for the project." ), - "prerequisites": ("Representative clusters must already be computed."), }, "packmol_plan": { "title": "Packmol Plan", @@ -259,10 +304,10 @@ ), "needed": ( "Packmol planning mode, box side length, saved solution " - "properties, and saved representative clusters." + "properties, and saved representative structures." ), "prerequisites": ( - "Solution properties and representative clusters must both be " + "Solution properties and representative structures must both be " "available first." ), }, @@ -273,8 +318,8 @@ "the simulation box." ), "needed": ( - "A saved Packmol plan and the generated representative solvent " - "PDB files." + "A saved Packmol plan and the Full solvent representative " + "structure set." ), "prerequisites": ( "Packmol planning and representative PDB solvent outputs must " @@ -384,19 +429,38 @@ def __init__( self.setWindowTitle("SAXSShell (rmcsetup)") self.setWindowIcon(load_saxshell_icon()) self.resize(1080, 860) + self._build_menu_bar() self.project_manager = SAXSProjectManager() self._project_source_state: RMCDreamProjectSource | None = None + self._child_tool_windows: list[object] = [] self._representative_preview_window: ( RepresentativePreviewWindow | None ) = None self._generated_pdb_preview_window: ( GeneratedPDBPreviewWindow | None ) = None + self._constraints_preview_window: ConstraintsPreviewWindow | None = ( + None + ) + self._solvent_distribution_analysis: ( + RepresentativeSolventDistributionAnalysis | None + ) = None self._updating_dream_controls = False self._representative_presets: dict[str, BondAnalysisPreset] = {} self._solution_presets: dict[str, SolutionPropertiesPreset] = {} - self._generated_pdb_inspections: list[GeneratedPDBInspection] = [] + self._generated_pdb_inspections: list[ + GeneratedPDBInspection | None + ] = [] + self._solvent_table_preview_paths: list[Path | None] = [] + self._solvent_table_details: list[str] = [] + self._updating_generated_pdb_mode_combo = False + self._solvent_cutoff_spins: dict[str, QDoubleSpinBox] = {} + self._solvent_coordination_center_items: dict[ + str, QTableWidgetItem + ] = {} + self._solvent_coordination_target_spins: dict[str, QDoubleSpinBox] = {} + self._updating_solvent_table = False self._current_dream_model_plot_data: DreamModelPlotData | None = None self._representative_thread: QThread | None = None self._representative_worker: RepresentativeSelectionWorker | None = ( @@ -407,6 +471,8 @@ def __init__( self._dream_results_loader_cache: dict[ Path, SAXSDreamResultsLoader ] = {} + self._section_toggle_buttons: dict[str, QToolButton] = {} + self._section_content_widgets: dict[str, QWidget] = {} self._readiness_checkboxes: dict[str, QCheckBox] = {} self._available_solvent_presets = list_solvent_reference_presets() self._dream_model_preview_figure = Figure( @@ -568,6 +634,11 @@ def __init__( "Ready", "Checked after the SAXS project source loads successfully.", ), + section_key="project_source", + ) + project_content_layout = self._create_collapsible_section_layout( + project_layout, + "project_source", ) project_row = QHBoxLayout() self.project_dir_edit = QLineEdit() @@ -586,12 +657,12 @@ def __init__( self.refresh_button = QPushButton("Reload Project") self.refresh_button.clicked.connect(self._refresh_project_source) project_row.addWidget(self.refresh_button) - project_layout.addLayout(project_row) + project_content_layout.addLayout(project_row) self.project_summary_box = QPlainTextEdit() self.project_summary_box.setReadOnly(True) self.project_summary_box.setMinimumHeight(170) - project_layout.addWidget(self.project_summary_box) + project_content_layout.addWidget(self.project_summary_box) self._left_layout.addWidget(self.project_group) self.output_group = QGroupBox("RMCSetup Output Structure") @@ -611,6 +682,11 @@ def __init__( "Ready", "Checked after a DREAM run is selected for rmcsetup.", ), + section_key="dream_selection", + ) + dream_content_layout = self._create_collapsible_section_layout( + dream_layout, + "dream_selection", ) dream_form = QFormLayout() @@ -680,12 +756,12 @@ def __init__( interval_container = QWidget() interval_container.setLayout(interval_row) dream_form.addRow("Credible interval (%)", interval_container) - dream_layout.addLayout(dream_form) + dream_content_layout.addLayout(dream_form) self.dream_source_summary_box = QPlainTextEdit() self.dream_source_summary_box.setReadOnly(True) self.dream_source_summary_box.setMinimumHeight(170) - dream_layout.addWidget(self.dream_source_summary_box) + dream_content_layout.addWidget(self.dream_source_summary_box) self._left_layout.addWidget(self.dream_group) self.favorite_group = QGroupBox("Saved DREAM Model") @@ -729,6 +805,11 @@ def __init__( "Ready", "Checked after solution properties are calculated.", ), + section_key="solution_properties", + ) + solution_content_layout = self._create_collapsible_section_layout( + solution_layout, + "solution_properties", ) self.solution_preset_group = QGroupBox("Solution Presets") solution_preset_layout = QVBoxLayout(self.solution_preset_group) @@ -753,7 +834,7 @@ def __init__( ) self.solution_preset_hint_label.setWordWrap(True) solution_preset_layout.addWidget(self.solution_preset_hint_label) - solution_layout.addWidget(self.solution_preset_group) + solution_content_layout.addWidget(self.solution_preset_group) solution_form = QFormLayout() self.solution_mode_combo = QComboBox() @@ -822,13 +903,13 @@ def __init__( "Solvent molar mass (g/mol)", self.molar_mass_solvent_spin, ) - solution_layout.addLayout(solution_form) + solution_content_layout.addLayout(solution_form) self.solution_mode_hint_label = QLabel() self.solution_mode_hint_label.setWordWrap(True) self.solution_mode_hint_label.setText( solution_properties_mode_hint_text("mass") ) - solution_layout.addWidget(self.solution_mode_hint_label) + solution_content_layout.addWidget(self.solution_mode_hint_label) self.solution_mode_stack = QStackedWidget() @@ -909,7 +990,7 @@ def __init__( ) self.solution_mode_stack.addWidget(molarity_page) - solution_layout.addWidget(self.solution_mode_stack) + solution_content_layout.addWidget(self.solution_mode_stack) solution_button_row = QHBoxLayout() self.calculate_solution_button = QPushButton("Calculate") @@ -918,19 +999,19 @@ def __init__( ) solution_button_row.addWidget(self.calculate_solution_button) solution_button_row.addStretch(1) - solution_layout.addLayout(solution_button_row) + solution_content_layout.addLayout(solution_button_row) self.solution_output_box = QPlainTextEdit() self.solution_output_box.setReadOnly(True) self.solution_output_box.setMinimumHeight(220) - solution_layout.addWidget(self.solution_output_box) + solution_content_layout.addWidget(self.solution_output_box) self._left_layout.addWidget(self.solution_group) self.dream_preview_group = QGroupBox("Selected DREAM Model Preview") dream_preview_layout = QVBoxLayout(self.dream_preview_group) self.dream_preview_intro_label = QLabel( "Preview the selected DREAM model fit and posterior weight " - "distributions before building representative clusters or " + "distributions before loading representative structures or " "Packmol inputs." ) self.dream_preview_intro_label.setWordWrap(True) @@ -994,19 +1075,49 @@ def __init__( dream_preview_layout.addWidget(self._dream_preview_splitter) self._right_layout.addWidget(self.dream_preview_group) - self.representative_group = QGroupBox( - "Representative Cluster Selection" - ) + self.representative_group = QGroupBox("Representative Structures") representative_layout = QVBoxLayout(self.representative_group) self._add_group_readiness_row( representative_layout, ( "representative_selection", "Ready", - "Checked after representative clusters are computed.", + "Checked after representative structures are saved and " + "loaded for the active project.", ), + section_key="representative_selection", + ) + representative_content_layout = ( + self._create_collapsible_section_layout( + representative_layout, + "representative_selection", + ) + ) + self.representative_intro_label = QLabel( + "rmcsetup consumes saved representative structures from the " + "dedicated Representative Structures tool. Those saved files are " + "combined here with the selected DREAM distribution weights, the " + "solution density targets, Solvent Shell Builder, Packmol planning, " + "and cluster-specific constraint generation." + ) + self.representative_intro_label.setWordWrap(True) + representative_content_layout.addWidget( + self.representative_intro_label + ) + self.representative_workflow_label = QLabel( + "Use Open Representative Structures to create or update the " + "saved project set, then reload it here. The active DREAM model " + "selection in rmcsetup remains the source of the fitted weights " + "used for downstream box planning and constraint generation." ) - representative_form = QFormLayout() + self.representative_workflow_label.setWordWrap(True) + representative_content_layout.addWidget( + self.representative_workflow_label + ) + + self.representative_mode_widget = QWidget() + representative_form = QFormLayout(self.representative_mode_widget) + representative_form.setContentsMargins(0, 0, 0, 0) self.representative_mode_combo = QComboBox() for label, value in _REPRESENTATIVE_MODE_ITEMS: self.representative_mode_combo.addItem(label, value) @@ -1017,7 +1128,10 @@ def __init__( "Selection mode", self.representative_mode_combo, ) - representative_layout.addLayout(representative_form) + self.representative_mode_widget.setVisible(False) + representative_content_layout.addWidget( + self.representative_mode_widget + ) self.representative_preset_group = QGroupBox("Bondanalysis Presets") representative_preset_layout = QVBoxLayout( @@ -1055,7 +1169,10 @@ def __init__( representative_preset_layout.addWidget( self.representative_preset_hint_label ) - representative_layout.addWidget(self.representative_preset_group) + self.representative_preset_group.setVisible(False) + representative_content_layout.addWidget( + self.representative_preset_group + ) self.representative_bond_pairs_row = QGroupBox("Bond Pairs") representative_bond_pairs_layout = QVBoxLayout( @@ -1093,7 +1210,10 @@ def __init__( self.representative_bond_pair_table ) self._add_empty_representative_bond_pair_row(blocked=True) - representative_layout.addWidget(self.representative_bond_pairs_row) + self.representative_bond_pairs_row.setVisible(False) + representative_content_layout.addWidget( + self.representative_bond_pairs_row + ) self.representative_angle_triplets_row = QGroupBox("Angle Triplets") representative_angle_triplets_layout = QVBoxLayout( @@ -1139,7 +1259,10 @@ def __init__( self.representative_angle_triplet_table ) self._add_empty_representative_angle_triplet_row(blocked=True) - representative_layout.addWidget(self.representative_angle_triplets_row) + self.representative_angle_triplets_row.setVisible(False) + representative_content_layout.addWidget( + self.representative_angle_triplets_row + ) self.representative_advanced_toggle = QPushButton( "Show Advanced Settings" @@ -1148,7 +1271,10 @@ def __init__( self.representative_advanced_toggle.toggled.connect( self._toggle_representative_advanced_settings ) - representative_layout.addWidget(self.representative_advanced_toggle) + self.representative_advanced_toggle.setVisible(False) + representative_content_layout.addWidget( + self.representative_advanced_toggle + ) self.representative_advanced_widget = QWidget() representative_advanced_layout = QFormLayout( @@ -1190,71 +1316,87 @@ def __init__( self.representative_angle_weight_spin, ) self.representative_advanced_widget.setVisible(False) - representative_layout.addWidget(self.representative_advanced_widget) + representative_content_layout.addWidget( + self.representative_advanced_widget + ) representative_button_row = QHBoxLayout() self.compute_representatives_button = QPushButton( - "Compute Representative Clusters" + "Open Representative Structures" ) self.compute_representatives_button.clicked.connect( - self._compute_representative_clusters + self._open_representative_structures_tool ) representative_button_row.addWidget( self.compute_representatives_button ) self.preview_representatives_button = QPushButton( - "Preview Representative Analysis" + "Reload Saved Representative Structures" ) self.preview_representatives_button.clicked.connect( - self._preview_representative_clusters + self._reload_saved_representative_structures ) representative_button_row.addWidget( self.preview_representatives_button ) representative_button_row.addStretch(1) - representative_layout.addLayout(representative_button_row) + representative_content_layout.addLayout(representative_button_row) self.representative_status_label = QLabel( - "Representative selection: idle" + "Representative structures: waiting for saved project data." ) self.representative_status_label.setWordWrap(True) - representative_layout.addWidget(self.representative_status_label) + representative_content_layout.addWidget( + self.representative_status_label + ) self.representative_progress_bar = QProgressBar() self.representative_progress_bar.setRange(0, 1) self.representative_progress_bar.setValue(0) - representative_layout.addWidget(self.representative_progress_bar) + self.representative_progress_bar.setVisible(False) + representative_content_layout.addWidget( + self.representative_progress_bar + ) self.representative_summary_box = QPlainTextEdit() self.representative_summary_box.setReadOnly(True) self.representative_summary_box.setMinimumHeight(210) - representative_layout.addWidget(self.representative_summary_box) + representative_content_layout.addWidget( + self.representative_summary_box + ) self._right_layout.addWidget(self.representative_group) - self.solvent_group = QGroupBox("Solvent Handling") + self.solvent_group = QGroupBox("Solvent Shell Builder") solvent_layout = QVBoxLayout(self.solvent_group) self._add_group_readiness_row( solvent_layout, ( "solvent_outputs", "Ready", - "Checked after solvent-aware representative PDB outputs are built.", + "Checked when the active representative structure set already " + "has full solvent.", ), + section_key="solvent_outputs", ) - solvent_form = QFormLayout() - - self.coordinated_solvent_mode_combo = QComboBox() - for label, value in _COORDINATED_SOLVENT_MODE_ITEMS: - self.coordinated_solvent_mode_combo.addItem(label, value) - solvent_form.addRow( - "Coordinated solvent mode", - self.coordinated_solvent_mode_combo, + solvent_content_layout = self._create_collapsible_section_layout( + solvent_layout, + "solvent_outputs", + ) + self.solvent_intro_label = QLabel( + "This subsection is active for no-solvent and partial-solvent " + "representative structure sets. The build action analyzes the " + "selected representatives, then writes the completed full-solvent " + "PDBs for previewing and Packmol." ) + self.solvent_intro_label.setWordWrap(True) + solvent_content_layout.addWidget(self.solvent_intro_label) + + solvent_form = QFormLayout() self.solvent_reference_source_combo = QComboBox() for label, value in _SOLVENT_REFERENCE_SOURCE_ITEMS: self.solvent_reference_source_combo.addItem(label, value) self.solvent_reference_source_combo.currentIndexChanged.connect( - self._update_solvent_reference_widgets + self._handle_solvent_reference_source_changed ) solvent_form.addRow( "Reference source", @@ -1264,6 +1406,9 @@ def __init__( self.solvent_preset_combo = QComboBox() for preset in self._available_solvent_presets: self.solvent_preset_combo.addItem(preset.name, preset.name) + self.solvent_preset_combo.currentIndexChanged.connect( + self._handle_solvent_reference_changed + ) solvent_form.addRow("Preset reference", self.solvent_preset_combo) solvent_path_row = QHBoxLayout() @@ -1271,6 +1416,9 @@ def __init__( self.solvent_reference_edit.setPlaceholderText( "Choose a solvent reference PDB" ) + self.solvent_reference_edit.editingFinished.connect( + self._handle_solvent_reference_changed + ) solvent_path_row.addWidget(self.solvent_reference_edit, stretch=1) self.browse_solvent_reference_button = QPushButton("Browse...") self.browse_solvent_reference_button.clicked.connect( @@ -1281,46 +1429,144 @@ def __init__( solvent_path_widget.setLayout(solvent_path_row) solvent_form.addRow("Custom reference PDB", solvent_path_widget) + self.solvent_reference_match_tolerance_spin = self._new_float_spin( + maximum=5.0, + step=0.05, + decimals=3, + value=DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + ) + self.solvent_reference_match_tolerance_spin.valueChanged.connect( + self._handle_solvent_analysis_setting_changed + ) + solvent_form.addRow( + "Reference match tolerance (A)", + self.solvent_reference_match_tolerance_spin, + ) + + self.solvent_director_atom_combo = QComboBox() + solvent_form.addRow( + "Director atom", + self.solvent_director_atom_combo, + ) + self.solvent_minimum_separation_spin = self._new_float_spin( maximum=10.0, step=0.1, decimals=2, value=1.2, ) + self.solvent_minimum_separation_spin.valueChanged.connect( + self._update_solvent_build_panel_state + ) solvent_form.addRow( "Minimum solvent atom separation (A)", self.solvent_minimum_separation_spin, ) - solvent_layout.addLayout(solvent_form) + solvent_content_layout.addLayout(solvent_form) + + self.solvent_reference_details_box = QPlainTextEdit() + self.solvent_reference_details_box.setReadOnly(True) + self.solvent_reference_details_box.setMinimumHeight(90) + solvent_content_layout.addWidget(self.solvent_reference_details_box) + + self.solvent_status_group = QGroupBox( + "Detected Representative Solvent State" + ) + solvent_status_layout = QVBoxLayout(self.solvent_status_group) + self.solvent_status_headline_label = QLabel( + "No representative solvent state has been determined yet." + ) + self.solvent_status_headline_label.setWordWrap(True) + solvent_status_layout.addWidget(self.solvent_status_headline_label) + self.solvent_status_stats_label = QLabel( + "Save representative structures, choose a solvent reference, " + "and press Build Solvent-Decorated Representative PDBs." + ) + self.solvent_status_stats_label.setWordWrap(True) + self.solvent_status_stats_label.setTextInteractionFlags( + Qt.TextInteractionFlag.TextSelectableByMouse + ) + solvent_status_layout.addWidget(self.solvent_status_stats_label) + solvent_content_layout.addWidget(self.solvent_status_group) + + self.solvent_cutoff_group = QGroupBox("Solute Coordination Settings") + solvent_cutoff_layout = QVBoxLayout(self.solvent_cutoff_group) + self.solvent_cutoff_status_label = QLabel( + "Analyze or build the representative structures to populate the " + "solute atom types used for solvent-shell building." + ) + self.solvent_cutoff_status_label.setWordWrap(True) + solvent_cutoff_layout.addWidget(self.solvent_cutoff_status_label) + self.solvent_cutoff_table = QTableWidget(0, 5) + self.solvent_cutoff_table.setHorizontalHeaderLabels( + [ + "Element", + "Count", + "Coordination Center", + "Avg Coord #", + "Director Distance (A)", + ] + ) + self.solvent_cutoff_table.setEditTriggers( + QTableWidget.EditTrigger.NoEditTriggers + ) + self.solvent_cutoff_table.itemChanged.connect( + self._handle_solvent_coordination_table_item_changed + ) + self.solvent_cutoff_table.horizontalHeader().setStretchLastSection( + True + ) + solvent_cutoff_layout.addWidget(self.solvent_cutoff_table) + solvent_content_layout.addWidget(self.solvent_cutoff_group) solvent_button_row = QHBoxLayout() + self.analyze_solvent_outputs_button = QPushButton( + "Analyze / Refresh Solvent State" + ) + self.analyze_solvent_outputs_button.clicked.connect( + self._analyze_representative_solvent_states + ) + solvent_button_row.addWidget(self.analyze_solvent_outputs_button) self.build_solvent_outputs_button = QPushButton( - "Build Representative PDBs" + "Build Solvent-Decorated Representative PDBs" ) self.build_solvent_outputs_button.clicked.connect( self._build_representative_solvent_outputs ) solvent_button_row.addWidget(self.build_solvent_outputs_button) solvent_button_row.addStretch(1) - solvent_layout.addLayout(solvent_button_row) + solvent_content_layout.addLayout(solvent_button_row) self.solvent_summary_box = QPlainTextEdit() self.solvent_summary_box.setReadOnly(True) self.solvent_summary_box.setMinimumHeight(190) - solvent_layout.addWidget(self.solvent_summary_box) + solvent_content_layout.addWidget(self.solvent_summary_box) self.generated_pdb_group = QGroupBox( - "Generated Representative PDB Files" + "Active Representative Structures" ) generated_pdb_layout = QVBoxLayout(self.generated_pdb_group) self.generated_pdb_intro_label = QLabel( - "Browse the generated no-solvent and completed representative " - "PDB files, review residue assignments, and open a 3D preview " - "of the selected structure." + "Choose the active representative structure set, then select a " + "row to inspect the currently selected representative structure " + "and preview it directly below. Solvent-aware variants appear " + "here after the Solvent Shell Builder has run." ) self.generated_pdb_intro_label.setWordWrap(True) generated_pdb_layout.addWidget(self.generated_pdb_intro_label) + generated_pdb_mode_row = QHBoxLayout() + generated_pdb_mode_row.addWidget(QLabel("Active structure set")) + self.generated_pdb_mode_combo = QComboBox() + self.generated_pdb_mode_combo.currentIndexChanged.connect( + self._handle_generated_pdb_mode_changed + ) + generated_pdb_mode_row.addWidget( + self.generated_pdb_mode_combo, + stretch=1, + ) + generated_pdb_layout.addLayout(generated_pdb_mode_row) + generated_pdb_button_row = QHBoxLayout() self.open_generated_pdb_preview_button = QPushButton( "Open Selected Preview" @@ -1334,16 +1580,15 @@ def __init__( generated_pdb_button_row.addStretch(1) generated_pdb_layout.addLayout(generated_pdb_button_row) - self.generated_pdb_table = QTableWidget(0, 7) + self.generated_pdb_table = QTableWidget(0, 6) self.generated_pdb_table.setHorizontalHeaderLabels( [ "Representative", - "Variant", - "File", + "Detected State", + "Active Set", + "Structure File", "Atoms", - "Elements", - "Solvent Molecules", - "Molecule Residues", + "Source", ] ) self.generated_pdb_table.setSelectionBehavior( @@ -1366,9 +1611,20 @@ def __init__( self.generated_pdb_details_box = QPlainTextEdit() self.generated_pdb_details_box.setReadOnly(True) - self.generated_pdb_details_box.setMinimumHeight(210) + self.generated_pdb_details_box.setMinimumHeight(150) generated_pdb_layout.addWidget(self.generated_pdb_details_box) - solvent_layout.addWidget(self.generated_pdb_group) + + self.generated_pdb_viewer_status_label = QLabel( + "Select a representative row to preview the active structure." + ) + self.generated_pdb_viewer_status_label.setWordWrap(True) + generated_pdb_layout.addWidget(self.generated_pdb_viewer_status_label) + self.generated_pdb_viewer = ElectronDensityStructureViewer( + self.generated_pdb_group + ) + self.generated_pdb_viewer.setMinimumHeight(360) + generated_pdb_layout.addWidget(self.generated_pdb_viewer, stretch=1) + representative_content_layout.addWidget(self.generated_pdb_group) self._right_layout.addWidget(self.solvent_group) self.packmol_group = QGroupBox("Packmol Planning") @@ -1385,7 +1641,26 @@ def __init__( "Setup Ready", "Checked after Packmol setup inputs are built.", ), + section_key="packmol", ) + packmol_content_layout = self._create_collapsible_section_layout( + packmol_layout, + "packmol", + ) + self.packmol_docker_group = QGroupBox("Linked Packmol Docker") + packmol_docker_layout = QVBoxLayout(self.packmol_docker_group) + self.packmol_docker_hint_label = QLabel( + "Use Tools > Link Packmol Docker Container to validate a " + "container, confirm Packmol is installed, and select the " + f"container-side project folder inside {DEFAULT_PACKMOL_CONTAINER_ROOT}." + ) + self.packmol_docker_hint_label.setWordWrap(True) + packmol_docker_layout.addWidget(self.packmol_docker_hint_label) + self.packmol_docker_summary_box = QPlainTextEdit() + self.packmol_docker_summary_box.setReadOnly(True) + self.packmol_docker_summary_box.setMinimumHeight(150) + packmol_docker_layout.addWidget(self.packmol_docker_summary_box) + packmol_content_layout.addWidget(self.packmol_docker_group) packmol_form = QFormLayout() self.packmol_planning_mode_combo = QComboBox() @@ -1406,7 +1681,12 @@ def __init__( "Box side length (A)", self.packmol_box_side_spin, ) - packmol_layout.addLayout(packmol_form) + self.packmol_free_solvent_combo = QComboBox() + packmol_form.addRow( + "Free solvent structure", + self.packmol_free_solvent_combo, + ) + packmol_content_layout.addLayout(packmol_form) packmol_button_row = QHBoxLayout() self.compute_packmol_plan_button = QPushButton( @@ -1416,25 +1696,41 @@ def __init__( self._compute_packmol_plan ) packmol_button_row.addWidget(self.compute_packmol_plan_button) + packmol_button_row.addWidget(QLabel("Tolerance (A)")) + self.packmol_tolerance_spin = QDoubleSpinBox() + self.packmol_tolerance_spin.setDecimals(3) + self.packmol_tolerance_spin.setRange(0.1, 100.0) + self.packmol_tolerance_spin.setSingleStep(0.1) + self.packmol_tolerance_spin.setValue(2.0) + self.packmol_tolerance_spin.setSuffix(" A") + packmol_button_row.addWidget(self.packmol_tolerance_spin) self.build_packmol_setup_button = QPushButton("Build Packmol Setup") self.build_packmol_setup_button.clicked.connect( self._build_packmol_setup ) packmol_button_row.addWidget(self.build_packmol_setup_button) + self.open_packmol_setup_folder_button = QPushButton( + "Open Packmol Setup Folder" + ) + self.open_packmol_setup_folder_button.clicked.connect( + self._open_packmol_setup_folder + ) + self.open_packmol_setup_folder_button.setEnabled(False) + packmol_button_row.addWidget(self.open_packmol_setup_folder_button) packmol_button_row.addStretch(1) - packmol_layout.addLayout(packmol_button_row) + packmol_content_layout.addLayout(packmol_button_row) self.packmol_plan_summary_box = QPlainTextEdit() self.packmol_plan_summary_box.setReadOnly(True) self.packmol_plan_summary_box.setMinimumHeight(180) - packmol_layout.addWidget(self.packmol_plan_summary_box) - packmol_layout.addWidget(self._packmol_plan_toolbar) + packmol_content_layout.addWidget(self.packmol_plan_summary_box) + packmol_content_layout.addWidget(self._packmol_plan_toolbar) self._packmol_plan_canvas.setMinimumHeight(260) - packmol_layout.addWidget(self._packmol_plan_canvas) + packmol_content_layout.addWidget(self._packmol_plan_canvas) self.packmol_build_summary_box = QPlainTextEdit() self.packmol_build_summary_box.setReadOnly(True) self.packmol_build_summary_box.setMinimumHeight(150) - packmol_layout.addWidget(self.packmol_build_summary_box) + packmol_content_layout.addWidget(self.packmol_build_summary_box) self._right_layout.addWidget(self.packmol_group) self.constraints_group = QGroupBox("Constraint Generation") @@ -1468,6 +1764,22 @@ def __init__( self._generate_constraints ) constraints_button_row.addWidget(self.generate_constraints_button) + self.open_constraints_folder_button = QPushButton( + "Open Constraints Folder" + ) + self.open_constraints_folder_button.clicked.connect( + self._open_constraints_folder + ) + self.open_constraints_folder_button.setEnabled(False) + constraints_button_row.addWidget(self.open_constraints_folder_button) + self.preview_constraints_button = QPushButton( + "Show Merged Constraints" + ) + self.preview_constraints_button.clicked.connect( + self._open_constraints_preview + ) + self.preview_constraints_button.setEnabled(False) + constraints_button_row.addWidget(self.preview_constraints_button) constraints_button_row.addStretch(1) constraints_layout.addLayout(constraints_button_row) @@ -1485,10 +1797,11 @@ def __init__( "\n".join( [ "1. Load a SAXS project and choose a DREAM result source.", - "2. Validate cluster sources and select representative-cluster mode.", - "3. Enter solution properties for Packmol box construction.", - "4. Compute representative structures and inspect Packmol planning counts.", - "5. Build representative structures, constraints, and fullrmc inputs.", + "2. Open Representative Structures and save the representative project set.", + "3. Enter solution properties for the target box density and composition.", + "4. Build solvent shells as needed and convert DREAM " + "weights into Packmol cluster counts.", + "5. Build the Packmol box inputs and cluster-specific fullrmc constraints.", ] ) ) @@ -1517,6 +1830,116 @@ def __init__( else: self._refresh_project_source() + def _build_menu_bar(self) -> None: + menu_bar = self.menuBar() + self.tools_menu = menu_bar.addMenu("Tools") + self.open_representative_structures_action = QAction( + "Open Representative Structures", + self, + ) + self.open_representative_structures_action.triggered.connect( + self._open_representative_structures_tool + ) + self.tools_menu.addAction(self.open_representative_structures_action) + self.link_packmol_docker_action = QAction( + "Link Packmol Docker Container", + self, + ) + self.link_packmol_docker_action.triggered.connect( + self._open_packmol_docker_link_dialog + ) + self.tools_menu.addSeparator() + self.tools_menu.addAction(self.link_packmol_docker_action) + + def _track_child_tool_window(self, window: object) -> None: + if window in self._child_tool_windows: + return + if isinstance(window, QWidget): + window.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, True) + destroyed_signal = getattr(window, "destroyed", None) + if destroyed_signal is not None and hasattr( + destroyed_signal, "connect" + ): + destroyed_signal.connect( + lambda _obj=None, win=window: self._forget_child_tool_window( + win + ) + ) + self._child_tool_windows.append(window) + + def _forget_child_tool_window(self, window: object) -> None: + self._child_tool_windows = [ + existing + for existing in self._child_tool_windows + if existing is not window + ] + + def _reload_saved_representative_structures(self) -> None: + if self.project_dir() is None: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before reloading representative structures.", + ) + return + self._append_run_log("Reloading saved representative structures.") + self._refresh_project_source() + + def _handle_representative_structure_results_changed( + self, + project_dir_text: str, + ) -> None: + current_project_dir = self.project_dir() + if current_project_dir is None: + return + try: + changed_project_dir = Path(project_dir_text).expanduser().resolve() + except Exception: + return + if changed_project_dir != current_project_dir: + return + self._append_run_log( + "Representative structures were updated in the dedicated tool; " + "reloading the saved project set." + ) + self._refresh_project_source() + + def _open_representative_structures_tool(self) -> None: + from saxshell.representativefinder.ui.main_window import ( + launch_representativefinder_ui, + ) + + state = self._project_source_state + if state is None: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before opening Representative Structures.", + ) + return + project_dir = Path(state.settings.project_dir).resolve() + initial_input_path = state.settings.resolved_clusters_dir + window = launch_representativefinder_ui( + initial_project_dir=project_dir, + initial_input_path=initial_input_path, + ) + project_results_changed = getattr( + window, "project_results_changed", None + ) + if project_results_changed is not None and hasattr( + project_results_changed, "connect" + ): + project_results_changed.connect( + self._handle_representative_structure_results_changed + ) + self._track_child_tool_window(window) + self.statusBar().showMessage( + f"Opened representative structures for {project_dir}" + ) + self._append_run_log( + "Opened the Representative Structures tool for this project." + ) + def project_dir(self) -> Path | None: text = self.project_dir_edit.text().strip() if not text: @@ -1546,6 +1969,112 @@ def _browse_project_dir(self) -> None: def _toggle_software_details(self, checked: bool) -> None: self.software_details_panel.setVisible(checked) + def _packmol_docker_settings(self) -> QSettings: + return QSettings("SAXShell", "RMCSetup") + + def _create_packmol_docker_client(self) -> PackmolDockerClient: + return PackmolDockerClient() + + def _recent_packmol_docker_presets(self) -> list[PackmolDockerLink]: + raw_value = self._packmol_docker_settings().value( + _PACKMOL_DOCKER_PRESETS_KEY, + "[]", + ) + if isinstance(raw_value, str): + try: + payload = json.loads(raw_value) + except Exception: + payload = [] + elif isinstance(raw_value, (list, tuple)): + payload = list(raw_value) + else: + payload = [] + presets: list[PackmolDockerLink] = [] + for entry in payload: + preset = PackmolDockerLink.from_dict( + dict(entry) if isinstance(entry, dict) else None + ) + if preset is not None: + presets.append(preset) + return presets + + def _remember_packmol_docker_preset( + self, + link: PackmolDockerLink, + ) -> None: + preset = PackmolDockerLink.from_dict(link.to_preset_dict()) + if preset is None: + return + signature = ( + preset.container_name, + preset.packmol_command, + preset.shell_command, + preset.container_project_root, + ) + kept = [ + existing + for existing in self._recent_packmol_docker_presets() + if ( + existing.container_name, + existing.packmol_command, + existing.shell_command, + existing.container_project_root, + ) + != signature + ] + payload = [preset.to_preset_dict()] + [ + item.to_preset_dict() for item in kept[:7] + ] + self._packmol_docker_settings().setValue( + _PACKMOL_DOCKER_PRESETS_KEY, + json.dumps(payload), + ) + + def _save_packmol_docker_link( + self, + link: PackmolDockerLink | None, + ) -> None: + state = self._project_source_state + if state is None: + return + save_packmol_docker_link_metadata( + state.rmcsetup_paths.packmol_docker_link_path, + link, + ) + state.packmol_docker_link = link + + def _open_packmol_docker_link_dialog(self) -> None: + state = self._project_source_state + if state is None: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before linking a Packmol Docker container.", + ) + return + dialog = PackmolDockerLinkDialog( + current_link=state.packmol_docker_link, + recent_presets=self._recent_packmol_docker_presets(), + docker_client=self._create_packmol_docker_client(), + parent=self, + ) + if not dialog.exec(): + return + link = dialog.selected_link() + if link is None: + return + link.linked_at = datetime.now().isoformat(timespec="seconds") + self._remember_packmol_docker_preset(link) + self._save_packmol_docker_link(link) + self.packmol_docker_summary_box.setPlainText( + self._packmol_docker_summary_text(state.packmol_setup) + ) + self.output_summary_box.setPlainText(self._output_structure_text()) + self._append_run_log( + "Linked Packmol Docker container " + f"{link.container_name} at {link.container_project_root}." + ) + def _refresh_project_source(self) -> None: self._set_task_progress("Loading project source...", 10) project_dir = self.project_dir() @@ -1693,11 +2222,18 @@ def _output_structure_text(self) -> str: paths = state.rmcsetup_paths lines = [ f"Root: {paths.rmcsetup_dir}", - f"Representative clusters: {paths.representative_clusters_dir}", + ( + "Representative structures root: " + f"{paths.representative_clusters_dir}" + ), ( "Representative metadata: " f"{paths.representative_selection_path}" ), + ( + "Partial-solvent representatives: " + f"{paths.representative_partial_solvent_dir}" + ), f"PDBs without solvent: {paths.pdb_no_solvent_dir}", f"PDBs with solvent: {paths.pdb_with_solvent_dir}", f"Packmol inputs: {paths.packmol_inputs_dir}", @@ -1706,6 +2242,7 @@ def _output_structure_text(self) -> str: ("Distribution metadata: " f"{paths.distribution_selection_path}"), ("Solution metadata: " f"{paths.solution_properties_path}"), ("Solvent metadata: " f"{paths.solvent_handling_path}"), + ("Packmol Docker link: " f"{paths.packmol_docker_link_path}"), ("Packmol plan metadata: " f"{paths.packmol_plan_path}"), ("Packmol setup metadata: " f"{paths.packmol_setup_path}"), ("Constraint metadata: " f"{paths.constraint_generation_path}"), @@ -1914,9 +2451,31 @@ def _add_group_readiness_row( self, layout: QVBoxLayout, *definitions: tuple[str, str, str], + section_key: str | None = None, ) -> None: row = QHBoxLayout() row.setContentsMargins(0, 0, 0, 0) + if section_key is not None: + toggle = QToolButton() + toggle.setCheckable(True) + toggle.setChecked(True) + toggle.setArrowType(Qt.ArrowType.DownArrow) + toggle.setToolButtonStyle( + Qt.ToolButtonStyle.ToolButtonTextBesideIcon + ) + toggle.setText("Collapse") + toggle.setToolTip( + "Collapse or expand this section while keeping its " + "readiness status visible." + ) + toggle.toggled.connect( + lambda checked, key=section_key: self._toggle_section_content( + key, + checked, + ) + ) + self._section_toggle_buttons[section_key] = toggle + row.addWidget(toggle) row.addStretch(1) for key, label, tooltip in definitions: checkbox = QCheckBox(label) @@ -1928,6 +2487,38 @@ def _add_group_readiness_row( row.addWidget(checkbox) layout.addLayout(row) + def _create_collapsible_section_layout( + self, + layout: QVBoxLayout, + section_key: str, + ) -> QVBoxLayout: + content_widget = QWidget() + content_layout = QVBoxLayout(content_widget) + content_layout.setContentsMargins(0, 0, 0, 0) + content_layout.setSpacing(layout.spacing()) + layout.addWidget(content_widget) + self._section_content_widgets[section_key] = content_widget + self._toggle_section_content(section_key, True) + return content_layout + + def _toggle_section_content( + self, + section_key: str, + expanded: bool, + ) -> None: + content_widget = self._section_content_widgets.get(section_key) + if content_widget is not None: + content_widget.setVisible(expanded) + toggle = self._section_toggle_buttons.get(section_key) + if toggle is not None: + toggle.blockSignals(True) + toggle.setChecked(expanded) + toggle.setArrowType( + Qt.ArrowType.DownArrow if expanded else Qt.ArrowType.RightArrow + ) + toggle.setText("Collapse" if expanded else "Expand") + toggle.blockSignals(False) + def _readiness_states(self) -> dict[str, bool]: state = self._project_source_state has_project = state is not None @@ -1943,7 +2534,8 @@ def _readiness_states(self) -> dict[str, bool]: has_project and state.representative_selection is not None ), "solvent_outputs": ( - has_project and state.solvent_handling is not None + has_project + and self._active_representative_structure_set_is_ready() ), "packmol_plan": ( has_project and state.packmol_planning is not None @@ -2173,23 +2765,16 @@ def _configure_tooltips(self) -> None: ) self._set_widget_tooltip( self.compute_representatives_button, - "Compute and save the representative clusters for the current " - "DREAM model selection.", + "Open the dedicated Representative Structures tool for this " + "project.", ) self._set_widget_tooltip( self.representative_progress_bar, - "Progress for the background representative-cluster selection job.", + "Progress for the legacy in-panel representative-selection job.", ) self._set_widget_tooltip( self.preview_representatives_button, - "Open a preview window for the saved representative cluster " - "analysis.", - ) - self._set_widget_tooltip( - self.coordinated_solvent_mode_combo, - "Choose whether coordinated solvent is removed, partially " - "retained, or fully retained around each representative " - "cluster.", + "Reload the saved representative structures from this project.", ) self._set_widget_tooltip( self.solvent_reference_source_combo, @@ -2203,33 +2788,60 @@ def _configure_tooltips(self) -> None: self._set_widget_tooltip( self.solvent_reference_edit, "Path to the custom solvent-reference PDB used for coordinated " - "solvent handling.", + "solvent-shell building.", ) self._set_widget_tooltip( self.browse_solvent_reference_button, "Browse for a custom solvent-reference PDB file.", ) + self._set_widget_tooltip( + self.solvent_reference_match_tolerance_spin, + "Tolerance used while matching the selected solvent reference " + "against each representative structure.", + ) + self._set_widget_tooltip( + self.solvent_director_atom_combo, + "Reference atom that should point toward the solute cluster " + "during solvent-shell building.", + ) self._set_widget_tooltip( self.solvent_minimum_separation_spin, "Minimum allowed distance between placed solvent atoms and the " - "surrounding coordination sphere during partial-solvent " - "completion. The anchor atom fixed to the shell site is " - "excluded from this clash check.", + "surrounding coordination sphere during solvent-shell " + "placement and refinement.", + ) + self._set_widget_tooltip( + self.analyze_solvent_outputs_button, + "Analyze every saved representative structure to determine the " + "current coordinated-solvent state of the representative set.", + ) + self._set_widget_tooltip( + self.solvent_cutoff_table, + "Choose which solute elements are coordination centers, set " + "their target average coordination numbers, and define the " + "director-atom cutoff distance used for solvent placement.", ) self._set_widget_tooltip( self.build_solvent_outputs_button, - "Build the representative cluster PDB outputs with the chosen " - "solvent-handling mode.", + "Build the stripped and solvent-decorated representative PDB " + "outputs using the current automatic solvent analysis and " + "coordination settings.", + ) + self._set_widget_tooltip( + self.generated_pdb_mode_combo, + "Choose which saved representative structure set is active for " + "the table, viewer, and Packmol setup: source, no solvent, " + "partial solvent, or full solvent when available.", ) self._set_widget_tooltip( self.generated_pdb_table, - "Review the generated no-solvent and completed representative " - "PDB files, then double-click a row to open its structure preview.", + "Review the currently active representative structures and " + "select one row to inspect it in the embedded viewer.", ) self._set_widget_tooltip( self.open_generated_pdb_preview_button, - "Open a 3D preview window for the currently selected generated " - "representative PDB file.", + "Open a 3D preview window for the currently selected saved " + "no-solvent or full-solvent representative PDB file.", ) self._set_widget_tooltip( self.packmol_planning_mode_combo, @@ -2240,6 +2852,11 @@ def _configure_tooltips(self) -> None: self.packmol_box_side_spin, "Side length of the cubic Packmol box used for count planning.", ) + self._set_widget_tooltip( + self.packmol_free_solvent_combo, + "Choose the solvent structure file used for the free bulk " + "solvent population in the Packmol box.", + ) self._set_widget_tooltip( self.compute_packmol_plan_button, "Compute cluster counts and target weights for the current " @@ -2248,7 +2865,12 @@ def _configure_tooltips(self) -> None: self._set_widget_tooltip( self.build_packmol_setup_button, "Build Packmol input files and audit outputs from the saved " - "plan.", + "plan using the active Full solvent representative structure set.", + ) + self._set_widget_tooltip( + self.packmol_tolerance_spin, + "Tolerance written into the Packmol input file for the setup " + "build.", ) self._set_widget_tooltip( self.constraint_length_tolerance_spin, @@ -2265,6 +2887,16 @@ def _configure_tooltips(self) -> None: "Generate per-structure and merged fullrmc constraints from the " "selected representative structures.", ) + self._set_widget_tooltip( + self.open_constraints_folder_button, + "Open the folder that contains the generated merged fullrmc " + "constraints file and per-structure constraint files.", + ) + self._set_widget_tooltip( + self.preview_constraints_button, + "Open a copy-friendly window that shows the merged fullrmc " + "constraints file contents.", + ) def _on_posterior_filter_changed(self) -> None: self._update_posterior_filter_widgets() @@ -2305,16 +2937,31 @@ def _populate_representative_controls(self) -> None: self.representative_group.setEnabled(state is not None) if state is None: self._apply_representative_metadata(None) + self.compute_representatives_button.setEnabled(False) self.preview_representatives_button.setEnabled(False) + self.representative_status_label.setText( + "Representative structures: no SAXS project loaded." + ) self.representative_summary_box.setPlainText( "Load a SAXS project and choose a DREAM source before " - "selecting representative cluster files." + "loading saved representative structures." ) return self._apply_representative_metadata(state.representative_selection) - self.preview_representatives_button.setEnabled( - state.representative_selection is not None - ) + self.compute_representatives_button.setEnabled(True) + self.preview_representatives_button.setEnabled(True) + if state.representative_selection is None: + self.representative_status_label.setText( + "Representative structures: no saved project set loaded." + ) + else: + selection_mode = ( + state.representative_selection.selection_mode or "unknown" + ) + self.representative_status_label.setText( + "Representative structures: saved project set loaded " + f"({selection_mode})." + ) self.representative_summary_box.setPlainText( self._representative_summary_text( state.representative_selection, @@ -2326,19 +2973,18 @@ def _populate_solvent_controls(self) -> None: self.solvent_group.setEnabled(state is not None) if state is None: self._apply_solvent_metadata(None) + self.analyze_solvent_outputs_button.setEnabled(False) self.build_solvent_outputs_button.setEnabled(False) self.solvent_summary_box.setPlainText( - "Load a SAXS project and compute representative clusters before " - "building solvent-aware representative PDB outputs." + "Load a SAXS project and save representative structures before " + "running the Solvent Shell Builder." ) return self._apply_solvent_metadata(state.solvent_handling) - self.build_solvent_outputs_button.setEnabled( - state.representative_selection is not None - ) self.solvent_summary_box.setPlainText( self._solvent_summary_text(state.solvent_handling) ) + self._update_solvent_build_panel_state() def _populate_packmol_planning_controls(self) -> None: state = self._project_source_state @@ -2347,14 +2993,20 @@ def _populate_packmol_planning_controls(self) -> None: self._apply_packmol_planning_metadata(None) self.compute_packmol_plan_button.setEnabled(False) self.build_packmol_setup_button.setEnabled(False) + self.packmol_free_solvent_combo.setEnabled(False) self.packmol_plan_summary_box.setPlainText( "Load a SAXS project, calculate solution properties, and " - "compute representative clusters before planning Packmol " + "save representative structures before planning Packmol " "cluster counts." ) self.packmol_build_summary_box.setPlainText( - "Build Packmol setup after computing cluster counts and " - "solvent-aware representative PDB outputs." + "Build Packmol setup after computing cluster counts, " + "choosing a free-solvent structure, and preparing the " + "active full-solvent representative files." + ) + self.open_packmol_setup_folder_button.setEnabled(False) + self.packmol_docker_summary_box.setPlainText( + self._packmol_docker_summary_text() ) return self._apply_packmol_planning_metadata(state.packmol_planning) @@ -2364,12 +3016,18 @@ def _populate_packmol_planning_controls(self) -> None: ) self.build_packmol_setup_button.setEnabled( state.packmol_planning is not None - and state.solvent_handling is not None + and self._active_representative_structure_set_is_ready() + ) + self.packmol_free_solvent_combo.setEnabled( + self.packmol_free_solvent_combo.count() > 0 ) self.packmol_plan_summary_box.setPlainText( self._packmol_plan_summary_text(state.packmol_planning) ) self._apply_packmol_setup_metadata(state.packmol_setup) + self.packmol_docker_summary_box.setPlainText( + self._packmol_docker_summary_text(state.packmol_setup) + ) def _populate_constraint_controls(self) -> None: state = self._project_source_state @@ -2377,6 +3035,8 @@ def _populate_constraint_controls(self) -> None: if state is None: self._apply_constraint_metadata(None) self.generate_constraints_button.setEnabled(False) + self.open_constraints_folder_button.setEnabled(False) + self.preview_constraints_button.setEnabled(False) self.constraints_summary_box.setPlainText( "Build Packmol setup inputs before generating per-structure " "constraint files and the merged fullrmc constraints file." @@ -2386,6 +3046,9 @@ def _populate_constraint_controls(self) -> None: self.generate_constraints_button.setEnabled( state.packmol_setup is not None ) + has_constraints = state.constraint_generation is not None + self.open_constraints_folder_button.setEnabled(has_constraints) + self.preview_constraints_button.setEnabled(has_constraints) def _selected_representative_mode(self) -> str: return str( @@ -2422,6 +3085,11 @@ def _apply_representative_metadata( self.representative_bond_weight_spin.setValue(settings.bond_weight) self.representative_angle_weight_spin.setValue(settings.angle_weight) self._update_representative_mode_widgets() + self._refresh_generated_pdb_mode_combo() + state = self._project_source_state + self._refresh_generated_pdb_browser( + state.solvent_handling if state is not None else None + ) def _selected_representative_preset_name(self) -> str | None: payload = self.representative_preset_combo.currentData() @@ -2764,80 +3432,403 @@ def _table_text(table: QTableWidget, row: int, column: int) -> str: item = table.item(row, column) return item.text().strip() if item is not None else "" - def _apply_solvent_metadata( + def _available_generated_pdb_mode_items( self, - metadata: SolventHandlingMetadata | None, - ) -> None: - settings = ( - metadata.settings - if metadata is not None - else SolventHandlingSettings() + ) -> list[tuple[str, str]]: + state = self._project_source_state + if state is None or state.representative_selection is None: + return [] + return [ + (representative_structure_mode_label(mode), mode) + for mode in available_representative_structure_modes( + state.representative_selection, + state.solvent_handling, + ) + ] + + def _active_generated_pdb_mode(self) -> str: + state = self._project_source_state + if state is None or state.representative_selection is None: + return "source" + preferred = self.generated_pdb_mode_combo.currentData() + preferred_mode = ( + str(preferred).strip() if preferred is not None else None ) - self._set_combo_value( - self.coordinated_solvent_mode_combo, - settings.coordinated_solvent_mode, + return resolved_representative_structure_mode( + state.representative_selection, + state.solvent_handling, + preferred_mode=preferred_mode, ) - self._set_combo_value( - self.solvent_reference_source_combo, - settings.reference_source, + + def _active_representative_structure_set_is_ready(self) -> bool: + state = self._project_source_state + return bool( + state is not None + and state.representative_selection is not None + and self._active_generated_pdb_mode() == "full_solvent" ) - self._set_combo_value(self.solvent_preset_combo, settings.preset_name) - self.solvent_reference_edit.setText( - settings.custom_reference_path or "" + + def _solvent_shell_builder_required(self) -> bool: + state = self._project_source_state + return bool( + state is not None + and state.representative_selection is not None + and self._active_generated_pdb_mode() != "full_solvent" + ) + + def _set_solvent_shell_builder_controls_enabled( + self, + enabled: bool, + ) -> None: + for widget in ( + self.solvent_reference_source_combo, + self.solvent_preset_combo, + self.solvent_reference_edit, + self.browse_solvent_reference_button, + self.solvent_reference_match_tolerance_spin, + self.solvent_director_atom_combo, + self.solvent_minimum_separation_spin, + self.solvent_reference_details_box, + self.solvent_status_group, + self.solvent_cutoff_group, + ): + widget.setEnabled(enabled) + + def _solvent_build_has_required_coordination_settings( + self, + analysis: RepresentativeSolventDistributionAnalysis, + ) -> bool: + if analysis.distribution_status in { + "complete_solvent", + "partial_solvent", + "mixed_complete_and_partial", + }: + return True + selected_settings = self._selected_solvent_coordination_settings() + return any( + setting.coordination_center + and float(setting.target_coordination_number) > 0.0 + and float(setting.director_distance_cutoff_a) > 0.0 + for setting in selected_settings.values() + ) + + def _representative_entry_selected_state_text( + self, + representative_entry: object, + active_mode: str, + ) -> str: + source_variant = representative_source_solvent_mode_to_variant( + getattr(representative_entry, "source_solvent_mode", None) + ) + if active_mode == "full_solvent": + return "Full solvent analyzed" + if source_variant == active_mode: + return f"{representative_structure_mode_label(active_mode)} source" + if active_mode in { + "no_solvent", + "partial_solvent", + "full_solvent", + }: + return ( + f"{representative_structure_mode_label(active_mode)} selected" + ) + return "Not analyzed" + + def _refresh_generated_pdb_mode_combo(self) -> None: + state = self._project_source_state + items = self._available_generated_pdb_mode_items() + preferred_mode: str | None = None + if state is not None and state.solvent_handling is not None: + preferred_mode = str( + state.solvent_handling.settings.coordinated_solvent_mode + ).strip() + current = self.generated_pdb_mode_combo.currentData() + if current is not None and str(current).strip(): + preferred_mode = str(current).strip() + self._updating_generated_pdb_mode_combo = True + try: + self.generated_pdb_mode_combo.clear() + for label, mode in items: + self.generated_pdb_mode_combo.addItem(label, mode) + self.generated_pdb_mode_combo.setEnabled(len(items) > 1) + if not items: + return + resolved_mode = resolved_representative_structure_mode( + state.representative_selection if state is not None else None, + state.solvent_handling if state is not None else None, + preferred_mode=preferred_mode, + ) + self._set_combo_value(self.generated_pdb_mode_combo, resolved_mode) + finally: + self._updating_generated_pdb_mode_combo = False + + def _handle_generated_pdb_mode_changed(self) -> None: + if self._updating_generated_pdb_mode_combo: + return + state = self._project_source_state + if state is None or state.representative_selection is None: + return + resolved_mode = self._active_generated_pdb_mode() + if state.solvent_handling is not None: + state.solvent_handling.settings.coordinated_solvent_mode = ( + resolved_mode + ) + save_solvent_handling_metadata( + state.rmcsetup_paths.solvent_handling_path, + state.solvent_handling, + ) + self.solvent_summary_box.setPlainText( + self._solvent_summary_text(state.solvent_handling) + ) + self._refresh_generated_pdb_browser( + state.solvent_handling if state is not None else None + ) + self.solvent_summary_box.setPlainText( + self._solvent_summary_text(state.solvent_handling) + ) + self._update_solvent_status_panel(state.solvent_handling) + self._update_solvent_build_panel_state() + self._populate_packmol_planning_controls() + self._update_readiness_progress() + + def _apply_solvent_metadata( + self, + metadata: SolventHandlingMetadata | None, + ) -> None: + settings = ( + metadata.settings + if metadata is not None + else SolventHandlingSettings() + ) + self._set_combo_value( + self.solvent_reference_source_combo, + settings.reference_source, + ) + self._set_combo_value(self.solvent_preset_combo, settings.preset_name) + self.solvent_reference_edit.setText( + settings.custom_reference_path or "" + ) + self.solvent_reference_match_tolerance_spin.setValue( + settings.reference_match_tolerance_a ) self.solvent_minimum_separation_spin.setValue( settings.minimum_solvent_atom_separation_a ) + self._solvent_distribution_analysis = None self._update_solvent_reference_widgets() + self._populate_solvent_director_atom_choices( + selected_name=settings.director_atom_name + ) + self._update_solvent_reference_details() + counts = ( + metadata.aggregate_solute_element_counts + if metadata is not None + else {} + ) + self._populate_solvent_cutoff_table( + counts, + settings.solute_atom_settings, + ) + self._refresh_generated_pdb_mode_combo() self._refresh_generated_pdb_browser(metadata) + self._update_solvent_status_panel(metadata) + self._update_solvent_build_panel_state() def _refresh_generated_pdb_browser( self, metadata: SolventHandlingMetadata | None, ) -> None: + state = self._project_source_state self._generated_pdb_inspections = [] + self._solvent_table_preview_paths = [] + self._solvent_table_details = [] self.generated_pdb_table.setRowCount(0) - if metadata is None: + self.generated_pdb_viewer.draw_placeholder() + self.generated_pdb_group.setEnabled( + state is not None and state.representative_selection is not None + ) + if state is None or state.representative_selection is None: self.generated_pdb_details_box.setPlainText( - "Build representative PDB outputs to browse the generated " - "no-solvent and completed PDB files." + "Open Representative Structures, save the project set, and " + "reload it here to browse the active representative " + "structures." ) - self.open_generated_pdb_preview_button.setEnabled(False) - return - try: - inspections = build_generated_pdb_inspections(metadata) - except Exception as exc: - self.generated_pdb_details_box.setPlainText( - f"Unable to inspect generated PDB files: {exc}" + self.generated_pdb_viewer_status_label.setText( + "Select a representative row to preview the active structure." ) self.open_generated_pdb_preview_button.setEnabled(False) return - self._generated_pdb_inspections = inspections - self.generated_pdb_table.setRowCount(len(inspections)) - for row, inspection in enumerate(inspections): + active_mode = self._active_generated_pdb_mode() + active_mode_label = representative_structure_mode_label(active_mode) + inspection_lookup: dict[ + tuple[str, str, str, str], GeneratedPDBInspection + ] = {} + inspection_error: str | None = None + if metadata is not None: + try: + inspections = build_generated_pdb_inspections(metadata) + except Exception as exc: + inspection_error = str(exc) + else: + inspection_lookup = { + ( + inspection.structure, + inspection.motif, + inspection.param, + inspection.file_role, + ): inspection + for inspection in inspections + } + + solvent_lookup = solvent_entry_lookup_for_representatives( + state.representative_selection, + metadata, + ) + self.generated_pdb_table.setRowCount( + len(state.representative_selection.representative_entries) + ) + for row, representative_entry in enumerate( + state.representative_selection.representative_entries + ): + key = ( + representative_entry.structure, + representative_entry.motif, + representative_entry.param, + ) + solvent_entry = solvent_lookup.get(key) + file_role: str | None = None + if active_mode == "full_solvent": + file_role = "completed" + elif active_mode == "no_solvent": + file_role = "no_solvent" + inspection = ( + inspection_lookup.get((*key, file_role)) + if file_role is not None + else None + ) + preview_path = representative_structure_path_for_mode( + representative_entry, + solvent_entry, + active_mode, + ) + source_text = representative_entry.analysis_source or "n/a" + atom_count = representative_entry.atom_count + if solvent_entry is not None: + if active_mode == "full_solvent": + atom_count = solvent_entry.atom_count_completed + source_text = ( + solvent_entry.completion_strategy + or "saved Solvent Shell Builder output" + ) + elif active_mode == "no_solvent": + atom_count = solvent_entry.atom_count_no_solvent + source_text = "saved no-solvent export" + elif active_mode == "partial_solvent": + source_text = "representative selection source" + if inspection is not None and inspection.exists: + atom_count = inspection.atom_count + + detected_state = ( + solvent_entry.detected_source_status_text + if solvent_entry is not None + else self._representative_entry_selected_state_text( + representative_entry, + active_mode, + ) + ) values = [ - inspection.representative_label, - inspection.variant_label, - inspection.file_name, - str(inspection.atom_count) if inspection.exists else "Missing", - inspection.element_counts_text, - str(inspection.solvent_molecule_count), - inspection.molecule_residue_text, + ( + representative_entry.structure + if representative_entry.motif == "no_motif" + else ( + f"{representative_entry.structure}/" + f"{representative_entry.motif}" + ) + ), + detected_state, + active_mode_label, + preview_path.name, + str(atom_count), + source_text, + ] + details_lines = [ + f"Representative: {values[0]}", + f"Detected source solvent state: {detected_state}", + f"Active structure set: {active_mode_label}", + f"Structure file: {preview_path}", + ( + "Selected weight: " + f"{representative_entry.selected_weight:.6g}" + ), + f"Cluster count: {representative_entry.cluster_count}", + ( + "Representative atom count: " + f"{representative_entry.atom_count}" + ), + ( + "Representative selection source: " + f"{representative_entry.analysis_source}" + ), ] + if representative_entry.element_counts: + details_lines.append( + "Representative elements: " + + ", ".join( + f"{element}:{count}" + for element, count in sorted( + representative_entry.element_counts.items() + ) + ) + ) + if solvent_entry is not None: + details_lines.extend( + [ + f"No-solvent PDB: {solvent_entry.no_solvent_pdb}", + f"Full-solvent PDB: {solvent_entry.completed_pdb}", + ( + "Saved solvent-handling strategy: " + f"{solvent_entry.completion_strategy or 'n/a'}" + ), + ] + ) + detail_sections = ["\n".join(details_lines)] + if inspection is not None: + detail_sections.append(inspection.details_text()) + if solvent_entry is not None and solvent_entry.analysis_summary: + detail_sections.append(solvent_entry.analysis_summary) + if solvent_entry is not None and solvent_entry.build_summary: + detail_sections.append(solvent_entry.build_summary) + if inspection_error: + detail_sections.append( + "Generated structure inspection warning:\n" + + inspection_error + ) + self._generated_pdb_inspections.append(inspection) + self._solvent_table_preview_paths.append( + preview_path if preview_path.is_file() else None + ) + self._solvent_table_details.append( + "\n\n".join(section for section in detail_sections if section) + ) for column, value in enumerate(values): self.generated_pdb_table.setItem( row, column, QTableWidgetItem(value), ) - if inspections: + + if state.representative_selection.representative_entries: self.generated_pdb_table.selectRow(0) - else: - self.generated_pdb_details_box.setPlainText( - "No generated representative PDB files are available." - ) - self.open_generated_pdb_preview_button.setEnabled(False) + return + + self.generated_pdb_details_box.setPlainText( + "No representative structures are available for browsing." + ) + self.generated_pdb_viewer_status_label.setText( + "No representative structures are available for preview." + ) + self.open_generated_pdb_preview_button.setEnabled(False) def _selected_generated_pdb_inspection( self, @@ -2848,16 +3839,27 @@ def _selected_generated_pdb_inspection( return self._generated_pdb_inspections[row] def _on_generated_pdb_selection_changed(self) -> None: - inspection = self._selected_generated_pdb_inspection() - if inspection is None: + row = self.generated_pdb_table.currentRow() + if row < 0 or row >= len(self._solvent_table_details): self.generated_pdb_details_box.setPlainText( - "Select a generated PDB row to review its residue details." + "Select a representative row to review the active structure details." + ) + self.generated_pdb_viewer.draw_placeholder() + self.generated_pdb_viewer_status_label.setText( + "Select a representative row to preview the active structure." ) self.open_generated_pdb_preview_button.setEnabled(False) return - self.generated_pdb_details_box.setPlainText(inspection.details_text()) + self.generated_pdb_details_box.setPlainText( + self._solvent_table_details[row] + ) + preview_path = self._solvent_table_preview_paths[row] + self._refresh_generated_pdb_viewer(preview_path) + inspection = self._selected_generated_pdb_inspection() self.open_generated_pdb_preview_button.setEnabled( - inspection.exists and inspection.load_error is None + inspection is not None + and inspection.exists + and inspection.load_error is None ) def _open_selected_generated_pdb_preview(self) -> None: @@ -2865,8 +3867,9 @@ def _open_selected_generated_pdb_preview(self) -> None: if inspection is None: QMessageBox.information( self, - "No generated PDB selected", - "Select a generated representative PDB before opening the preview window.", + "No previewable PDB selected", + "Switch to a saved no-solvent or full-solvent representative " + "structure set before opening the separate PDB preview window.", ) return if not inspection.exists or inspection.load_error is not None: @@ -2888,10 +3891,273 @@ def _open_selected_generated_pdb_preview(self) -> None: f"Opened generated PDB preview: {inspection.file_name}" ) + def _refresh_generated_pdb_viewer( + self, + preview_path: Path | None, + ) -> None: + if preview_path is None or not preview_path.is_file(): + self.generated_pdb_viewer.draw_placeholder() + self.generated_pdb_viewer_status_label.setText( + "No previewable representative structure is available for the selected row." + ) + return + try: + structure = load_electron_density_structure( + preview_path, + center_mode="center_of_mass", + include_bonds=True, + include_comment=True, + ) + except Exception as exc: + self.generated_pdb_viewer.draw_placeholder() + self.generated_pdb_viewer_status_label.setText( + f"Unable to preview {preview_path.name}: {exc}" + ) + return + self.generated_pdb_viewer.set_structure( + structure, + scene_key=f"rmcsetup-solvent:{preview_path}", + ) + self.generated_pdb_viewer_status_label.setText( + f"Previewing {preview_path.name} with {structure.atom_count} atom(s)." + ) + + def _update_solvent_status_panel( + self, + metadata: SolventHandlingMetadata | None, + ) -> None: + analysis = self._solvent_distribution_analysis + state = self._project_source_state + status_lookup = { + "complete_solvent": "Complete solvent molecules detected", + "partial_solvent": "Partial solvent molecules detected", + "no_solvent": "No solvent molecules detected", + "mixed_complete_and_partial": ( + "Complete and partial solvent molecules detected" + ), + "unknown": "Unknown solvent state", + } + if ( + state is not None + and self._active_representative_structure_set_is_ready() + and (metadata is None or not metadata.entries) + ): + active_mode = self._active_generated_pdb_mode() + self.solvent_status_headline_label.setText( + "Full-solvent representative structures are selected." + ) + self.solvent_status_stats_label.setText( + "\n".join( + [ + "The active representative source files already " + "provide the Full solvent structure set.", + ( + "Active representative structure set: " + + representative_structure_mode_label(active_mode) + ), + "Solvent Shell Builder readiness: Ready for Packmol", + ( + "Solvent state analysis and solvent-shell " + "building are not required for this selection." + ), + ] + ) + ) + return + if metadata is not None: + active_mode = self._active_generated_pdb_mode() + self.solvent_status_headline_label.setText( + status_lookup.get( + metadata.detected_distribution_status, + metadata.detected_distribution_status.replace("_", " "), + ) + ) + status_lines = [ + metadata.detected_distribution_note + or "Representative solvent outputs have been generated.", + ( + "Recognized solute elements: " + + ", ".join( + f"{element}:{count}" + for element, count in sorted( + metadata.aggregate_solute_element_counts.items() + ) + ) + if metadata.aggregate_solute_element_counts + else "Recognized solute elements: none" + ), + ( + "Active representative structure set: " + + representative_structure_mode_label(active_mode) + ), + ( + "Solvent Shell Builder readiness: Ready for Packmol" + if active_mode == "full_solvent" + else "Solvent Shell Builder readiness: Select Full solvent " + "to mark this step ready for Packmol" + ), + ] + self.solvent_status_stats_label.setText("\n".join(status_lines)) + return + if analysis is None: + self.solvent_status_headline_label.setText( + "No representative solvent state has been determined yet." + ) + if self._solvent_shell_builder_required(): + self.solvent_status_stats_label.setText( + "Choose a solvent reference, set coordination options as " + "needed, and press Build Solvent-Decorated Representative " + "PDBs. The required solvent-state analysis will run first." + ) + else: + self.solvent_status_stats_label.setText( + "Save representative structures before running the " + "Solvent Shell Builder." + ) + return + self.solvent_status_headline_label.setText( + status_lookup.get( + analysis.distribution_status, + analysis.distribution_status.replace("_", " "), + ) + ) + status_lines = [ + f"Representative entries analyzed: {len(analysis.entries)}", + ( + "Saved full-solvent representatives are not available yet. " + "Build solvent-decorated representative PDBs to store the " + "Full solvent representative structure set." + ), + ] + if analysis.distribution_note: + status_lines.append(analysis.distribution_note) + status_lines.append( + ( + "Recognized solute elements: " + + ", ".join( + f"{element}:{count}" + for element, count in sorted( + analysis.aggregate_solute_element_counts.items() + ) + ) + ) + if analysis.aggregate_solute_element_counts + else "Recognized solute elements: none" + ) + self.solvent_status_stats_label.setText("\n".join(status_lines)) + + def _populate_solvent_cutoff_table( + self, + element_counts: dict[str, int], + element_settings: dict[str, SoluteAtomBuildSetting] | None = None, + ) -> None: + self._updating_solvent_table = True + try: + self._solvent_cutoff_spins = {} + self._solvent_coordination_center_items = {} + self._solvent_coordination_target_spins = {} + self.solvent_cutoff_table.setRowCount(0) + settings_lookup = element_settings or {} + if not element_counts: + self.solvent_cutoff_status_label.setText( + "Analyze or build the representative structures to " + "populate the solute atom types used for solvent-shell " + "building." + ) + return + self.solvent_cutoff_status_label.setText( + "Choose which solute elements should coordinate the solvent, " + "set their average coordination targets, and review the " + "director-atom cutoffs that will be used across the " + "representative set." + ) + sorted_counts = sorted(element_counts.items()) + self.solvent_cutoff_table.setRowCount(len(sorted_counts)) + for row, (element, count) in enumerate(sorted_counts): + setting = settings_lookup.get( + str(element), SoluteAtomBuildSetting() + ) + self.solvent_cutoff_table.setItem( + row, 0, QTableWidgetItem(str(element)) + ) + self.solvent_cutoff_table.setItem( + row, 1, QTableWidgetItem(str(count)) + ) + center_item = QTableWidgetItem("") + center_item.setFlags( + Qt.ItemFlag.ItemIsEnabled + | Qt.ItemFlag.ItemIsSelectable + | Qt.ItemFlag.ItemIsUserCheckable + ) + center_item.setCheckState( + Qt.CheckState.Checked + if setting.coordination_center + else Qt.CheckState.Unchecked + ) + self.solvent_cutoff_table.setItem(row, 2, center_item) + self._solvent_coordination_center_items[str(element)] = ( + center_item + ) + + coordination_spin = QDoubleSpinBox(self.solvent_cutoff_table) + coordination_spin.setDecimals(2) + coordination_spin.setRange(0.0, 12.0) + coordination_spin.setSingleStep(0.25) + coordination_spin.setValue( + float(setting.target_coordination_number) + ) + coordination_spin.valueChanged.connect( + self._update_solvent_build_panel_state + ) + self.solvent_cutoff_table.setCellWidget( + row, 3, coordination_spin + ) + self._solvent_coordination_target_spins[str(element)] = ( + coordination_spin + ) + + cutoff_spin = QDoubleSpinBox(self.solvent_cutoff_table) + cutoff_spin.setDecimals(3) + cutoff_spin.setRange(0.0, 20.0) + cutoff_spin.setSingleStep(0.1) + cutoff_spin.setSuffix(" A") + cutoff_spin.setValue(float(setting.director_distance_cutoff_a)) + cutoff_spin.valueChanged.connect( + self._update_solvent_build_panel_state + ) + self.solvent_cutoff_table.setCellWidget(row, 4, cutoff_spin) + self._solvent_cutoff_spins[str(element)] = cutoff_spin + finally: + self._updating_solvent_table = False + + def _selected_solvent_coordination_settings( + self, + ) -> dict[str, SoluteAtomBuildSetting]: + settings: dict[str, SoluteAtomBuildSetting] = {} + for element, cutoff_spin in self._solvent_cutoff_spins.items(): + center_item = self._solvent_coordination_center_items.get(element) + coordination_spin = self._solvent_coordination_target_spins.get( + element + ) + settings[str(element)] = SoluteAtomBuildSetting( + coordination_center=( + center_item is not None + and center_item.checkState() == Qt.CheckState.Checked + ), + target_coordination_number=( + float(coordination_spin.value()) + if coordination_spin is not None + else 0.0 + ), + director_distance_cutoff_a=float(cutoff_spin.value()), + ) + return settings + def _apply_packmol_planning_metadata( self, metadata: PackmolPlanningMetadata | None, ) -> None: + state = self._project_source_state settings = ( metadata.settings if metadata is not None @@ -2902,15 +4168,46 @@ def _apply_packmol_planning_metadata( settings.planning_mode, ) self.packmol_box_side_spin.setValue(settings.box_side_length_a) + selected_reference = settings.free_solvent_reference + if ( + selected_reference is None + and metadata is not None + and metadata.solvent_allocation is not None + ): + selected_reference = metadata.solvent_allocation.reference_path + if ( + selected_reference is None + and state is not None + and state.packmol_setup is not None + ): + selected_reference = ( + state.packmol_setup.free_solvent_reference_path + ) + if ( + selected_reference is None + and state is not None + and state.solvent_handling is not None + ): + selected_reference = state.solvent_handling.reference_path + self._populate_packmol_free_solvent_choices( + selected_identifier=selected_reference + ) self._refresh_packmol_plan_plot(metadata) def _apply_packmol_setup_metadata( self, metadata: PackmolSetupMetadata | None, ) -> None: + settings = ( + metadata.settings + if metadata is not None + else PackmolSetupSettings() + ) + self.packmol_tolerance_spin.setValue(settings.tolerance_angstrom) self.packmol_build_summary_box.setPlainText( self._packmol_setup_summary_text(metadata) ) + self.open_packmol_setup_folder_button.setEnabled(metadata is not None) def _apply_constraint_metadata( self, @@ -2930,6 +4227,8 @@ def _apply_constraint_metadata( self.constraints_summary_box.setPlainText( self._constraint_summary_text(metadata) ) + self.open_constraints_folder_button.setEnabled(metadata is not None) + self.preview_constraints_button.setEnabled(metadata is not None) def _on_representative_mode_changed(self) -> None: self._update_representative_mode_widgets() @@ -2959,6 +4258,125 @@ def _update_solvent_reference_widgets(self) -> None: self.solvent_preset_combo.setEnabled(not use_custom) self.solvent_reference_edit.setEnabled(use_custom) self.browse_solvent_reference_button.setEnabled(use_custom) + self._populate_solvent_director_atom_choices() + self._update_solvent_reference_details() + + def _handle_solvent_reference_source_changed(self) -> None: + self._update_solvent_reference_widgets() + self._clear_solvent_analysis_outputs() + + def _handle_solvent_reference_changed(self) -> None: + self._populate_solvent_director_atom_choices() + self._update_solvent_reference_details() + self._clear_solvent_analysis_outputs() + + def _handle_solvent_analysis_setting_changed(self) -> None: + self._clear_solvent_analysis_outputs() + + def _handle_solvent_coordination_table_item_changed( + self, + _item: QTableWidgetItem, + ) -> None: + if self._updating_solvent_table: + return + self._update_solvent_build_panel_state() + + def _selected_solvent_reference_identifier(self) -> str | None: + source = str( + self.solvent_reference_source_combo.currentData() or "preset" + ) + if source == "custom": + reference_path = self.solvent_reference_edit.text().strip() + return reference_path or None + selected_name = self.solvent_preset_combo.currentData() + if selected_name is None: + return None + return str(selected_name) + + def _populate_solvent_director_atom_choices( + self, + *, + selected_name: str | None = None, + ) -> None: + self.solvent_director_atom_combo.blockSignals(True) + self.solvent_director_atom_combo.clear() + reference_identifier = self._selected_solvent_reference_identifier() + if reference_identifier is None: + self.solvent_director_atom_combo.blockSignals(False) + return + try: + atom_names = reference_atom_choices(reference_identifier) + suggested_name = selected_name or default_director_atom_name( + reference_identifier + ) + except Exception: + self.solvent_director_atom_combo.blockSignals(False) + return + for atom_name in atom_names: + self.solvent_director_atom_combo.addItem(atom_name, atom_name) + if suggested_name is not None: + suggested_index = self.solvent_director_atom_combo.findData( + suggested_name + ) + if suggested_index >= 0: + self.solvent_director_atom_combo.setCurrentIndex( + suggested_index + ) + elif atom_names: + self.solvent_director_atom_combo.setCurrentIndex(0) + self.solvent_director_atom_combo.blockSignals(False) + + def _update_solvent_reference_details(self) -> None: + reference_identifier = self._selected_solvent_reference_identifier() + if reference_identifier is None: + self.solvent_reference_details_box.setPlainText( + "Choose a solvent reference before analyzing representative structures." + ) + return + try: + atom_names = reference_atom_choices(reference_identifier) + suggested_director = default_director_atom_name( + reference_identifier + ) + reference_path = Path(reference_identifier).expanduser() + reference_name = reference_path.stem + source_text = ( + str(reference_path.resolve()) + if reference_path.is_file() + else reference_name + ) + except Exception as exc: + self.solvent_reference_details_box.setPlainText( + f"Unable to inspect the selected solvent reference: {exc}" + ) + return + director_text = suggested_director or "n/a" + self.solvent_reference_details_box.setPlainText( + f"Reference molecule: {reference_name}\n" + f"Atom count: {len(atom_names)}\n" + f"Suggested director atom: {director_text}\n" + f"Reference source: {source_text}" + ) + + def _clear_solvent_analysis_outputs(self) -> None: + self._solvent_distribution_analysis = None + state = self._project_source_state + self.solvent_summary_box.setPlainText( + self._solvent_summary_text( + state.solvent_handling if state is not None else None + ) + ) + self._populate_solvent_cutoff_table( + {}, + self._selected_solvent_coordination_settings(), + ) + self._refresh_generated_pdb_browser( + state.solvent_handling if state is not None else None + ) + self._update_solvent_status_panel( + state.solvent_handling if state is not None else None + ) + self._update_solvent_build_panel_state() def _current_representative_settings( self, @@ -3151,7 +4569,7 @@ def _finish_representative_selection( "Computed representative clusters in " f"{metadata.selection_mode} mode." ) - self.build_solvent_outputs_button.setEnabled(True) + self._populate_solvent_controls() def _fail_representative_selection(self, message: str) -> None: self.compute_representatives_button.setEnabled(True) @@ -3196,7 +4614,7 @@ def _preview_representative_clusters(self) -> None: self, "No representative selection", ( - "Compute representative clusters before opening the " + "Save representative structures before opening the " "preview window." ), ) @@ -3245,13 +4663,19 @@ def _browse_solvent_reference_pdb(self) -> None: if not selected_path: return self.solvent_reference_edit.setText(selected_path) + self._handle_solvent_reference_changed() def _current_solvent_settings(self) -> SolventHandlingSettings: + selected_mode = self.generated_pdb_mode_combo.currentData() + coordinated_solvent_mode = ( + str(selected_mode).strip() + if selected_mode is not None + and str(selected_mode).strip() + in {"no_solvent", "partial_solvent", "full_solvent"} + else "automatic_detection" + ) return SolventHandlingSettings( - coordinated_solvent_mode=str( - self.coordinated_solvent_mode_combo.currentData() - or "no_coordinated_solvent" - ), + coordinated_solvent_mode=coordinated_solvent_mode, reference_source=str( self.solvent_reference_source_combo.currentData() or "preset" ), @@ -3259,21 +4683,228 @@ def _current_solvent_settings(self) -> SolventHandlingSettings: custom_reference_path=( self.solvent_reference_edit.text().strip() or None ), + reference_match_tolerance_a=float( + self.solvent_reference_match_tolerance_spin.value() + ), + director_atom_name=( + str(self.solvent_director_atom_combo.currentData()) + if self.solvent_director_atom_combo.currentData() is not None + else None + ), minimum_solvent_atom_separation_a=float( self.solvent_minimum_separation_spin.value() ), + solute_atom_settings=self._selected_solvent_coordination_settings(), + ) + + def _update_solvent_build_panel_state(self) -> None: + state = self._project_source_state + reference_identifier = self._selected_solvent_reference_identifier() + has_reference = reference_identifier is not None + has_representatives = ( + state is not None and state.representative_selection is not None + ) + builder_required = self._solvent_shell_builder_required() + self._set_solvent_shell_builder_controls_enabled(builder_required) + if not builder_required: + self.analyze_solvent_outputs_button.setEnabled(False) + self.build_solvent_outputs_button.setEnabled(False) + return + self.analyze_solvent_outputs_button.setEnabled( + bool(has_reference and has_representatives) + ) + if not has_reference or not has_representatives: + self.build_solvent_outputs_button.setEnabled(False) + return + analysis = self._solvent_distribution_analysis + if analysis is None: + self.build_solvent_outputs_button.setEnabled(True) + return + if self._solvent_build_has_required_coordination_settings(analysis): + self.build_solvent_outputs_button.setEnabled(True) + return + self.build_solvent_outputs_button.setEnabled(False) + + def _analyze_representative_solvent_states(self) -> None: + self._run_representative_solvent_analysis( + progress_message="Analyzing representative solvent states...", + completion_message="Representative solvent analysis ready.", + log_completion="Analyzed representative solvent states.", ) + def _run_representative_solvent_analysis( + self, + *, + progress_message: str, + completion_message: str | None = None, + log_completion: str | None = None, + ) -> RepresentativeSolventDistributionAnalysis | None: + state = self._project_source_state + if state is None: + QMessageBox.information( + self, + "No SAXS project loaded", + "Load a SAXS project before analyzing representative solvent states.", + ) + return None + if state.representative_selection is None: + QMessageBox.information( + self, + "No representative selection", + "Save representative structures before analyzing the representative solvent states.", + ) + return None + if self._selected_solvent_reference_identifier() is None: + QMessageBox.information( + self, + "No solvent reference selected", + "Choose a bundled preset or custom solvent reference PDB before analyzing the representative structures.", + ) + return None + try: + self._set_task_progress( + progress_message, + 35, + ) + self._append_run_log("Analyzing representative solvent states.") + analysis = analyze_representative_solvent_distribution( + state, + self._current_solvent_settings(), + representative_metadata=state.representative_selection, + ) + except Exception as exc: + self._solvent_distribution_analysis = None + message = ( + "Unable to analyze representative solvent states: " f"{exc}" + ) + self.solvent_summary_box.setPlainText(message) + self._append_run_log(message) + self._set_task_progress("Solvent analysis failed.", 0) + QMessageBox.warning( + self, + "Representative solvent analysis failed", + str(exc), + ) + self._populate_solvent_cutoff_table( + {}, + self._selected_solvent_coordination_settings(), + ) + self._refresh_generated_pdb_browser(None) + self._update_solvent_status_panel(None) + self._update_solvent_build_panel_state() + return None + self._solvent_distribution_analysis = analysis + self.solvent_summary_box.setPlainText(analysis.summary_text()) + self._populate_solvent_cutoff_table( + analysis.aggregate_solute_element_counts, + self._selected_solvent_coordination_settings(), + ) + self._refresh_generated_pdb_browser(None) + self._update_solvent_status_panel(None) + self._update_solvent_build_panel_state() + if completion_message is not None: + self._set_task_progress(completion_message, 100) + if log_completion is not None: + self._append_run_log(log_completion) + return analysis + def _current_packmol_planning_settings(self) -> PackmolPlanningSettings: return PackmolPlanningSettings( planning_mode=str( self.packmol_planning_mode_combo.currentData() or "per_element" ), box_side_length_a=float(self.packmol_box_side_spin.value()), + free_solvent_reference=self._selected_packmol_free_solvent_reference(), ) def _current_packmol_setup_settings(self) -> PackmolSetupSettings: - return PackmolSetupSettings() + return PackmolSetupSettings( + tolerance_angstrom=float(self.packmol_tolerance_spin.value()), + free_solvent_reference=self._selected_packmol_free_solvent_reference(), + ) + + def _selected_packmol_free_solvent_reference(self) -> str | None: + current_data = self.packmol_free_solvent_combo.currentData() + if current_data is None: + return None + text = str(current_data).strip() + return text or None + + def _normalize_packmol_free_solvent_identifier( + self, + identifier: str | None, + ) -> str | None: + if identifier is None: + return None + text = str(identifier).strip() + if not text: + return None + candidate = Path(text).expanduser() + if candidate.is_file(): + return str(candidate.resolve()) + for preset in self._available_solvent_presets: + preset_path = str(Path(preset.path).expanduser().resolve()) + if text == preset.name or text == preset_path: + return preset_path + return None + + def _populate_packmol_free_solvent_choices( + self, + *, + selected_identifier: str | None = None, + ) -> None: + state = self._project_source_state + current_identifier = ( + selected_identifier + or self._selected_packmol_free_solvent_reference() + ) + combo = self.packmol_free_solvent_combo + combo.blockSignals(True) + combo.clear() + seen_paths: set[str] = set() + for preset in self._available_solvent_presets: + preset_path = str(Path(preset.path).expanduser().resolve()) + combo.addItem(preset.name, preset_path) + seen_paths.add(preset_path) + + extra_identifiers = [ + current_identifier, + ( + None + if state is None or state.solvent_handling is None + else state.solvent_handling.reference_path + ), + ( + None + if state is None or state.packmol_planning is None + else state.packmol_planning.settings.free_solvent_reference + ), + ( + None + if state is None or state.packmol_setup is None + else state.packmol_setup.free_solvent_reference_path + ), + ] + for identifier in extra_identifiers: + normalized = self._normalize_packmol_free_solvent_identifier( + identifier + ) + if normalized is None or normalized in seen_paths: + continue + combo.addItem(Path(normalized).stem, normalized) + seen_paths.add(normalized) + + target_identifier = self._normalize_packmol_free_solvent_identifier( + current_identifier + ) + if target_identifier is not None: + index = combo.findData(target_identifier) + if index >= 0: + combo.setCurrentIndex(index) + elif combo.count() > 0: + combo.setCurrentIndex(0) + combo.setEnabled(combo.count() > 0 and state is not None) + combo.blockSignals(False) def _current_constraint_settings(self) -> ConstraintGenerationSettings: return ConstraintGenerationSettings( @@ -3285,6 +4916,49 @@ def _current_constraint_settings(self) -> ConstraintGenerationSettings: ), ) + def _sync_packmol_inputs_to_linked_container( + self, + setup_metadata: PackmolSetupMetadata, + ) -> PackmolDockerSyncResult | None: + state = self._project_source_state + if state is None or state.packmol_docker_link is None: + return None + link = state.packmol_docker_link + try: + result = self._create_packmol_docker_client().sync_packmol_inputs( + link, + state.rmcsetup_paths.packmol_inputs_dir, + packmol_setup_metadata=setup_metadata, + ) + except Exception as exc: + link.last_sync_at = datetime.now().isoformat(timespec="seconds") + link.last_sync_status = "error" + link.last_sync_message = str(exc) + self._save_packmol_docker_link(link) + self.packmol_docker_summary_box.setPlainText( + self._packmol_docker_summary_text(setup_metadata) + ) + self._append_run_log( + "Packmol Docker sync failed after local build: " f"{exc}" + ) + QMessageBox.warning( + self, + "Packmol Docker sync failed", + "Packmol inputs were built locally, but syncing them to the " + "linked Docker container failed.\n\n" + f"{exc}", + ) + return None + link.last_sync_at = result.synced_at + link.last_sync_status = "success" + link.last_sync_message = result.summary_text() + self._save_packmol_docker_link(link) + self.packmol_docker_summary_box.setPlainText( + self._packmol_docker_summary_text(setup_metadata) + ) + self._append_run_log(result.summary_text()) + return result + def _build_representative_solvent_outputs(self) -> None: state = self._project_source_state if state is None: @@ -3298,21 +4972,53 @@ def _build_representative_solvent_outputs(self) -> None: QMessageBox.information( self, "No representative selection", - "Compute representative clusters before building representative PDB outputs.", + "Save representative structures before building representative PDB outputs.", + ) + return + if not self._solvent_shell_builder_required(): + QMessageBox.information( + self, + "Full-solvent representatives selected", + "The active representative structure set already has full " + "solvent, so solvent-state analysis and solvent-shell " + "building are not required.", + ) + return + analysis = self._solvent_distribution_analysis + if analysis is None: + analysis = self._run_representative_solvent_analysis( + progress_message=( + "Analyzing representative solvent states before build..." + ), + ) + if analysis is None: + return + if not self._solvent_build_has_required_coordination_settings( + analysis + ): + self._update_solvent_build_panel_state() + QMessageBox.information( + self, + "Coordination settings required", + "Select at least one coordination-center element and set " + "its average coordination number and director distance " + "before building full-solvent representatives from " + "no-solvent structures.", ) return try: self._set_task_progress( - "Building solvent-aware representative PDBs...", + "Building solvent-decorated representative PDBs...", 35, ) self._append_run_log( - "Starting solvent-aware representative PDB build." + "Starting Solvent Shell Builder representative PDB build." ) solvent_metadata = build_representative_solvent_outputs( state, self._current_solvent_settings(), representative_metadata=state.representative_selection, + distribution_analysis=analysis, ) except Exception as exc: message = f"Unable to build representative PDB outputs: {exc}" @@ -3321,11 +5027,12 @@ def _build_representative_solvent_outputs(self) -> None: self._set_task_progress("Solvent build failed.", 0) QMessageBox.warning( self, - "Solvent handling failed", + "Solvent Shell Builder failed", str(exc), ) return state.solvent_handling = solvent_metadata + state.packmol_planning = None state.packmol_setup = None state.constraint_generation = None self._apply_solvent_metadata(solvent_metadata) @@ -3336,10 +5043,12 @@ def _build_representative_solvent_outputs(self) -> None: self._populate_constraint_controls() self._update_readiness_progress() self._set_task_progress( - "Solvent-aware representative PDBs ready.", + "Solvent-decorated representative PDBs ready.", 100, ) - self._append_run_log("Built representative solvent-aware PDB outputs.") + self._append_run_log( + "Built solvent-decorated representative PDB outputs." + ) def _compute_packmol_plan(self) -> None: state = self._project_source_state @@ -3354,7 +5063,7 @@ def _compute_packmol_plan(self) -> None: QMessageBox.information( self, "No representative selection", - "Compute representative clusters before planning Packmol counts.", + "Save representative structures before planning Packmol counts.", ) return if state.solution_properties.result is None: @@ -3393,14 +5102,12 @@ def _compute_packmol_plan(self) -> None: state.constraint_generation = None self.compute_packmol_plan_button.setEnabled(True) self.build_packmol_setup_button.setEnabled( - state.solvent_handling is not None + self._active_representative_structure_set_is_ready() ) self.packmol_plan_summary_box.setPlainText( self._packmol_plan_summary_text(planning_metadata) ) - self.packmol_build_summary_box.setPlainText( - self._packmol_setup_summary_text(None) - ) + self._apply_packmol_setup_metadata(None) self._populate_constraint_controls() self._refresh_packmol_plan_plot(planning_metadata) self.output_summary_box.setPlainText(self._output_structure_text()) @@ -3424,11 +5131,20 @@ def _build_packmol_setup(self) -> None: "Compute Packmol cluster counts before generating Packmol inputs.", ) return - if state.solvent_handling is None: + if not self._active_representative_structure_set_is_ready(): + QMessageBox.information( + self, + "Full-solvent representatives not selected", + "Select the Full solvent representative structure set in " + "Representative Structures before generating Packmol inputs.", + ) + return + if self._selected_packmol_free_solvent_reference() is None: QMessageBox.information( self, - "No solvent handling metadata", - "Build representative solvent-aware PDB outputs before generating Packmol inputs.", + "No free-solvent structure selected", + "Choose the free-solvent structure used for the Packmol " + "bulk-solvent population before generating Packmol inputs.", ) return try: @@ -3457,15 +5173,128 @@ def _build_packmol_setup(self) -> None: return state.packmol_setup = setup_metadata state.constraint_generation = None - self.packmol_build_summary_box.setPlainText( - self._packmol_setup_summary_text(setup_metadata) - ) + self._sync_packmol_inputs_to_linked_container(setup_metadata) + self._apply_packmol_setup_metadata(setup_metadata) self._populate_constraint_controls() self.output_summary_box.setPlainText(self._output_structure_text()) self._update_readiness_progress() self._set_task_progress("Packmol setup ready.", 100) self._append_run_log("Built Packmol setup inputs and audit report.") + def _open_packmol_setup_folder(self) -> None: + state = self._project_source_state + if state is None or state.packmol_setup is None: + QMessageBox.information( + self, + "No Packmol setup", + "Build Packmol setup inputs before opening the setup folder.", + ) + return + folder_path = ( + state.rmcsetup_paths.packmol_inputs_dir.expanduser().resolve() + ) + if not folder_path.is_dir(): + QMessageBox.warning( + self, + "Packmol setup folder missing", + f"Could not find the Packmol setup folder:\n{folder_path}", + ) + return + try: + self._open_path_in_file_manager(folder_path) + except Exception as exc: + QMessageBox.warning( + self, + "Packmol setup folder", + f"Could not open the Packmol setup folder:\n{exc}", + ) + return + self.statusBar().showMessage( + f"Opened Packmol setup folder: {folder_path.name}" + ) + self._append_run_log( + f"Opened Packmol setup folder in Finder/file manager: {folder_path}" + ) + + def _open_constraints_folder(self) -> None: + state = self._project_source_state + if state is None or state.constraint_generation is None: + QMessageBox.information( + self, + "No constraints generated", + "Generate constraints before opening the constraints folder.", + ) + return + merged_path = ( + Path(state.constraint_generation.merged_constraints_path) + .expanduser() + .resolve() + ) + if not merged_path.is_file(): + QMessageBox.warning( + self, + "Constraints file missing", + f"Could not find the merged constraints file:\n{merged_path}", + ) + return + try: + self._open_path_in_file_manager(merged_path) + except Exception as exc: + QMessageBox.warning( + self, + "Constraints folder", + f"Could not open the constraints folder:\n{exc}", + ) + return + self.statusBar().showMessage( + f"Opened constraints file location: {merged_path.name}" + ) + self._append_run_log( + "Opened constraints file location in Finder/file manager: " + f"{merged_path}" + ) + + def _open_constraints_preview(self) -> None: + state = self._project_source_state + if state is None or state.constraint_generation is None: + QMessageBox.information( + self, + "No constraints generated", + "Generate constraints before opening the merged constraints preview.", + ) + return + merged_path = ( + Path(state.constraint_generation.merged_constraints_path) + .expanduser() + .resolve() + ) + if not merged_path.is_file(): + QMessageBox.warning( + self, + "Constraints preview unavailable", + f"Could not find the merged constraints file:\n{merged_path}", + ) + return + try: + self._constraints_preview_window = ConstraintsPreviewWindow( + merged_path, + parent=self, + ) + except Exception as exc: + QMessageBox.warning( + self, + "Constraints preview unavailable", + f"Could not open the merged constraints file:\n{exc}", + ) + return + self._track_child_tool_window(self._constraints_preview_window) + self._constraints_preview_window.show() + self._constraints_preview_window.raise_() + self._constraints_preview_window.activateWindow() + self._append_run_log( + f"Opened merged constraints preview: {merged_path.name}" + ) + def _generate_constraints(self) -> None: state = self._project_source_state if state is None: @@ -3505,9 +5334,7 @@ def _generate_constraints(self) -> None: ) return state.constraint_generation = metadata - self.constraints_summary_box.setPlainText( - self._constraint_summary_text(metadata) - ) + self._apply_constraint_metadata(metadata) self.output_summary_box.setPlainText(self._output_structure_text()) self._update_readiness_progress() self._set_task_progress("Constraint generation complete.", 100) @@ -3848,14 +5675,28 @@ def _solvent_summary_text( state = self._project_source_state if state is None: return ( - "Load a SAXS project and compute representative clusters before " - "building solvent-aware PDB outputs." + "Load a SAXS project and save representative structures " + "before running the Solvent Shell Builder." ) if metadata is None: + active_mode = self._active_generated_pdb_mode() + if active_mode == "full_solvent": + return ( + "Imported representative structures already include the " + "Full solvent structure set, so this step is ready for " + "Packmol without rebuilding solvent outputs.\n\n" + "Solvent state analysis and solvent-shell building are " + "not required for this selection.\n\n" + "Metadata will be written to:\n" + f"{state.rmcsetup_paths.solvent_handling_path}" + ) return ( - "No representative PDB solvent export has been built yet.\n\n" - "Choose a coordinated solvent mode, select a bundled preset or " - "custom solvent PDB, and press Build Representative PDBs.\n\n" + "No full-solvent representative export has been built yet " + f"for the active {representative_structure_mode_label(active_mode)} " + "set.\n\n" + "Choose a solvent reference, review the coordination " + "settings, and press Build Solvent-Decorated Representative " + "PDBs. The required solvent-state analysis will run first.\n\n" "Metadata will be written to:\n" f"{state.rmcsetup_paths.solvent_handling_path}" ) @@ -3873,15 +5714,17 @@ def _packmol_plan_summary_text( if state is None: return ( "Load a SAXS project, calculate solution properties, and " - "compute representative clusters before planning Packmol " + "save representative structures before planning Packmol " "cluster counts." ) if metadata is None: return ( "No Packmol plan has been saved yet.\n\n" "Press Compute Cluster Counts to convert the selected DREAM " - "distribution and representative clusters into planned box " - "counts and output reports.\n\n" + "distribution and representative structures into planned box " + "counts, solvent-allocation totals, and output reports.\n\n" + "Choose the free-solvent structure above before planning if " + "the representative source files already contain solvent.\n\n" "Metadata will be written to:\n" f"{state.rmcsetup_paths.packmol_plan_path}" ) @@ -3900,26 +5743,83 @@ def _packmol_setup_summary_text( state = self._project_source_state if state is None: return ( - "Load a SAXS project, compute representative clusters, " - "build solvent-aware representative PDBs, and plan counts " + "Load a SAXS project, save representative structures, " + "build solvent-decorated representative PDBs, and plan counts " "before generating Packmol inputs." ) if metadata is None: - return ( + if not self._active_representative_structure_set_is_ready(): + text = ( + "No Packmol setup has been built yet.\n\n" + "Select the Full solvent representative structure set in " + "Representative Structures before generating Packmol " + "inputs.\n\n" + "Metadata will be written to:\n" + f"{state.rmcsetup_paths.packmol_setup_path}" + ) + if state.packmol_docker_link is not None: + text += ( + "\n\nLinked Docker target:\n" + + state.packmol_docker_link.summary_text() + ) + return text + text = ( "No Packmol setup has been built yet.\n\n" "Press Build Packmol Setup to generate representative input " - "PDBs, the Packmol .inp file, the solvent single-molecule PDB, " - "and the audit report.\n\n" + "PDBs, the Packmol .inp file, the selected free-solvent " + "single-molecule PDB, and the audit report.\n\n" "Metadata will be written to:\n" f"{state.rmcsetup_paths.packmol_setup_path}" ) - return ( + if state.packmol_docker_link is not None: + text += ( + "\n\nLinked Docker target:\n" + + state.packmol_docker_link.summary_text() + ) + return text + text = ( metadata.summary_text() + "\n\nMetadata path:\n" + str(state.rmcsetup_paths.packmol_setup_path) + "\nAudit report:\n" + str(state.rmcsetup_paths.packmol_audit_report_path) ) + if state.packmol_docker_link is not None: + text += ( + "\n\nLinked Docker target:\n" + + state.packmol_docker_link.summary_text( + packmol_setup_metadata=metadata + ) + ) + return text + + def _packmol_docker_summary_text( + self, + packmol_setup_metadata: PackmolSetupMetadata | None = None, + ) -> str: + state = self._project_source_state + if state is None: + return ( + "Load a SAXS project before linking a Packmol Docker " + "container." + ) + if state.packmol_docker_link is None: + return ( + "No Packmol Docker container is linked yet.\n\n" + "Use Tools > Link Packmol Docker Container to validate a " + "container, confirm Packmol is installed, and choose the " + "container-side project folder. The required bind-mounted " + f"root inside the container is {DEFAULT_PACKMOL_CONTAINER_ROOT}.\n\n" + "Project link metadata will be written to:\n" + f"{state.rmcsetup_paths.packmol_docker_link_path}" + ) + return ( + state.packmol_docker_link.summary_text( + packmol_setup_metadata=packmol_setup_metadata + ) + + "\n\nProject link metadata path:\n" + + str(state.rmcsetup_paths.packmol_docker_link_path) + ) def _constraint_summary_text( self, @@ -4148,14 +6048,17 @@ def _representative_summary_text( if state is None: return ( "Load a SAXS project and choose a DREAM source before " - "selecting representative cluster files." + "loading saved representative structures." ) if metadata is None: return ( - "No representative selection has been saved yet.\n\n" - "Choose a DREAM source and press Compute Representative " - "Clusters to select one representative structure per active " - "cluster bin.\n\n" + "No representative structures have been saved yet.\n\n" + "Use Open Representative Structures to create or update the " + "saved project set, then reload it here. rmcsetup will " + "combine those structure files with the selected DREAM " + "distribution weights, solution-density targets, solvent " + "handling, Packmol planning, and cluster-specific " + "constraints.\n\n" "Metadata will be written to:\n" f"{state.rmcsetup_paths.representative_selection_path}" ) @@ -4495,7 +6398,7 @@ def _plot_dream_model_preview( ) axis.set_xscale("log") axis.set_yscale("log") - axis.set_xlabel("q (Å⁻¹)") + axis.set_xlabel(Q_A_INVERSE_LABEL) axis.set_ylabel("Intensity (arb. units)") axis.set_title(f"DREAM refinement: {plot_data.template_name}") axis.text( @@ -4740,6 +6643,19 @@ def _append_run_log(self, message: str) -> None: timestamp = datetime.now().strftime("%H:%M:%S") self.run_log_box.appendPlainText(f"[{timestamp}] {message}") + @staticmethod + def _open_path_in_file_manager(path: Path) -> None: + resolved_path = path.expanduser().resolve() + target_path = ( + resolved_path if resolved_path.is_dir() else resolved_path.parent + ) + if not target_path.exists(): + raise FileNotFoundError(target_path) + if not QDesktopServices.openUrl(QUrl.fromLocalFile(str(target_path))): + raise RuntimeError( + "Qt could not open the requested folder in the file manager." + ) + def launch_rmcsetup_ui( project_dir: str | Path | None = None, diff --git a/src/saxshell/fullrmc/ui/packmol_docker_dialog.py b/src/saxshell/fullrmc/ui/packmol_docker_dialog.py new file mode 100644 index 0000000..3b67291 --- /dev/null +++ b/src/saxshell/fullrmc/ui/packmol_docker_dialog.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QComboBox, + QDialog, + QDialogButtonBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPlainTextEdit, + QPushButton, + QSplitter, + QTreeWidget, + QTreeWidgetItem, + QVBoxLayout, + QWidget, +) + +from saxshell.fullrmc.packmol_docker import ( + DEFAULT_PACKMOL_CONTAINER_ROOT, + PackmolDockerClient, + PackmolDockerContainerRecord, + PackmolDockerLink, + docker_daemon_unavailable_hint, +) + +_TREE_PATH_ROLE = Qt.ItemDataRole.UserRole +_TREE_LOADED_ROLE = Qt.ItemDataRole.UserRole + 1 + + +class PackmolDockerLinkDialog(QDialog): + def __init__( + self, + *, + current_link: PackmolDockerLink | None = None, + recent_presets: ( + list[PackmolDockerLink] | tuple[PackmolDockerLink, ...] + ) = (), + docker_client: PackmolDockerClient | None = None, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + self.setWindowTitle("Link Packmol Docker Container") + self.resize(920, 620) + + self._docker_client = docker_client or PackmolDockerClient() + self._recent_presets = list(recent_presets) + self._available_containers: list[PackmolDockerContainerRecord] = [] + self._selected_link: PackmolDockerLink | None = None + self._validated_signature: tuple[str, str, str, str] | None = None + + root_layout = QVBoxLayout(self) + root_layout.setContentsMargins(12, 12, 12, 12) + root_layout.setSpacing(12) + + intro_label = QLabel( + "Validate a Packmol-ready Docker container, choose the " + "container-side project folder, and link it to the current " + "rmcsetup workflow. You can pick from discovered Docker " + "containers or type a container name manually. The " + "container-side project folder must live inside the bind-mounted " + f"root {DEFAULT_PACKMOL_CONTAINER_ROOT}." + ) + intro_label.setWordWrap(True) + root_layout.addWidget(intro_label) + + preset_group = QGroupBox("Recent Container Presets") + preset_layout = QHBoxLayout(preset_group) + self.preset_name_edit = QLineEdit() + self.preset_name_edit.setPlaceholderText("Preset display name") + preset_layout.addWidget(self.preset_name_edit, stretch=1) + self.preset_combo = QTreeWidget() + self.preset_combo.setHeaderHidden(True) + self.preset_combo.setMaximumHeight(110) + preset_layout.addWidget(self.preset_combo, stretch=1) + preset_button_column = QVBoxLayout() + self.load_preset_button = QPushButton("Load Selected Preset") + self.load_preset_button.clicked.connect(self._load_selected_preset) + preset_button_column.addWidget(self.load_preset_button) + preset_button_column.addStretch(1) + preset_layout.addLayout(preset_button_column) + root_layout.addWidget(preset_group) + + form_group = QGroupBox("Container Settings") + form_layout = QFormLayout(form_group) + discovered_row = QWidget() + discovered_layout = QHBoxLayout(discovered_row) + discovered_layout.setContentsMargins(0, 0, 0, 0) + discovered_layout.setSpacing(6) + self.available_container_combo = QComboBox() + self.available_container_combo.setMinimumContentsLength(32) + discovered_layout.addWidget(self.available_container_combo, stretch=1) + self.refresh_containers_button = QPushButton("Refresh List") + self.refresh_containers_button.clicked.connect( + self._refresh_available_containers + ) + discovered_layout.addWidget(self.refresh_containers_button) + self.use_available_container_button = QPushButton("Use Selected") + self.use_available_container_button.clicked.connect( + self._use_available_container + ) + discovered_layout.addWidget(self.use_available_container_button) + form_layout.addRow("Discovered containers", discovered_row) + self.container_name_edit = QLineEdit() + self.container_name_edit.setPlaceholderText("Docker container name") + form_layout.addRow("Container name", self.container_name_edit) + self.packmol_command_edit = QLineEdit("packmol") + form_layout.addRow("Packmol command", self.packmol_command_edit) + self.shell_command_edit = QLineEdit("sh") + form_layout.addRow("Shell command", self.shell_command_edit) + self.container_root_edit = QLineEdit(DEFAULT_PACKMOL_CONTAINER_ROOT) + self.container_root_edit.setPlaceholderText( + "/packmol_input_files/project_name" + ) + form_layout.addRow("Container project root", self.container_root_edit) + root_layout.addWidget(form_group) + + command_row = QHBoxLayout() + self.test_connection_button = QPushButton("Test Container") + self.test_connection_button.clicked.connect(self._test_connection) + command_row.addWidget(self.test_connection_button) + self.refresh_tree_button = QPushButton("Refresh Directory Tree") + self.refresh_tree_button.clicked.connect(self._refresh_directory_tree) + command_row.addWidget(self.refresh_tree_button) + self.use_selected_directory_button = QPushButton( + "Use Selected Directory" + ) + self.use_selected_directory_button.clicked.connect( + self._use_selected_directory + ) + command_row.addWidget(self.use_selected_directory_button) + command_row.addStretch(1) + root_layout.addLayout(command_row) + + self.content_splitter = QSplitter(Qt.Orientation.Horizontal) + self.content_splitter.setChildrenCollapsible(False) + root_layout.addWidget(self.content_splitter, stretch=1) + + self.directory_tree = QTreeWidget() + self.directory_tree.setHeaderLabel("Container Directories") + self.directory_tree.itemExpanded.connect(self._on_tree_item_expanded) + self.directory_tree.itemSelectionChanged.connect( + self._on_tree_selection_changed + ) + self.content_splitter.addWidget(self.directory_tree) + + details_panel = QWidget() + details_layout = QVBoxLayout(details_panel) + details_layout.setContentsMargins(0, 0, 0, 0) + details_layout.setSpacing(8) + self.selected_directory_label = QLabel("Selected directory: (none)") + self.selected_directory_label.setWordWrap(True) + details_layout.addWidget(self.selected_directory_label) + self.status_box = QPlainTextEdit() + self.status_box.setReadOnly(True) + self.status_box.setPlainText( + "Press Test Container to initialize Docker, verify Packmol " + "inside the selected container, and load the directory tree." + ) + details_layout.addWidget(self.status_box, stretch=1) + self.content_splitter.addWidget(details_panel) + self.content_splitter.setStretchFactor(0, 1) + self.content_splitter.setStretchFactor(1, 1) + self.content_splitter.setSizes([420, 420]) + + self.button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok + | QDialogButtonBox.StandardButton.Cancel + ) + self.button_box.accepted.connect(self.accept) + self.button_box.rejected.connect(self.reject) + self.button_box.button(QDialogButtonBox.StandardButton.Ok).setText( + "Link Container" + ) + root_layout.addWidget(self.button_box) + + self._populate_preset_tree() + if current_link is not None: + self._apply_link_to_fields(current_link) + elif self._recent_presets: + self._apply_link_to_fields(self._recent_presets[0]) + self._refresh_available_containers(show_feedback=False) + + def selected_link(self) -> PackmolDockerLink | None: + return self._selected_link + + def _populate_preset_tree(self) -> None: + self.preset_combo.clear() + for preset in self._recent_presets: + item = QTreeWidgetItem([preset.resolved_display_name]) + item.setData(0, _TREE_PATH_ROLE, preset.to_dict()) + item.setToolTip(0, preset.summary_text()) + self.preset_combo.addTopLevelItem(item) + has_presets = self.preset_combo.topLevelItemCount() > 0 + self.preset_combo.setEnabled(has_presets) + self.load_preset_button.setEnabled(has_presets) + if has_presets: + self.preset_combo.setCurrentItem(self.preset_combo.topLevelItem(0)) + + def _apply_link_to_fields(self, link: PackmolDockerLink) -> None: + self.preset_name_edit.setText(link.resolved_display_name) + self.container_name_edit.setText(link.container_name) + self.packmol_command_edit.setText(link.packmol_command) + self.shell_command_edit.setText(link.shell_command) + self.container_root_edit.setText(link.container_project_root) + self._validated_signature = None + self._selected_link = None + + def _load_selected_preset(self) -> None: + current_item = self.preset_combo.currentItem() + if current_item is None: + return + payload = current_item.data(0, _TREE_PATH_ROLE) + if not isinstance(payload, dict): + return + preset = PackmolDockerLink.from_dict(payload) + if preset is None: + return + self._apply_link_to_fields(preset) + self.status_box.setPlainText( + "Preset loaded. Test the container to validate the current " + "Docker state and refresh the directory tree." + ) + self._sync_available_container_selection() + + def _draft_link(self) -> PackmolDockerLink: + container_name = self.container_name_edit.text().strip() + if not container_name: + raise ValueError("Enter a Docker container name before linking.") + display_name = self.preset_name_edit.text().strip() or container_name + return PackmolDockerLink( + display_name=display_name, + container_name=container_name, + packmol_command=self.packmol_command_edit.text().strip() + or "packmol", + shell_command=self.shell_command_edit.text().strip() or "sh", + container_project_root=self.container_root_edit.text().strip() + or DEFAULT_PACKMOL_CONTAINER_ROOT, + ) + + def _link_signature( + self, link: PackmolDockerLink + ) -> tuple[str, str, str, str]: + return ( + link.container_name.strip(), + link.packmol_command.strip(), + link.shell_command.strip(), + link.container_project_root.strip(), + ) + + def _format_docker_failure( + self, + exc: Exception, + *, + include_attached_shell_hint: bool, + ) -> tuple[str, bool]: + details = str(exc).strip() or "Docker command failed." + daemon_hint = docker_daemon_unavailable_hint(details) + sections: list[str] = [] + if daemon_hint is not None: + sections.append(daemon_hint) + sections.append(details) + if include_attached_shell_hint: + if daemon_hint is not None: + sections.append( + "After Docker is running, if your container only stays " + "alive with an attached shell, start it manually with " + "`docker start -i ` before retrying." + ) + else: + sections.append( + "If your container only stays alive with an attached " + "shell, start it manually with `docker start -i " + "` before retrying." + ) + return "\n\n".join(sections), daemon_hint is not None + + def _test_connection(self) -> bool: + try: + draft = self._draft_link() + result = self._docker_client.verify_link(draft) + except Exception as exc: + self._selected_link = None + self._validated_signature = None + formatted_message, _ = self._format_docker_failure( + exc, + include_attached_shell_hint=True, + ) + self.status_box.setPlainText( + "Docker validation failed.\n\n" f"{formatted_message}" + ) + return False + draft.last_verified_at = result.verified_at + draft.container_id = result.container_id + draft.image_name = result.image_name + draft.packmol_command_path = result.packmol_command_path + draft.packmol_version = result.packmol_version + draft.container_project_root = result.container_project_root + self.container_root_edit.setText(result.container_project_root) + self.status_box.setPlainText(result.summary_text(draft)) + self._selected_link = draft + self._validated_signature = self._link_signature(draft) + self._load_directory_tree(draft, result.container_project_root) + return True + + def _refresh_available_containers( + self, + *, + show_feedback: bool = True, + ) -> None: + try: + records = self._docker_client.list_containers() + except Exception as exc: + self._available_containers = [] + self.available_container_combo.clear() + self.available_container_combo.setEnabled(False) + self.use_available_container_button.setEnabled(False) + if show_feedback: + formatted_message, daemon_unavailable = ( + self._format_docker_failure( + exc, + include_attached_shell_hint=False, + ) + ) + self.status_box.appendPlainText( + "\nUnable to list Docker containers.\n\n" + f"{formatted_message}" + ) + if not daemon_unavailable: + self.status_box.appendPlainText( + "\nYou can still type a container name manually and " + "press Test Container." + ) + return + self._available_containers = list(records) + self.available_container_combo.clear() + for record in self._available_containers: + self.available_container_combo.addItem( + record.summary_label, + record, + ) + has_records = bool(self._available_containers) + self.available_container_combo.setEnabled(has_records) + self.use_available_container_button.setEnabled(has_records) + self._sync_available_container_selection() + if not show_feedback: + return + if has_records: + self.status_box.appendPlainText( + "\nDiscovered " + f"{len(self._available_containers)} Docker container(s). " + "Select one to populate the container name field, then " + "press Test Container to verify Packmol." + ) + return + self.status_box.appendPlainText( + "\nNo Docker containers were found. You can still type a " + "container name manually and press Test Container." + ) + + def _sync_available_container_selection(self) -> None: + current_name = self.container_name_edit.text().strip() + if not current_name: + return + for index, record in enumerate(self._available_containers): + if record.name == current_name: + self.available_container_combo.setCurrentIndex(index) + return + + def _selected_available_container( + self, + ) -> PackmolDockerContainerRecord | None: + record = self.available_container_combo.currentData() + if isinstance(record, PackmolDockerContainerRecord): + return record + return None + + def _use_available_container(self) -> None: + record = self._selected_available_container() + if record is None: + return + self.container_name_edit.setText(record.name) + self._selected_link = None + self._validated_signature = None + self.status_box.appendPlainText( + "\nLoaded container name from the discovered Docker list. " + "Press Test Container to verify Packmol in this container." + ) + + def _refresh_directory_tree(self) -> None: + if not self._test_connection(): + return + + def _load_directory_tree( + self, + link: PackmolDockerLink, + root_path: str, + ) -> None: + self.directory_tree.clear() + root_item = QTreeWidgetItem([root_path]) + root_item.setData(0, _TREE_PATH_ROLE, root_path) + root_item.setData(0, _TREE_LOADED_ROLE, False) + root_item.addChild(QTreeWidgetItem(["Loading..."])) + self.directory_tree.addTopLevelItem(root_item) + root_item.setExpanded(True) + self._populate_tree_item_children(root_item, link) + self.directory_tree.setCurrentItem(root_item) + self._on_tree_selection_changed() + + def _populate_tree_item_children( + self, + item: QTreeWidgetItem, + link: PackmolDockerLink | None = None, + ) -> None: + active_link = link or self._selected_link + if active_link is None: + return + if item.data(0, _TREE_LOADED_ROLE): + return + directory = item.data(0, _TREE_PATH_ROLE) + if not isinstance(directory, str) or not directory: + return + item.takeChildren() + try: + entries = self._docker_client.list_directories( + active_link, directory + ) + except Exception as exc: + item.addChild(QTreeWidgetItem([f"Unable to load folders: {exc}"])) + item.setData(0, _TREE_LOADED_ROLE, True) + return + for entry in entries: + child = QTreeWidgetItem([entry.name]) + child.setData(0, _TREE_PATH_ROLE, entry.path) + child.setData(0, _TREE_LOADED_ROLE, False) + child.addChild(QTreeWidgetItem(["Loading..."])) + item.addChild(child) + item.setData(0, _TREE_LOADED_ROLE, True) + + def _on_tree_item_expanded(self, item: QTreeWidgetItem) -> None: + self._populate_tree_item_children(item) + + def _on_tree_selection_changed(self) -> None: + current_item = self.directory_tree.currentItem() + if current_item is None: + self.selected_directory_label.setText("Selected directory: (none)") + return + directory = current_item.data(0, _TREE_PATH_ROLE) + if not isinstance(directory, str) or not directory: + self.selected_directory_label.setText("Selected directory: (none)") + return + self.selected_directory_label.setText( + f"Selected directory: {directory}" + ) + + def _use_selected_directory(self) -> None: + current_item = self.directory_tree.currentItem() + if current_item is None: + return + directory = current_item.data(0, _TREE_PATH_ROLE) + if not isinstance(directory, str) or not directory: + return + self.container_root_edit.setText(directory) + self.status_box.appendPlainText( + "\nUpdated container project root from the selected directory." + ) + + def accept(self) -> None: + try: + draft = self._draft_link() + except Exception as exc: + QMessageBox.warning(self, "Packmol Docker link invalid", str(exc)) + return + if self._validated_signature != self._link_signature(draft): + if not self._test_connection(): + QMessageBox.warning( + self, + "Packmol Docker link invalid", + "Docker validation failed. Review the status box for " + "details and retry.", + ) + return + super().accept() diff --git a/src/saxshell/fullrmc/ui/solvent_shell_builder_window.py b/src/saxshell/fullrmc/ui/solvent_shell_builder_window.py new file mode 100644 index 0000000..fb9058c --- /dev/null +++ b/src/saxshell/fullrmc/ui/solvent_shell_builder_window.py @@ -0,0 +1,1316 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSplitter, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +from saxshell.fullrmc.solvent_shell_builder import ( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A, + SolventShellAnalysisResult, + analyze_solvent_shell, + build_solvent_shell_output, + default_director_atom_name, + reference_atom_choices, +) +from saxshell.saxs.electron_density_mapping.ui.viewer import ( + ElectronDensityStructureViewer, +) +from saxshell.saxs.electron_density_mapping.workflow import ( + load_electron_density_structure, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) +from saxshell.xyz2pdb import ( + ReferenceLibraryEntry, + default_reference_library_dir, + list_reference_library, +) + + +class SolventShellBuilderMainWindow(QMainWindow): + """Small beta utility for isolated solvent-shell detection tests.""" + + _DEFAULT_MINIMUM_SOLVENT_SEPARATION_A = 1.2 + _DEFAULT_SOLUTE_DISTANCE_CUTOFF_A = 2.5 + + def __init__( + self, + *, + initial_project_dir: Path | None = None, + initial_input_path: Path | None = None, + reference_library_dir: Path | None = None, + ) -> None: + super().__init__() + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self.reference_library_dir = ( + default_reference_library_dir() + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ) + self._available_references = list_reference_library( + self.reference_library_dir + ) + self._analysis_result: SolventShellAnalysisResult | None = None + self._build_result_text: str | None = None + self._last_suggested_output_path: str | None = None + self._solute_cutoff_spins: dict[str, QDoubleSpinBox] = {} + self._coordination_center_items: dict[str, QTableWidgetItem] = {} + self._coordination_target_spins: dict[str, QDoubleSpinBox] = {} + self._updating_solute_table = False + self._browse_start_path = ( + self._initial_project_dir + if self._initial_project_dir is not None + else Path.home() + ) + + self.setWindowTitle("Solvent Shell Builder (Beta)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(920, 700) + + central_widget = QWidget(self) + self.setCentralWidget(central_widget) + layout = QVBoxLayout(central_widget) + layout.setContentsMargins(12, 12, 12, 12) + layout.setSpacing(10) + + intro_label = QLabel( + "Beta utility for detecting whether an input PDB or XYZ contains " + "no, partial, or complete solvent molecules that match one " + "selected reference preset, then building a solvated output PDB " + "for no-solvent or partial-solvent cases." + ) + intro_label.setWordWrap(True) + layout.addWidget(intro_label) + + self._pane_splitter = QSplitter(Qt.Orientation.Horizontal, self) + self._pane_splitter.setChildrenCollapsible(False) + self._pane_splitter.setStretchFactor(0, 0) + self._pane_splitter.setStretchFactor(1, 1) + layout.addWidget(self._pane_splitter, stretch=1) + + self._left_scroll_area = QScrollArea(self) + self._left_scroll_area.setWidgetResizable(True) + self._left_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self._right_scroll_area = QScrollArea(self) + self._right_scroll_area.setWidgetResizable(True) + self._right_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self._left_panel = QWidget() + self._right_panel = QWidget() + self._left_layout = QVBoxLayout(self._left_panel) + self._left_layout.setContentsMargins(10, 10, 10, 10) + self._left_layout.setSpacing(10) + self._right_layout = QVBoxLayout(self._right_panel) + self._right_layout.setContentsMargins(10, 10, 10, 10) + self._right_layout.setSpacing(10) + self._left_scroll_area.setWidget(self._left_panel) + self._right_scroll_area.setWidget(self._right_panel) + self._pane_splitter.addWidget(self._left_scroll_area) + self._pane_splitter.addWidget(self._right_scroll_area) + self._pane_splitter.setSizes([380, 620]) + + self.input_group = QGroupBox("Input Settings") + input_layout = QVBoxLayout(self.input_group) + form = QFormLayout() + self.reference_preset_combo = QComboBox() + for preset in self._available_references: + self.reference_preset_combo.addItem(preset.name, preset.name) + self.reference_preset_combo.currentIndexChanged.connect( + self._handle_reference_selection_changed + ) + form.addRow("Reference molecule", self.reference_preset_combo) + + self.reference_details_box = QPlainTextEdit() + self.reference_details_box.setReadOnly(True) + self.reference_details_box.setPlaceholderText( + "Reference preset details will appear here." + ) + self.reference_details_box.setMinimumHeight(72) + self.reference_details_box.setMaximumBlockCount(8) + form.addRow("Preset details", self.reference_details_box) + + self.reference_match_tolerance_spin = QDoubleSpinBox() + self.reference_match_tolerance_spin.setDecimals(3) + self.reference_match_tolerance_spin.setRange(0.001, 10.0) + self.reference_match_tolerance_spin.setSingleStep(0.025) + self.reference_match_tolerance_spin.setSuffix(" A") + self.reference_match_tolerance_spin.setValue( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A + ) + self.reference_match_tolerance_spin.setToolTip( + "Maximum allowed anchor-pair distance deviation when matching " + "the selected solvent reference." + ) + self.reference_match_tolerance_spin.valueChanged.connect( + self._handle_reference_tolerance_changed + ) + form.addRow( + "Reference match tolerance", + self.reference_match_tolerance_spin, + ) + + input_row = QHBoxLayout() + self.input_path_edit = QLineEdit() + self.input_path_edit.setPlaceholderText( + "Choose an input PDB or XYZ structure file..." + ) + self.input_path_edit.editingFinished.connect( + self._handle_input_path_edited + ) + input_row.addWidget(self.input_path_edit, stretch=1) + self.browse_input_button = QPushButton("Browse...") + self.browse_input_button.clicked.connect(self._browse_input_file) + input_row.addWidget(self.browse_input_button) + form.addRow("Input file", input_row) + input_layout.addLayout(form) + + action_row = QHBoxLayout() + action_row.addStretch(1) + self.analyze_button = QPushButton("Analyze Solvent Shell") + self.analyze_button.clicked.connect(self._analyze_input_structure) + self.analyze_button.setEnabled(bool(self._available_references)) + action_row.addWidget(self.analyze_button) + input_layout.addLayout(action_row) + self._left_layout.addWidget(self.input_group) + + self.cluster_status_group = QGroupBox("Detected Cluster Status") + status_layout = QVBoxLayout(self.cluster_status_group) + self.cluster_status_headline_label = QLabel() + self.cluster_status_headline_label.setWordWrap(True) + self.cluster_status_stats_label = QLabel() + self.cluster_status_stats_label.setWordWrap(True) + self.cluster_status_stats_label.setTextInteractionFlags( + Qt.TextInteractionFlag.TextSelectableByMouse + ) + status_layout.addWidget(self.cluster_status_headline_label) + status_layout.addWidget(self.cluster_status_stats_label) + self._left_layout.addWidget(self.cluster_status_group) + + self.build_group = QGroupBox("Build Solvation Shell") + build_layout = QVBoxLayout(self.build_group) + self.build_intro_label = QLabel( + "Use the analyzed solvent status to complete partial solvent " + "molecules or build a new shell for a cluster with no " + "coordinated solvent." + ) + self.build_intro_label.setWordWrap(True) + build_layout.addWidget(self.build_intro_label) + + build_form = QFormLayout() + self.director_atom_combo = QComboBox() + self.director_atom_combo.setToolTip( + "Reference atom that should point toward the solute cluster " + "during solvent placement." + ) + build_form.addRow("Director atom", self.director_atom_combo) + + self.minimum_solvent_separation_spin = QDoubleSpinBox() + self.minimum_solvent_separation_spin.setDecimals(3) + self.minimum_solvent_separation_spin.setRange(0.0, 20.0) + self.minimum_solvent_separation_spin.setSingleStep(0.1) + self.minimum_solvent_separation_spin.setSuffix(" A") + self.minimum_solvent_separation_spin.setValue( + self._DEFAULT_MINIMUM_SOLVENT_SEPARATION_A + ) + self.minimum_solvent_separation_spin.setToolTip( + "Minimum allowed atom-to-atom separation between placed solvent " + "molecules and already placed neighbors." + ) + build_form.addRow( + "Solvent-solvent separation", + self.minimum_solvent_separation_spin, + ) + + output_row = QHBoxLayout() + self.output_path_edit = QLineEdit() + self.output_path_edit.setPlaceholderText( + "Choose where the solvated output PDB should be written..." + ) + self.output_path_edit.editingFinished.connect( + self._handle_output_path_edited + ) + output_row.addWidget(self.output_path_edit, stretch=1) + self.browse_output_button = QPushButton("Browse...") + self.browse_output_button.clicked.connect(self._browse_output_file) + output_row.addWidget(self.browse_output_button) + build_form.addRow("Output PDB", output_row) + build_layout.addLayout(build_form) + + self.solute_cutoff_status_label = QLabel( + "Analyze the input structure to populate the recognized solute " + "atom types and their solvent-placement distances." + ) + self.solute_cutoff_status_label.setWordWrap(True) + build_layout.addWidget(self.solute_cutoff_status_label) + + self.solute_cutoff_table = QTableWidget(0, 5) + self.solute_cutoff_table.setHorizontalHeaderLabels( + [ + "Solute Element", + "Atom Count", + "Coordination Center", + "Avg Coord #", + "Director Distance (A)", + ] + ) + self.solute_cutoff_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.solute_cutoff_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.solute_cutoff_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.solute_cutoff_table.verticalHeader().setVisible(False) + cutoff_header = self.solute_cutoff_table.horizontalHeader() + cutoff_header.setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + cutoff_header.setSectionResizeMode( + 1, QHeaderView.ResizeMode.ResizeToContents + ) + cutoff_header.setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) + cutoff_header.setSectionResizeMode( + 3, QHeaderView.ResizeMode.ResizeToContents + ) + cutoff_header.setSectionResizeMode(4, QHeaderView.ResizeMode.Stretch) + self.solute_cutoff_table.itemChanged.connect( + self._handle_solute_table_item_changed + ) + build_layout.addWidget(self.solute_cutoff_table) + + self.build_status_label = QLabel( + "Analyze the input structure to enable solvent-shell building." + ) + self.build_status_label.setWordWrap(True) + self.build_status_label.setTextInteractionFlags( + Qt.TextInteractionFlag.TextSelectableByMouse + ) + build_layout.addWidget(self.build_status_label) + + build_action_row = QHBoxLayout() + build_action_row.addStretch(1) + self.build_output_button = QPushButton("Build Solvated Output PDB") + self.build_output_button.clicked.connect(self._build_solvated_output) + build_action_row.addWidget(self.build_output_button) + build_layout.addLayout(build_action_row) + self._left_layout.addWidget(self.build_group) + + left_notes_group = QGroupBox("Notes") + left_notes_layout = QVBoxLayout(left_notes_group) + self.input_notes_label = QLabel( + "Run the analysis after changing the reference or input file. " + "Adjust the reference match tolerance if the solvent geometry " + "needs a looser or tighter match. Then review the director atom, " + "solute cutoffs, and solvent separation before building the " + "solvated output PDB." + ) + self.input_notes_label.setWordWrap(True) + left_notes_layout.addWidget(self.input_notes_label) + self._left_layout.addWidget(left_notes_group) + self._left_layout.addStretch(1) + + visualizer_group = QGroupBox("Structure Visualizer") + visualizer_layout = QVBoxLayout(visualizer_group) + self.visualizer_status_label = QLabel( + "Choose a PDB or XYZ input file to preview the structure." + ) + self.visualizer_status_label.setWordWrap(True) + visualizer_layout.addWidget(self.visualizer_status_label) + self.structure_viewer = ElectronDensityStructureViewer( + self._right_panel + ) + self.structure_viewer.setMinimumHeight(460) + visualizer_layout.addWidget(self.structure_viewer, stretch=1) + self._right_layout.addWidget(visualizer_group, stretch=1) + + summary_group = QGroupBox("Generated Outputs") + summary_layout = QVBoxLayout(summary_group) + self.summary_box = QPlainTextEdit() + self.summary_box.setReadOnly(True) + self.summary_box.setPlaceholderText( + "Analysis results will appear here." + ) + summary_layout.addWidget(self.summary_box) + self._right_layout.addWidget(summary_group) + + residue_group = QGroupBox("Matched PDB Residues") + residue_layout = QVBoxLayout(residue_group) + self.residue_status_label = QLabel( + "Residue-level solvent types are reported only for PDB inputs." + ) + self.residue_status_label.setWordWrap(True) + residue_layout.addWidget(self.residue_status_label) + self.residue_table = QTableWidget(0, 5) + self.residue_table.setHorizontalHeaderLabels( + [ + "Residue", + "Matched Molecules", + "Residue Numbers", + "Atoms / Molecule", + "Elements", + ] + ) + self.residue_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.residue_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.residue_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + header = self.residue_table.horizontalHeader() + header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch) + header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(4, QHeaderView.ResizeMode.Stretch) + residue_layout.addWidget(self.residue_table) + self._right_layout.addWidget(residue_group) + + mismatch_group = QGroupBox("Incomplete / Partial Solvent Candidates") + mismatch_layout = QVBoxLayout(mismatch_group) + self.mismatch_status_label = QLabel( + "Incomplete solvent-like candidates are reported after analysis." + ) + self.mismatch_status_label.setWordWrap(True) + mismatch_layout.addWidget(self.mismatch_status_label) + self.mismatch_table = QTableWidget(0, 7) + self.mismatch_table.setHorizontalHeaderLabels( + [ + "Residue", + "Residue Number", + "Observed Atoms", + "Matched / Ref", + "Missing Atoms", + "Extra Atoms", + "Reason", + ] + ) + self.mismatch_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.mismatch_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.mismatch_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + mismatch_header = self.mismatch_table.horizontalHeader() + mismatch_header.setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + mismatch_header.setSectionResizeMode( + 1, QHeaderView.ResizeMode.ResizeToContents + ) + mismatch_header.setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) + mismatch_header.setSectionResizeMode( + 3, QHeaderView.ResizeMode.ResizeToContents + ) + mismatch_header.setSectionResizeMode(4, QHeaderView.ResizeMode.Stretch) + mismatch_header.setSectionResizeMode(5, QHeaderView.ResizeMode.Stretch) + mismatch_header.setSectionResizeMode(6, QHeaderView.ResizeMode.Stretch) + mismatch_layout.addWidget(self.mismatch_table) + self._right_layout.addWidget(mismatch_group) + self._right_layout.addStretch(1) + + if initial_input_path is not None: + resolved_input = Path(initial_input_path).expanduser().resolve() + if resolved_input.is_dir(): + self._browse_start_path = resolved_input + else: + self._browse_start_path = resolved_input.parent + self.input_path_edit.setText(str(resolved_input)) + + self._populate_director_atom_choices() + self._update_reference_details() + self._update_suggested_output_path(force=True) + self._update_cluster_status_panel(None) + self._populate_solute_cutoff_table(None) + self._update_build_panel_state() + if not self._available_references: + self._set_cluster_status_panel_text( + "No solvent status is available yet.", + "No reference presets were found. Add a reference molecule " + "to the library before using this beta tool.", + ) + self._set_build_status_text( + "Solvation-shell building is unavailable.", + "No reference presets were found. Add a solvent reference " + "molecule to the library before using this beta tool.", + ) + self.summary_box.setPlainText( + "No reference presets were found. Add a reference molecule " + "to the library before using this beta tool." + ) + elif self.input_path_edit.text().strip(): + self._refresh_structure_preview() + + def _selected_reference(self) -> ReferenceLibraryEntry | None: + selected_name = self.reference_preset_combo.currentData() + if selected_name is None: + return None + for preset in self._available_references: + if preset.name == selected_name: + return preset + return None + + def _update_reference_details(self) -> None: + preset = self._selected_reference() + if preset is None: + self.reference_details_box.setPlainText( + "No solvent reference preset is selected." + ) + return + suggested_director = default_director_atom_name( + preset.name, + reference_library_dir=self.reference_library_dir, + ) + director_text = ( + suggested_director if suggested_director is not None else "n/a" + ) + self.reference_details_box.setPlainText( + f"Residue {preset.residue_name}\n" + f"Atom count: {preset.atom_count}\n" + f"Suggested director atom: {director_text}\n" + f"Reference file: {preset.path.name}" + ) + + def _browse_input_file(self) -> None: + start_path = self.input_path_edit.text().strip() or str( + self._browse_start_path + ) + selected_path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Choose Input Structure", + start_path, + "Structure files (*.pdb *.xyz);;PDB files (*.pdb);;XYZ files (*.xyz)", + ) + if not selected_path: + return + resolved_path = Path(selected_path).expanduser().resolve() + self._browse_start_path = resolved_path.parent + self.input_path_edit.setText(str(resolved_path)) + self._handle_input_path_edited() + + def _handle_reference_selection_changed(self) -> None: + self._populate_director_atom_choices() + self._update_reference_details() + self._update_suggested_output_path() + self._clear_analysis_outputs() + + def _handle_reference_tolerance_changed(self) -> None: + self._clear_analysis_outputs() + + def _handle_input_path_edited(self) -> None: + self._update_suggested_output_path() + self._clear_analysis_outputs() + self._refresh_structure_preview() + + def _handle_output_path_edited(self) -> None: + output_text = self.output_path_edit.text().strip() + if output_text and not output_text.lower().endswith(".pdb"): + self.output_path_edit.setText(f"{output_text}.pdb") + + def _analyze_input_structure(self) -> None: + preset = self._selected_reference() + if preset is None: + QMessageBox.information( + self, + "No reference selected", + "Choose a solvent reference molecule before analyzing the input file.", + ) + return + input_text = self.input_path_edit.text().strip() + if not input_text: + QMessageBox.information( + self, + "No input file selected", + "Choose a PDB or XYZ file to inspect.", + ) + return + input_path = Path(input_text).expanduser().resolve() + self._refresh_structure_preview() + try: + result = analyze_solvent_shell( + input_path, + preset.name, + reference_library_dir=self.reference_library_dir, + reference_match_tolerance_a=float( + self.reference_match_tolerance_spin.value() + ), + ) + except Exception as exc: + self._analysis_result = None + self.summary_box.setPlainText(str(exc)) + self._set_cluster_status_panel_text( + "Solvent-shell analysis failed.", + str(exc), + ) + self._populate_residue_table(None) + self._populate_mismatch_table(None) + self.statusBar().showMessage("Solvent-shell analysis failed", 5000) + QMessageBox.warning( + self, + "Solvent-shell analysis failed", + str(exc), + ) + return + self._analysis_result = result + self._build_result_text = None + self.summary_box.setPlainText(self._combined_summary_text()) + self._update_cluster_status_panel(result) + self._populate_solute_cutoff_table(result) + self._update_build_panel_state() + self._populate_residue_table(result) + self._populate_mismatch_table(result) + self._refresh_structure_preview(preserve_display=True) + self.statusBar().showMessage( + result.cluster_solvent_status_text, + 5000, + ) + + def _clear_analysis_outputs(self) -> None: + self._analysis_result = None + self._build_result_text = None + self.summary_box.clear() + self._update_cluster_status_panel(None) + self._populate_solute_cutoff_table(None) + self._update_build_panel_state() + self._populate_residue_table(None) + self._populate_mismatch_table(None) + + def _set_cluster_status_panel_text( + self, + headline: str, + details: str, + ) -> None: + self.cluster_status_headline_label.setText(headline) + self.cluster_status_stats_label.setText(details) + + def _update_cluster_status_panel( + self, + result: SolventShellAnalysisResult | None, + ) -> None: + if result is None: + self._set_cluster_status_panel_text( + "No solvent status has been determined yet.", + "Choose a reference molecule and input structure, then run " + "Analyze Solvent Shell to identify whether the cluster " + "contains no, partial, or complete solvent molecules.", + ) + return + self._set_cluster_status_panel_text( + result.cluster_solvent_status_text, + result.status_statistics_text(), + ) + + def _previewable_input_path(self) -> Path | None: + input_text = self.input_path_edit.text().strip() + if not input_text: + return None + input_path = Path(input_text).expanduser().resolve() + if not input_path.is_file() or input_path.suffix.lower() not in { + ".pdb", + ".xyz", + }: + return None + return input_path + + def _refresh_structure_preview( + self, + *, + preview_path: Path | None = None, + preserve_display: bool = False, + ) -> None: + structure_path = ( + preview_path + if preview_path is not None + else self._previewable_input_path() + ) + if ( + structure_path is None + or not structure_path.is_file() + or structure_path.suffix.lower() not in {".pdb", ".xyz"} + ): + self.structure_viewer.draw_placeholder() + self.visualizer_status_label.setText( + "Choose a valid PDB or XYZ input file to preview the structure." + ) + return + try: + structure = load_electron_density_structure( + structure_path, + center_mode="center_of_mass", + include_bonds=True, + include_comment=True, + ) + except Exception as exc: + self.structure_viewer.draw_placeholder() + self.visualizer_status_label.setText( + f"Unable to preview {structure_path.name}: {exc}" + ) + return + scene_key = f"solvent-shell-builder:{structure_path}" + if ( + preserve_display + and self.structure_viewer.current_structure is not None + ): + self.structure_viewer.set_structure_preserving_display(structure) + else: + self.structure_viewer.set_structure( + structure, + scene_key=scene_key, + ) + is_generated_output = preview_path is not None and ( + self._previewable_input_path() != structure_path + ) + file_role = ( + "generated output" if is_generated_output else "input structure" + ) + self.visualizer_status_label.setText( + f"Previewing {file_role} {structure_path.name} with " + f"{structure.atom_count} atom(s)." + ) + + def _populate_residue_table( + self, + result: SolventShellAnalysisResult | None, + ) -> None: + self.residue_table.setRowCount(0) + if result is None: + self.residue_status_label.setText( + "Residue-level solvent types are reported only for PDB inputs." + ) + return + if result.input_format != "pdb": + self.residue_status_label.setText( + "Residue-level solvent types are not available for XYZ inputs." + ) + return + if not result.matched_residue_summaries: + self.residue_status_label.setText( + "No complete solvent residues matching the selected reference were detected." + ) + return + self.residue_status_label.setText( + "Residue names below matched the selected solvent geometry." + ) + self.residue_table.setRowCount(len(result.matched_residue_summaries)) + for row, summary in enumerate(result.matched_residue_summaries): + self._set_table_item(row, 0, summary.residue_name) + self._set_table_item(row, 1, str(summary.molecule_count)) + self._set_table_item(row, 2, summary.residue_numbers_text) + self._set_table_item(row, 3, str(summary.atom_count)) + self._set_table_item(row, 4, summary.element_counts_text) + + def _populate_mismatch_table( + self, + result: SolventShellAnalysisResult | None, + ) -> None: + self.mismatch_table.setRowCount(0) + if result is None: + self.mismatch_status_label.setText( + "Incomplete solvent-like candidates are reported after analysis." + ) + return + if not result.residue_mismatch_summaries: + if result.input_format == "pdb": + self.mismatch_status_label.setText( + "No incomplete or mismatched solvent-like PDB residues were preserved." + ) + else: + self.mismatch_status_label.setText( + "No partial solvent candidates were inferred from the XYZ input." + ) + return + if result.input_format == "pdb": + self.mismatch_status_label.setText( + "Incomplete or mismatched solvent-like residues were preserved with missing-atom details." + ) + else: + self.mismatch_status_label.setText( + "Partial solvent candidates were inferred heuristically from XYZ atom sets and preserved with missing-atom details." + ) + self.mismatch_table.setRowCount(len(result.residue_mismatch_summaries)) + for row, summary in enumerate(result.residue_mismatch_summaries): + self._set_mismatch_table_item(row, 0, summary.residue_name) + self._set_mismatch_table_item(row, 1, str(summary.residue_number)) + self._set_mismatch_table_item( + row, + 2, + str(summary.observed_atom_count), + ) + self._set_mismatch_table_item( + row, + 3, + summary.matched_atom_ratio_text, + ) + self._set_mismatch_table_item( + row, + 4, + summary.missing_atom_names_text, + ) + self._set_mismatch_table_item( + row, + 5, + summary.extra_atom_names_text, + ) + self._set_mismatch_table_item( + row, + 6, + summary.mismatch_reason, + ) + + def _set_table_item(self, row: int, column: int, text: str) -> None: + self.residue_table.setItem(row, column, QTableWidgetItem(text)) + + def _set_mismatch_table_item( + self, + row: int, + column: int, + text: str, + ) -> None: + self.mismatch_table.setItem(row, column, QTableWidgetItem(text)) + + def _combined_summary_text(self) -> str: + sections: list[str] = [] + if self._analysis_result is not None: + sections.append(self._analysis_result.summary_text()) + if self._build_result_text: + sections.extend( + [ + "", + "Generated solvent shell output:", + self._build_result_text, + ] + ) + return "\n".join(sections) + + def _populate_director_atom_choices(self) -> None: + self.director_atom_combo.blockSignals(True) + self.director_atom_combo.clear() + preset = self._selected_reference() + if preset is None: + self.director_atom_combo.blockSignals(False) + return + try: + atom_names = reference_atom_choices( + preset.name, + reference_library_dir=self.reference_library_dir, + ) + suggested_name = default_director_atom_name( + preset.name, + reference_library_dir=self.reference_library_dir, + ) + except Exception: + self.director_atom_combo.blockSignals(False) + return + for atom_name in atom_names: + self.director_atom_combo.addItem(atom_name, atom_name) + if suggested_name is not None: + suggested_index = self.director_atom_combo.findData(suggested_name) + if suggested_index >= 0: + self.director_atom_combo.setCurrentIndex(suggested_index) + elif atom_names: + self.director_atom_combo.setCurrentIndex(0) + self.director_atom_combo.blockSignals(False) + + def _suggested_output_path(self) -> str: + input_text = self.input_path_edit.text().strip() + preset = self._selected_reference() + if not input_text: + base_dir = self._browse_start_path + reference_suffix = ( + preset.name.casefold() if preset is not None else "solvent" + ) + return str( + (base_dir / f"solvent_shell_builder_{reference_suffix}.pdb") + .expanduser() + .resolve() + ) + input_path = Path(input_text).expanduser().resolve() + reference_suffix = ( + preset.name.casefold() if preset is not None else "solvent" + ) + return str( + input_path.with_name( + f"{input_path.stem}__solvated_{reference_suffix}.pdb" + ) + ) + + def _update_suggested_output_path(self, *, force: bool = False) -> None: + suggested_path = self._suggested_output_path() + current_text = self.output_path_edit.text().strip() + if ( + force + or not current_text + or current_text == self._last_suggested_output_path + ): + self.output_path_edit.setText(suggested_path) + self._last_suggested_output_path = suggested_path + + def _browse_output_file(self) -> None: + start_path = self.output_path_edit.text().strip() or str( + self._browse_start_path + ) + selected_path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Choose Solvated Output PDB", + start_path, + "PDB files (*.pdb);;All files (*)", + ) + if not selected_path: + return + resolved_path = Path(selected_path).expanduser().resolve() + if resolved_path.suffix.lower() != ".pdb": + resolved_path = resolved_path.with_suffix(".pdb") + self._browse_start_path = resolved_path.parent + self.output_path_edit.setText(str(resolved_path)) + + def _populate_solute_cutoff_table( + self, + result: SolventShellAnalysisResult | None, + ) -> None: + self._updating_solute_table = True + previous_values = { + element: spin.value() + for element, spin in self._solute_cutoff_spins.items() + } + previous_coordination_targets = { + element: spin.value() + for element, spin in self._coordination_target_spins.items() + } + previous_center_states = { + element: item.checkState() == Qt.CheckState.Checked + for element, item in self._coordination_center_items.items() + } + self._solute_cutoff_spins = {} + self._coordination_center_items = {} + self._coordination_target_spins = {} + self.solute_cutoff_table.setRowCount(0) + if result is None: + self.solute_cutoff_status_label.setText( + "Analyze the input structure to populate the recognized " + "solute atom types and their solvent-placement distances." + ) + self._updating_solute_table = False + return + if not result.solute_element_counts: + if result.partial_solvent_molecule_count > 0: + self.solute_cutoff_status_label.setText( + "Partial solvent anchors were found. No additional " + "solute atom types remain after excluding those solvent " + "candidates, so coordination-center selection is not " + "needed unless you want to extend this workflow later." + ) + else: + self.solute_cutoff_status_label.setText( + "No remaining solute atom types were recognized after " + "excluding solvent-like atoms." + ) + self._updating_solute_table = False + return + if result.partial_solvent_molecule_count > 0: + self.solute_cutoff_status_label.setText( + "Mark any solute elements that should act as coordinating " + "centers, set their average target coordination numbers, and " + "review the director-atom distance cutoffs if you want this " + "partial-solvent build to add new molecules beyond the " + "reconstructed anchors." + ) + else: + self.solute_cutoff_status_label.setText( + "Choose which solute elements should coordinate the solvent, " + "set their average target coordination numbers, and review " + "the director-atom distance cutoffs before building a new " + "solvent shell from scratch." + ) + self.solute_cutoff_table.setRowCount(len(result.solute_element_counts)) + for row, (element, atom_count) in enumerate( + sorted(result.solute_element_counts.items()) + ): + self.solute_cutoff_table.setItem( + row, + 0, + QTableWidgetItem(str(element)), + ) + self.solute_cutoff_table.setItem( + row, + 1, + QTableWidgetItem(str(atom_count)), + ) + center_item = QTableWidgetItem("") + center_item.setFlags( + Qt.ItemFlag.ItemIsEnabled + | Qt.ItemFlag.ItemIsSelectable + | Qt.ItemFlag.ItemIsUserCheckable + ) + center_item.setCheckState( + Qt.CheckState.Checked + if previous_center_states.get(str(element), False) + else Qt.CheckState.Unchecked + ) + self.solute_cutoff_table.setItem(row, 2, center_item) + self._coordination_center_items[str(element)] = center_item + + coordination_spin = QDoubleSpinBox(self.solute_cutoff_table) + coordination_spin.setDecimals(2) + coordination_spin.setRange(0.0, 12.0) + coordination_spin.setSingleStep(0.25) + coordination_spin.setValue( + previous_coordination_targets.get(str(element), 0.0) + ) + coordination_spin.valueChanged.connect( + self._handle_coordination_settings_changed + ) + self.solute_cutoff_table.setCellWidget(row, 3, coordination_spin) + self._coordination_target_spins[str(element)] = coordination_spin + + spin = QDoubleSpinBox(self.solute_cutoff_table) + spin.setDecimals(3) + spin.setRange(0.0, 20.0) + spin.setSingleStep(0.1) + spin.setSuffix(" A") + spin.setValue( + previous_values.get( + str(element), + self._DEFAULT_SOLUTE_DISTANCE_CUTOFF_A, + ) + ) + spin.valueChanged.connect( + self._handle_coordination_settings_changed + ) + self.solute_cutoff_table.setCellWidget(row, 4, spin) + self._solute_cutoff_spins[str(element)] = spin + self._updating_solute_table = False + + def _set_build_status_text(self, headline: str, details: str) -> None: + self.build_status_label.setText(f"{headline}\n{details}") + + def _update_build_panel_state(self) -> None: + result = self._analysis_result + can_choose_reference = bool(self.director_atom_combo.count()) + self.director_atom_combo.setEnabled(can_choose_reference) + self.minimum_solvent_separation_spin.setEnabled( + bool(self._available_references) + ) + self.output_path_edit.setEnabled(bool(self._available_references)) + self.browse_output_button.setEnabled(bool(self._available_references)) + self.solute_cutoff_table.setEnabled(result is not None) + if result is None: + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build is waiting for analysis.", + "Analyze the input structure first so the beta tool can " + "identify the solvent status, populate the solute cutoffs, " + "and determine whether a build is needed.", + ) + return + selected_centers = self._selected_coordination_center_elements() + selected_coordination_targets = ( + self._selected_target_average_coordination_numbers() + ) + solute_cutoffs = self._collect_solute_distance_cutoffs() + selected_centers_missing_cutoffs = [ + element + for element in selected_centers + if solute_cutoffs.get(element, 0.0) <= 0.0 + ] + if ( + result.complete_solvent_molecule_count > 0 + and result.partial_solvent_molecule_count == 0 + ): + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build is disabled for this input.", + "The analyzed structure already contains complete solvent " + "molecules and does not expose partial solvent candidates " + "to rebuild.", + ) + return + if result.partial_solvent_molecule_count == 0 and not selected_centers: + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build needs coordination targets.", + "Select at least one solute element as a coordinating " + "center and provide its average coordination number before " + "building a shell from scratch.", + ) + return + if ( + result.partial_solvent_molecule_count == 0 + and not selected_coordination_targets + ): + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build needs coordination targets.", + "Set a positive average coordination number for at least " + "one selected coordinating center element.", + ) + return + if ( + result.partial_solvent_molecule_count == 0 + and selected_centers_missing_cutoffs + ): + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build needs coordination cutoffs.", + "Each selected coordinating center element needs a positive " + "director-distance cutoff before the shell can be built.", + ) + return + if ( + result.partial_solvent_molecule_count == 0 + and not result.solute_element_counts + ): + self.build_output_button.setEnabled(False) + self._set_build_status_text( + "Solvation-shell build cannot start yet.", + "No partial solvent anchors or remaining solute atom types " + "were recognized for solvent placement.", + ) + return + self.build_output_button.setEnabled(True) + if result.partial_solvent_molecule_count > 0: + if selected_coordination_targets: + self._set_build_status_text( + "Solvation-shell build is ready.", + "The build will complete the partial solvent anchors and " + "then add more solvent molecules until the selected " + "average coordination targets are met or no valid " + "placements remain.", + ) + return + self._set_build_status_text( + "Solvation-shell build is ready.", + "The selected director atom will be used to complete the " + "partial solvent candidates that were detected in the input.", + ) + return + self._set_build_status_text( + "Solvation-shell build is ready.", + "Review the per-solute director distances and build a solvated " + "output PDB for this no-solvent cluster.", + ) + + def _collect_solute_distance_cutoffs(self) -> dict[str, float]: + return { + element: float(spin.value()) + for element, spin in sorted(self._solute_cutoff_spins.items()) + } + + def _selected_coordination_center_elements(self) -> tuple[str, ...]: + return tuple( + sorted( + element + for element, item in self._coordination_center_items.items() + if item.checkState() == Qt.CheckState.Checked + ) + ) + + def _selected_target_average_coordination_numbers( + self, + ) -> dict[str, float]: + selected_elements = set(self._selected_coordination_center_elements()) + return { + element: float(spin.value()) + for element, spin in sorted( + self._coordination_target_spins.items() + ) + if element in selected_elements and float(spin.value()) > 0.0 + } + + def _handle_solute_table_item_changed( + self, + item: QTableWidgetItem, + ) -> None: + if self._updating_solute_table: + return + if item.column() != 2: + return + self._update_build_panel_state() + + def _handle_coordination_settings_changed(self) -> None: + if self._updating_solute_table: + return + self._update_build_panel_state() + + def _build_solvated_output(self) -> None: + preset = self._selected_reference() + result = self._analysis_result + if preset is None: + QMessageBox.information( + self, + "No reference selected", + "Choose a solvent reference molecule before building an output PDB.", + ) + return + if result is None: + QMessageBox.information( + self, + "No analysis available", + "Analyze the input structure before building a solvated output PDB.", + ) + return + output_text = self.output_path_edit.text().strip() + if not output_text: + QMessageBox.information( + self, + "No output path selected", + "Choose where the solvated output PDB should be written.", + ) + return + director_atom_name = str(self.director_atom_combo.currentData() or "") + if not director_atom_name: + QMessageBox.information( + self, + "No director atom selected", + "Choose the solvent reference atom that should point toward the solute cluster.", + ) + return + solute_cutoffs = self._collect_solute_distance_cutoffs() + coordinating_center_elements = ( + self._selected_coordination_center_elements() + ) + target_average_coordination_numbers = ( + self._selected_target_average_coordination_numbers() + ) + output_path = Path(output_text).expanduser().resolve() + if output_path.suffix.lower() != ".pdb": + output_path = output_path.with_suffix(".pdb") + self.output_path_edit.setText(str(output_path)) + try: + build_result = build_solvent_shell_output( + result.input_path, + preset.name, + output_path=output_path, + director_atom_name=director_atom_name, + minimum_solvent_atom_separation_a=float( + self.minimum_solvent_separation_spin.value() + ), + solute_distance_cutoffs_a=solute_cutoffs, + coordinating_center_elements=coordinating_center_elements, + target_average_coordination_numbers=( + target_average_coordination_numbers + ), + reference_library_dir=self.reference_library_dir, + reference_match_tolerance_a=float( + self.reference_match_tolerance_spin.value() + ), + analysis_result=result, + ) + except Exception as exc: + self._build_result_text = None + self.summary_box.setPlainText(self._combined_summary_text()) + self._set_build_status_text( + "Solvation-shell build failed.", + str(exc), + ) + self.statusBar().showMessage( + "Solvation-shell build failed", + 5000, + ) + QMessageBox.warning( + self, + "Solvation-shell build failed", + str(exc), + ) + return + self._build_result_text = build_result.summary_text() + self.summary_box.setPlainText(self._combined_summary_text()) + self._refresh_structure_preview( + preview_path=build_result.output_path, + preserve_display=True, + ) + self._set_build_status_text( + "Solvated output PDB generated.", + ( + f"Wrote {build_result.output_path.name} with " + f"{build_result.solvent_molecules_added} solvent molecule(s) " + "added." + ), + ) + self.statusBar().showMessage( + f"Built solvated output: {build_result.output_path.name}", + 5000, + ) + + +def launch_solvent_shell_builder_ui( + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + reference_library_dir: str | Path | None = None, +) -> SolventShellBuilderMainWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = SolventShellBuilderMainWindow( + initial_project_dir=( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ), + initial_input_path=( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ), + reference_library_dir=( + None + if reference_library_dir is None + else Path(reference_library_dir).expanduser().resolve() + ), + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "SolventShellBuilderMainWindow", + "launch_solvent_shell_builder_ui", +] diff --git a/tests/test_fullrmc_cli.py b/tests/test_fullrmc_cli.py index 13a6bb2..ebf9112 100644 --- a/tests/test_fullrmc_cli.py +++ b/tests/test_fullrmc_cli.py @@ -7,23 +7,40 @@ import numpy as np import pytest -from PySide6.QtWidgets import QApplication, QMessageBox +from PySide6.QtCore import Qt, Signal +from PySide6.QtWidgets import ( + QApplication, + QDoubleSpinBox, + QMessageBox, + QWidget, +) import saxshell.fullrmc.cli as fullrmc_cli +import saxshell.fullrmc.solvent_shell_builder as solvent_shell_builder_module +import saxshell.fullrmc.ui.main_window as fullrmc_ui_module from saxshell.fullrmc import ( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A, ConstraintGenerationSettings, + PackmolDockerContainerRecord, + PackmolDockerDirectoryEntry, + PackmolDockerLink, + PackmolDockerSyncResult, + PackmolDockerValidationResult, PackmolPlanningSettings, PackmolSetupSettings, RepresentativeSelectionSettings, SolutionPropertiesSettings, SolventHandlingSettings, + analyze_solvent_shell, build_constraint_generation, build_distribution_selection, build_packmol_plan, build_packmol_setup, build_representative_preview_clusters, build_representative_solvent_outputs, + build_solvent_shell_output, calculate_solution_properties, + container_project_root_is_valid, load_constraint_generation_metadata, load_packmol_planning_metadata, load_packmol_setup_metadata, @@ -32,21 +49,62 @@ load_solvent_handling_metadata, parse_angle_triplet_text, parse_bond_pair_text, + save_packmol_docker_link_metadata, + save_representative_selection_metadata, save_solution_properties_metadata, select_distribution_representatives, select_first_file_representatives, ) from saxshell.fullrmc.cli import main as fullrmc_main from saxshell.fullrmc.ui.main_window import RMCSetupMainWindow +from saxshell.fullrmc.ui.solvent_shell_builder_window import ( + SolventShellBuilderMainWindow, +) from saxshell.saxs.dream import DreamRunSettings from saxshell.saxs.project_manager import ( DreamBestFitSelection, SAXSProjectManager, build_project_paths, + project_artifact_paths, ) from saxshell.saxs.stoichiometry import format_stoich_for_axis from saxshell.saxshell import main as saxshell_main -from saxshell.structure import PDBStructure +from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb import create_reference_molecule + + +def _integrated_solvent_handling_settings( + *, + reference_source: str = "preset", + preset_name: str = "dmf", + custom_reference_path: str | None = None, + director_atom_name: str | None = None, + pb_target_coordination: float = 1.0, + pb_cutoff_a: float = 2.6, + i_cutoff_a: float = 3.0, +) -> SolventHandlingSettings: + return SolventHandlingSettings.from_dict( + { + "coordinated_solvent_mode": "automatic_detection", + "reference_source": reference_source, + "preset_name": preset_name, + "custom_reference_path": custom_reference_path, + "director_atom_name": director_atom_name, + "minimum_solvent_atom_separation_a": 1.2, + "solute_atom_settings": { + "Pb": { + "coordination_center": True, + "target_coordination_number": pb_target_coordination, + "director_distance_cutoff_a": pb_cutoff_a, + }, + "I": { + "coordination_center": False, + "target_coordination_number": 0.0, + "director_distance_cutoff_a": i_cutoff_a, + }, + }, + } + ) def _build_sample_saxs_project(tmp_path): @@ -132,6 +190,294 @@ def _build_sample_saxs_project(tmp_path): return project_dir, paths +def _build_sample_distribution_scoped_saxs_project(tmp_path): + manager = SAXSProjectManager() + project_dir = tmp_path / "rmcsetup_distribution_source" + settings = manager.create_project(project_dir) + paths = build_project_paths(project_dir) + clusters_dir = _build_sample_clusters_dir(tmp_path) + settings.clusters_dir = str(clusters_dir) + settings.cluster_inventory_rows = [ + { + "structure": "PbI2", + "motif": "no_motif", + "count": 2, + "source_dir": str(clusters_dir / "PbI2"), + }, + { + "structure": "PbI2O", + "motif": "motif_1", + "count": 1, + "source_dir": str(clusters_dir / "PbI2O" / "motif_1"), + }, + ] + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_project_dirs(paths) + manager.ensure_artifact_dirs(artifact_paths) + + run_a = _write_sample_dream_run( + artifact_paths.dream_runtime_dir / "dream_run_001", + settings=DreamRunSettings( + bestfit_method="map", + posterior_filter_mode="all_post_burnin", + credible_interval_low=10.0, + credible_interval_high=90.0, + model_name="template_pd_likelihood_monosq_decoupled", + ), + template_name="template_pd_likelihood_monosq_decoupled", + ) + run_b = _write_sample_dream_run( + artifact_paths.dream_runtime_dir / "dream_run_002", + settings=DreamRunSettings( + bestfit_method="median", + posterior_filter_mode="top_percent_logp", + posterior_top_percent=7.5, + posterior_top_n=200, + credible_interval_low=20.0, + credible_interval_high=80.0, + model_name="template_pd_likelihood_monosq_decoupled", + ), + template_name="template_pd_likelihood_monosq_decoupled", + ) + + favorite = DreamBestFitSelection( + run_name=run_a.name, + run_relative_path=str(run_a.relative_to(project_dir)), + bestfit_method="chain_mean", + posterior_filter_mode="top_n_logp", + posterior_top_percent=10.0, + posterior_top_n=42, + credible_interval_low=25.0, + credible_interval_high=75.0, + label="2026-03-23T10:00:00 • dream_run_001 • chain_mean", + template_name="template_pd_likelihood_monosq_decoupled", + model_name="template_pd_likelihood_monosq_decoupled", + selection_source="rmcsetup", + selected_at="2026-03-23T10:00:00", + ) + history_entry = DreamBestFitSelection( + run_name=run_b.name, + run_relative_path=str(run_b.relative_to(project_dir)), + bestfit_method="median", + posterior_filter_mode="top_percent_logp", + posterior_top_percent=7.5, + posterior_top_n=200, + credible_interval_low=20.0, + credible_interval_high=80.0, + label="2026-03-22T09:00:00 • dream_run_002 • median", + template_name="template_pd_likelihood_monosq_decoupled", + model_name="template_pd_likelihood_monosq_decoupled", + selection_source="rmcsetup", + selected_at="2026-03-22T09:00:00", + ) + settings.dream_favorite_selection = favorite + settings.dream_favorite_history = [history_entry] + manager.save_project(settings) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + ) + return project_dir, paths, artifact_paths + + +def _build_sample_predicted_distribution_scoped_saxs_project(tmp_path): + manager = SAXSProjectManager() + project_dir = tmp_path / "rmcsetup_predicted_distribution_source" + settings = manager.create_project(project_dir) + paths = build_project_paths(project_dir) + clusters_dir = _build_sample_clusters_dir(tmp_path) + predicted_dir = tmp_path / "predicted_structures" + predicted_dir.mkdir(parents=True, exist_ok=True) + predicted_source = predicted_dir / "zn1_rank01.xyz" + predicted_source.write_text( + "\n".join( + [ + "1", + "predicted Zn1 structure", + "Zn 0.0 0.0 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + + settings.use_predicted_structure_weights = True + settings.clusters_dir = str(clusters_dir) + # Keep the saved settings stale on purpose so the loader has to prefer + # the active predicted-structure artifact inventory on window open. + settings.cluster_inventory_rows = [ + { + "structure": "PbI2", + "motif": "no_motif", + "count": 2, + "source_dir": str(clusters_dir / "PbI2"), + }, + { + "structure": "PbI2O", + "motif": "motif_1", + "count": 1, + "source_dir": str(clusters_dir / "PbI2O" / "motif_1"), + }, + ] + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_project_dirs(paths) + manager.ensure_artifact_dirs(artifact_paths) + + run = _write_sample_dream_run( + artifact_paths.dream_runtime_dir / "dream_run_predicted", + settings=DreamRunSettings( + bestfit_method="map", + posterior_filter_mode="all_post_burnin", + credible_interval_low=10.0, + credible_interval_high=90.0, + model_name="template_pd_likelihood_monosq_decoupled", + ), + template_name="template_pd_likelihood_monosq_decoupled", + weight_entries=[ + { + "structure": "PbI2", + "motif": "no_motif", + "param_type": "Both", + "param": "w0", + "value": 0.45, + "vary": True, + "distribution": "lognorm", + "dist_params": { + "loc": 0.0, + "scale": 0.45, + "s": 0.3, + }, + }, + { + "structure": "PbI2O", + "motif": "motif_1", + "param_type": "Both", + "param": "w1", + "value": 0.25, + "vary": True, + "distribution": "lognorm", + "dist_params": { + "loc": 0.0, + "scale": 0.25, + "s": 0.3, + }, + }, + { + "structure": "Zn1", + "motif": "predicted_rank01", + "param_type": "Both", + "param": "w2", + "value": 0.30, + "vary": True, + "distribution": "lognorm", + "dist_params": { + "loc": 0.0, + "scale": 0.30, + "s": 0.3, + }, + }, + ], + sampled_params=np.asarray( + [[[0.45, 0.25, 0.30], [0.40, 0.30, 0.30]]], + dtype=float, + ), + ) + + settings.dream_favorite_selection = DreamBestFitSelection( + run_name=run.name, + run_relative_path=str(run.relative_to(project_dir)), + bestfit_method="chain_mean", + posterior_filter_mode="top_n_logp", + posterior_top_percent=10.0, + posterior_top_n=42, + credible_interval_low=25.0, + credible_interval_high=75.0, + label="2026-03-23T11:00:00 • dream_run_predicted • chain_mean", + template_name="template_pd_likelihood_monosq_decoupled", + model_name="template_pd_likelihood_monosq_decoupled", + selection_source="rmcsetup", + selected_at="2026-03-23T11:00:00", + ) + manager.save_project(settings) + + artifact_paths.prior_weights_file.write_text( + json.dumps( + { + "origin": "clusters_predicted_structures", + "total_files": 4, + "includes_predicted_structures": True, + "structures": { + "PbI2": { + "no_motif": { + "count": 2, + "weight": 0.45, + "representative": "frame_0002.xyz", + "profile_file": "PbI2_no_motif.txt", + "source_kind": "cluster_dir", + "source_dir": str(clusters_dir / "PbI2"), + "source_file": str( + ( + clusters_dir / "PbI2" / "frame_0002.xyz" + ).resolve() + ), + "source_file_name": "frame_0002.xyz", + } + }, + "PbI2O": { + "motif_1": { + "count": 1, + "weight": 0.25, + "representative": "frame_0003.xyz", + "profile_file": "PbI2O_motif_1.txt", + "source_kind": "cluster_dir", + "source_dir": str( + clusters_dir / "PbI2O" / "motif_1" + ), + "source_file": str( + ( + clusters_dir + / "PbI2O" + / "motif_1" + / "frame_0003.xyz" + ).resolve() + ), + "source_file_name": "frame_0003.xyz", + } + }, + "Zn1": { + "predicted_rank01": { + "count": 1, + "weight": 0.30, + "representative": predicted_source.name, + "profile_file": "Zn1_predicted_rank01.txt", + "source_kind": "predicted_structure", + "source_dir": str(predicted_source.parent), + "source_file": str(predicted_source.resolve()), + "source_file_name": predicted_source.name, + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + ) + return project_dir, predicted_source.resolve() + + def _write_sample_dream_run( run_dir: Path, *, @@ -538,6 +884,212 @@ def _write_custom_solvent_pdb(tmp_path: Path) -> Path: return solvent_path +def _build_test_solvent_reference_library( + tmp_path: Path, +) -> tuple[Path, Path]: + reference_source = _write_custom_solvent_pdb(tmp_path) + reference_library_dir = tmp_path / "reference_library" + reference_library_dir.mkdir(parents=True, exist_ok=True) + result = create_reference_molecule( + reference_source, + reference_name="water_test", + residue_name="HOH", + library_dir=reference_library_dir, + ) + return reference_library_dir, result.path + + +def _write_test_solvent_shell_pdb( + tmp_path: Path, + *, + reference_path: Path, +) -> Path: + reference_structure = PDBStructure.from_file(reference_path) + atoms = [ + PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="Pb", + ) + ] + atom_id = 2 + for residue_name, residue_number, shift in ( + ("HOH", 2, np.array([3.0, 0.0, 0.0], dtype=float)), + ("ALT", 3, np.array([6.0, 0.0, 0.0], dtype=float)), + ): + for reference_atom in reference_structure.atoms: + atoms.append( + PDBAtom( + atom_id=atom_id, + atom_name=reference_atom.atom_name, + residue_name=residue_name, + residue_number=residue_number, + coordinates=reference_atom.coordinates.copy() + shift, + element=reference_atom.element, + ) + ) + atom_id += 1 + structure = PDBStructure(atoms=atoms, source_name="solvent_shell") + output_path = tmp_path / "solvent_shell_input.pdb" + structure.write_pdb_file(output_path) + return output_path + + +def _write_test_incomplete_solvent_shell_pdb( + tmp_path: Path, + *, + reference_path: Path, +) -> Path: + reference_structure = PDBStructure.from_file(reference_path) + atoms = [ + PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="Pb", + ) + ] + atom_id = 2 + complete_shift = np.array([3.0, 0.0, 0.0], dtype=float) + for reference_atom in reference_structure.atoms: + atoms.append( + PDBAtom( + atom_id=atom_id, + atom_name=reference_atom.atom_name, + residue_name="HOH", + residue_number=2, + coordinates=reference_atom.coordinates.copy() + complete_shift, + element=reference_atom.element, + ) + ) + atom_id += 1 + partial_shift = np.array([6.0, 0.0, 0.0], dtype=float) + for reference_atom in reference_structure.atoms[:2]: + atoms.append( + PDBAtom( + atom_id=atom_id, + atom_name=reference_atom.atom_name, + residue_name="HOH", + residue_number=3, + coordinates=reference_atom.coordinates.copy() + partial_shift, + element=reference_atom.element, + ) + ) + atom_id += 1 + structure = PDBStructure( + atoms=atoms, + source_name="solvent_shell_incomplete", + ) + output_path = tmp_path / "solvent_shell_incomplete_input.pdb" + structure.write_pdb_file(output_path) + return output_path + + +def _write_test_no_solvent_shell_pdb(tmp_path: Path) -> Path: + structure = PDBStructure( + atoms=[ + PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="Pb", + ) + ], + source_name="solvent_shell_none", + ) + output_path = tmp_path / "solvent_shell_none_input.pdb" + structure.write_pdb_file(output_path) + return output_path + + +def _write_test_no_solvent_mixed_shell_pdb(tmp_path: Path) -> Path: + structure = PDBStructure( + atoms=[ + PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([-1.2, 0.0, 0.0], dtype=float), + element="Pb", + ), + PDBAtom( + atom_id=2, + atom_name="PB2", + residue_name="PBI", + residue_number=1, + coordinates=np.array([1.2, 0.0, 0.0], dtype=float), + element="Pb", + ), + PDBAtom( + atom_id=3, + atom_name="I1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([-1.2, 2.6, 0.0], dtype=float), + element="I", + ), + PDBAtom( + atom_id=4, + atom_name="I2", + residue_name="PBI", + residue_number=1, + coordinates=np.array([1.2, -2.6, 0.0], dtype=float), + element="I", + ), + ], + source_name="solvent_shell_mixed_none", + ) + output_path = tmp_path / "solvent_shell_mixed_none_input.pdb" + structure.write_pdb_file(output_path) + return output_path + + +def _write_test_solvent_shell_xyz( + tmp_path: Path, + *, + reference_path: Path, +) -> Path: + reference_structure = PDBStructure.from_file(reference_path) + xyz_lines = ["7", "solvent shell xyz"] + for shift in ( + np.array([0.0, 0.0, 0.0], dtype=float), + np.array([4.0, 0.0, 0.0], dtype=float), + ): + for atom in reference_structure.atoms: + coordinates = atom.coordinates + shift + xyz_lines.append( + f"{atom.element} " + f"{coordinates[0]:.6f} " + f"{coordinates[1]:.6f} " + f"{coordinates[2]:.6f}" + ) + xyz_lines.append("Pb 9.000000 0.000000 0.000000") + output_path = tmp_path / "solvent_shell_input.xyz" + output_path.write_text("\n".join(xyz_lines) + "\n", encoding="utf-8") + return output_path + + +def _write_test_partial_solvent_shell_xyz(tmp_path: Path) -> Path: + xyz_lines = [ + "3", + "partial solvent xyz", + "Pb 0.000000 0.000000 0.000000", + "I 2.000000 0.000000 0.000000", + "O 0.000000 2.000000 0.000000", + ] + output_path = tmp_path / "partial_solvent_shell_input.xyz" + output_path.write_text("\n".join(xyz_lines) + "\n", encoding="utf-8") + return output_path + + def qapp(): os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") app = QApplication.instance() @@ -546,6 +1098,130 @@ def qapp(): return app +class _FakeSettings: + def __init__(self): + self.values: dict[str, object] = {} + + def value(self, key, default=None): + return self.values.get(key, default) + + def setValue(self, key, value): + self.values[key] = value + + +class _FakePackmolDockerClient: + def __init__(self): + self.listed_containers: int = 0 + self.verified_links: list[PackmolDockerLink] = [] + self.listed_directories: list[tuple[str, str]] = [] + self.synced_calls: list[tuple[str, Path, str | None]] = [] + + def list_containers(self) -> list[PackmolDockerContainerRecord]: + self.listed_containers += 1 + return [ + PackmolDockerContainerRecord( + name="analysis-helper", + image_name="ubuntu:22.04", + status="Exited (0) 2 hours ago", + ), + PackmolDockerContainerRecord( + name="packmol-dev", + image_name="packmol:test-image", + status="Up 8 minutes", + ), + ] + + def verify_link( + self, + link: PackmolDockerLink, + ) -> PackmolDockerValidationResult: + self.verified_links.append( + PackmolDockerLink.from_dict(link.to_dict()) or link + ) + return PackmolDockerValidationResult( + verified_at="2026-04-17T12:30:00", + container_id="sha256:fakepackmol", + image_name="packmol:test-image", + packmol_command_path="/usr/local/bin/packmol", + packmol_version="Packmol version 20.14.4", + container_project_root=link.container_project_root, + ) + + def list_directories( + self, + link: PackmolDockerLink, + directory: str, + ) -> list[PackmolDockerDirectoryEntry]: + self.listed_directories.append((link.container_name, directory)) + mapping = { + "/packmol_input_files": [ + PackmolDockerDirectoryEntry( + name="project_alpha", + path="/packmol_input_files/project_alpha", + ), + PackmolDockerDirectoryEntry( + name="project_beta", + path="/packmol_input_files/project_beta", + ), + ], + "/packmol_input_files/project_alpha": [ + PackmolDockerDirectoryEntry( + name="subrun", + path="/packmol_input_files/project_alpha/subrun", + ) + ], + } + return mapping.get(directory, []) + + def sync_packmol_inputs( + self, + link: PackmolDockerLink, + local_packmol_inputs_dir: str | Path, + *, + packmol_setup_metadata=None, + ) -> PackmolDockerSyncResult: + local_dir = Path(local_packmol_inputs_dir).resolve() + self.synced_calls.append( + ( + link.container_name, + local_dir, + ( + None + if packmol_setup_metadata is None + else packmol_setup_metadata.packmol_input_path + ), + ) + ) + input_name = ( + "packmol_combined.inp" + if packmol_setup_metadata is None + else Path(packmol_setup_metadata.packmol_input_path).name + ) + output_name = ( + "packed_combined.pdb" + if packmol_setup_metadata is None + else packmol_setup_metadata.packed_output_filename + ) + return PackmolDockerSyncResult( + synced_at="2026-04-17T12:45:00", + remote_packmol_inputs_dir=str(link.remote_packmol_inputs_dir()), + remote_packmol_input_path=str( + link.remote_packmol_inputs_dir() / input_name + ), + remote_packed_output_path=str( + link.remote_packmol_inputs_dir() / output_name + ), + synced_file_count=4, + ) + + +class _FakeRepresentativeStructuresWindow(QWidget): + project_results_changed = Signal(str) + + def __init__(self): + super().__init__() + + def _wait_for_representative_worker( window: RMCSetupMainWindow, *, @@ -563,11 +1239,44 @@ def _wait_for_representative_worker( assert window._representative_thread is None -def test_project_settings_roundtrip_preserves_dream_favorite_data(tmp_path): - project_dir, _paths = _build_sample_saxs_project(tmp_path) - settings = SAXSProjectManager().load_project(project_dir) +def _configure_integrated_rmcsetup_solvent_panel( + window: RMCSetupMainWindow, + *, + minimum_separation_a: float = 1.4, + pb_target_coordination: float = 1.0, + pb_cutoff_a: float = 2.6, +) -> None: + window.solvent_reference_source_combo.setCurrentIndex(0) + preset_index = window.solvent_preset_combo.findData("dmf") + assert preset_index >= 0 + window.solvent_preset_combo.setCurrentIndex(preset_index) + window.solvent_minimum_separation_spin.setValue(minimum_separation_a) + window._analyze_representative_solvent_states() - assert settings.dream_favorite_selection is not None + pb_row = None + for row in range(window.solvent_cutoff_table.rowCount()): + if window.solvent_cutoff_table.item(row, 0).text() == "Pb": + pb_row = row + break + assert pb_row is not None + + center_item = window.solvent_cutoff_table.item(pb_row, 2) + assert center_item is not None + center_item.setCheckState(Qt.CheckState.Checked) + + coordination_spin = window.solvent_cutoff_table.cellWidget(pb_row, 3) + cutoff_spin = window.solvent_cutoff_table.cellWidget(pb_row, 4) + assert isinstance(coordination_spin, QDoubleSpinBox) + assert isinstance(cutoff_spin, QDoubleSpinBox) + coordination_spin.setValue(pb_target_coordination) + cutoff_spin.setValue(pb_cutoff_a) + + +def test_project_settings_roundtrip_preserves_dream_favorite_data(tmp_path): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + settings = SAXSProjectManager().load_project(project_dir) + + assert settings.dream_favorite_selection is not None assert settings.dream_favorite_selection.bestfit_method == "chain_mean" assert settings.dream_favorite_selection.run_relative_path.endswith( "dream_run_001" @@ -603,6 +1312,160 @@ def test_fullrmc_project_loader_discovers_valid_runs_and_favorites(tmp_path): assert state.rmcsetup_paths.solution_properties_path.is_file() assert state.rmcsetup_paths.solvent_handling_path.is_file() assert state.rmcsetup_paths.representative_selection_path.is_file() + assert state.rmcsetup_paths.packmol_docker_link_path.is_file() + + +def test_fullrmc_project_loader_restores_packmol_docker_link(tmp_path): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + link = PackmolDockerLink( + display_name="Mounted Packmol Container", + container_name="packmol-test", + container_project_root="/packmol_input_files/project_alpha", + packmol_command="packmol", + shell_command="bash", + packmol_version="Packmol version 20.14.4", + linked_at="2026-04-17T09:00:00", + last_verified_at="2026-04-17T09:05:00", + container_id="sha256:restoreme", + image_name="packmol:latest", + packmol_command_path="/usr/local/bin/packmol", + ) + save_packmol_docker_link_metadata( + project_dir / "rmcsetup" / "packmol_docker_link.json", + link, + ) + + state = load_rmc_project_source(project_dir) + + assert state.packmol_docker_link is not None + assert state.packmol_docker_link.container_name == "packmol-test" + assert state.packmol_docker_link.container_project_root == ( + "/packmol_input_files/project_alpha" + ) + assert state.packmol_docker_link.packmol_command_path == ( + "/usr/local/bin/packmol" + ) + + +def test_packmol_docker_project_root_validation_targets_input_mount(): + assert container_project_root_is_valid("/packmol_input_files") + assert container_project_root_is_valid( + "/packmol_input_files/project_alpha" + ) + assert container_project_root_is_valid( + "/packmol_input_files/project_alpha/.." + ) + assert container_project_root_is_valid( + "/packmol_input_files/project_alpha/subrun" + ) + assert container_project_root_is_valid( + "/packmol_input_files/project_alpha/../subrun" + ) + assert ( + container_project_root_is_valid("/packmol_input_files_extra") is False + ) + assert ( + container_project_root_is_valid("/packmol_input_files/../etc") is False + ) + assert ( + container_project_root_is_valid( + "/packmol_input_files/project_alpha/../../etc" + ) + is False + ) + assert container_project_root_is_valid("/tmp/project_alpha") is False + + +def test_fullrmc_project_loader_discovers_distribution_scoped_runs( + tmp_path, +): + project_dir, _paths, artifact_paths = ( + _build_sample_distribution_scoped_saxs_project(tmp_path) + ) + + state = load_rmc_project_source(project_dir) + + assert [run.run_name for run in state.valid_runs] == [ + "dream_run_002", + "dream_run_001", + ] + assert all( + run.relative_path.startswith("saved_distributions/") + for run in state.valid_runs + ) + assert state.favorite_selection is not None + assert state.find_run_for_selection(state.favorite_selection) is not None + assert ( + state.valid_runs[0].run_dir.parent == artifact_paths.dream_runtime_dir + ) + + +def test_fullrmc_project_loader_uses_active_predicted_distribution_rows( + tmp_path, +): + project_dir, predicted_source = ( + _build_sample_predicted_distribution_scoped_saxs_project(tmp_path) + ) + + state = load_rmc_project_source(project_dir) + + assert state.cluster_validation.is_valid is True + assert any( + row.get("source_kind") == "predicted_structure" + for row in state.cluster_validation.expected_rows + ) + assert all( + row.get("structure") != "Zn1" + for row in state.cluster_validation.current_rows + ) + + selection = state.favorite_selection + assert selection is not None + distribution = build_distribution_selection(state, selection) + predicted_entry = next( + entry + for entry in distribution.entries + if entry.source_kind == "predicted_structure" + ) + + assert predicted_entry.structure == "Zn1" + assert predicted_entry.motif == "predicted_rank01" + assert predicted_entry.source_file == str(predicted_source) + assert predicted_entry.source_file_name == predicted_source.name + assert predicted_entry.cluster_count == 1 + + +def test_fullrmc_project_loader_falls_back_to_project_root_runs( + tmp_path, +): + project_dir, paths = _build_sample_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + ) + + state = load_rmc_project_source(project_dir) + + assert [run.run_name for run in state.valid_runs] == [ + "dream_run_002", + "dream_run_001", + ] + assert all( + run.run_dir.parent == paths.dream_runtime_dir + for run in state.valid_runs + ) + assert all( + not run.relative_path.startswith("saved_distributions/") + for run in state.valid_runs + ) def test_fullrmc_project_loader_detects_cluster_count_drift(tmp_path): @@ -886,23 +1749,30 @@ def test_build_representative_solvent_outputs_with_preset_reference( metadata = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - ), + _integrated_solvent_handling_settings(), ) reloaded = load_solvent_handling_metadata( state.rmcsetup_paths.solvent_handling_path ) assert metadata.reference_name == "dmf" + assert metadata.detected_distribution_status == "no_solvent" assert len(metadata.entries) == 2 assert all( entry.atom_count_completed > entry.atom_count_no_solvent for entry in metadata.entries ) assert all(entry.solvent_atoms_added > 0 for entry in metadata.entries) + assert all( + Path(entry.completed_pdb).parent + == state.rmcsetup_paths.pdb_with_solvent_dir / entry.structure + for entry in metadata.entries + ) + assert all( + Path(entry.no_solvent_pdb).parent + == state.rmcsetup_paths.pdb_no_solvent_dir / entry.structure + for entry in metadata.entries + ) assert reloaded is not None assert reloaded.settings.preset_name == "dmf" assert reloaded.settings.minimum_solvent_atom_separation_a == ( @@ -924,12 +1794,7 @@ def test_partial_coordinated_solvent_replaces_anchor_atoms_and_points_outward( metadata = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - minimum_solvent_atom_separation_a=1.2, - ), + _integrated_solvent_handling_settings(), ) anchored_entry = next( @@ -937,12 +1802,14 @@ def test_partial_coordinated_solvent_replaces_anchor_atoms_and_points_outward( for entry in metadata.entries if entry.structure == "PbI2O" and entry.motif == "motif_1" ) - assert anchored_entry.atom_count_no_solvent == 4 + assert metadata.detected_distribution_status == "no_solvent" + assert anchored_entry.atom_count_no_solvent == 3 assert anchored_entry.atom_count_completed == 15 - assert anchored_entry.solvent_atoms_added == 11 + assert anchored_entry.solvent_atoms_added == 12 assert anchored_entry.solvent_molecules_added == 1 - assert anchored_entry.completion_strategy.startswith( - "anchored_solvent_completion" + assert ( + anchored_entry.completion_strategy + == "rebuilt_from_no_solvent_distribution" ) completed_structure = PDBStructure.from_file(anchored_entry.completed_pdb) @@ -959,31 +1826,9 @@ def test_partial_coordinated_solvent_replaces_anchor_atoms_and_points_outward( assert len(solvent_atoms) == 12 assert len(solute_atoms) == 3 - anchor_atom = next(atom for atom in solvent_atoms if atom.element == "O") - assert anchor_atom.coordinates == pytest.approx([0.0, 0.0, 1.0], abs=1e-3) - - solute_center = np.mean( - [atom.coordinates for atom in solute_atoms], - axis=0, - ) - solvent_body_center = np.mean( - [ - atom.coordinates - for atom in solvent_atoms - if atom.atom_id != anchor_atom.atom_id - ], - axis=0, - ) - outward_alignment = np.dot( - solvent_body_center - anchor_atom.coordinates, - anchor_atom.coordinates - solute_center, - ) - assert outward_alignment > 0.0 - minimum_distance = min( np.linalg.norm(solvent_atom.coordinates - solute_atom.coordinates) for solvent_atom in solvent_atoms - if solvent_atom.atom_id != anchor_atom.atom_id for solute_atom in solute_atoms ) assert minimum_distance >= 1.2 - 1e-6 @@ -1004,15 +1849,16 @@ def test_build_representative_solvent_outputs_with_custom_reference( metadata = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", + _integrated_solvent_handling_settings( reference_source="custom", custom_reference_path=str(custom_solvent), + director_atom_name="O1", ), ) assert metadata.reference_name == "water_ref" assert metadata.reference_residue_name == "HOH" + assert metadata.detected_distribution_status == "no_solvent" added_counts = { (entry.structure, entry.motif): entry.solvent_atoms_added for entry in metadata.entries @@ -1022,7 +1868,7 @@ def test_build_representative_solvent_outputs_with_custom_reference( for entry in metadata.entries } assert added_counts[("PbI2", "no_motif")] == 3 - assert added_counts[("PbI2O", "motif_1")] == 2 + assert added_counts[("PbI2O", "motif_1")] == 3 assert molecule_counts[("PbI2", "no_motif")] == 1 assert molecule_counts[("PbI2O", "motif_1")] == 1 @@ -1043,11 +1889,7 @@ def test_build_representative_solvent_outputs_preserves_single_atom_sources( metadata = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - ), + _integrated_solvent_handling_settings(), ) zn_entry = next( @@ -1060,7 +1902,10 @@ def test_build_representative_solvent_outputs_preserves_single_atom_sources( assert zn_entry.atom_count_completed == 1 assert zn_entry.solvent_atoms_added == 0 assert zn_entry.solvent_molecules_added == 0 - assert zn_entry.completion_strategy == "preserved_single_structure_file" + assert ( + zn_entry.completion_strategy + == "preserved_without_matching_coordination_settings" + ) assert [atom.element for atom in completed_structure.atoms] == ["Zn"] @@ -1120,6 +1965,47 @@ def test_build_packmol_plan_writes_metadata_and_reports(tmp_path): ) +def test_build_packmol_plan_requires_all_positive_weight_representatives( + tmp_path, +): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + metadata = select_first_file_representatives( + state, + selection, + ) + metadata.representative_entries = metadata.representative_entries[:1] + state.representative_selection = metadata + solution_settings = SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=solution_settings, + result=calculate_solution_properties(solution_settings), + ) + + with pytest.raises( + ValueError, match="exactly one representative structure" + ): + build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + ), + ) + + def test_build_packmol_plan_includes_single_atom_model_sources(tmp_path): project_dir, _paths, _single_atom_path = ( _build_sample_saxs_project_with_single_atom_model(tmp_path) @@ -1158,11 +2044,7 @@ def test_build_packmol_plan_includes_single_atom_model_sources(tmp_path): ) state.solvent_handling = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - ), + _integrated_solvent_handling_settings(), ) metadata = build_packmol_plan( @@ -1184,7 +2066,7 @@ def test_build_packmol_plan_includes_single_atom_model_sources(tmp_path): assert zn_entry.planned_count_weight >= 0.0 -def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): +def test_build_packmol_plan_tracks_solvent_allocation(tmp_path): project_dir, _paths = _build_sample_saxs_project(tmp_path) state = load_rmc_project_source(project_dir) selection = state.favorite_selection @@ -1220,12 +2102,76 @@ def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): ) state.solvent_handling = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", + _integrated_solvent_handling_settings(), + ) + + metadata = build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + ), + ) + + allocation = metadata.solvent_allocation + + assert allocation is not None + assert allocation.reference_name == "dmf" + assert allocation.target_solvent_molecules == int( + metadata.target_box_composition["solvent_molecules"] + ) + assert allocation.solvent_molecules_in_clusters == sum( + entry.solvent_molecules_total for entry in allocation.entries + ) + assert allocation.free_solvent_molecules == max( + 0, + allocation.target_solvent_molecules + - allocation.solvent_molecules_in_clusters, + ) + assert any( + entry.solvent_molecules_total > 0 for entry in allocation.entries + ) + assert "Cluster solvent molecules:" in metadata.summary_text() + + +def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ), + result=calculate_solution_properties( + SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ) ), ) + state.solvent_handling = build_representative_solvent_outputs( + state, + _integrated_solvent_handling_settings(), + ) state.packmol_planning = build_packmol_plan( state, PackmolPlanningSettings( @@ -1251,6 +2197,17 @@ def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): assert Path(metadata.audit_report_path).is_file() assert reloaded is not None assert reloaded.free_solvent_molecules >= 0 + assert metadata.free_solvent_reference_name == "dmf" + assert state.packmol_planning.solvent_allocation is not None + assert metadata.target_solvent_molecules == ( + state.packmol_planning.solvent_allocation.target_solvent_molecules + ) + assert metadata.solvent_molecules_in_clusters == ( + state.packmol_planning.solvent_allocation.solvent_molecules_in_clusters + ) + assert metadata.free_solvent_molecules == ( + state.packmol_planning.solvent_allocation.free_solvent_molecules + ) assert len({entry.residue_name for entry in metadata.entries}) == len( metadata.entries ) @@ -1262,9 +2219,59 @@ def test_build_packmol_setup_writes_input_files_and_audit(tmp_path): assert "structure 001_PbI2O_motif_1_CAA.pdb" in packmol_text audit_text = Path(metadata.audit_report_path).read_text(encoding="utf-8") assert "# Packmol Build Audit" in audit_text + assert "Cluster solvent molecules:" in audit_text assert "Count-normalized weights" in audit_text +def test_build_packmol_setup_requires_all_positive_weight_representatives( + tmp_path, +): + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + state.representative_selection = select_first_file_representatives( + state, + selection, + ) + solution_settings = SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ) + state.solution_properties = save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=solution_settings, + result=calculate_solution_properties(solution_settings), + ) + state.packmol_planning = build_packmol_plan( + state, + PackmolPlanningSettings( + planning_mode="per_element", + box_side_length_a=80.0, + ), + ) + assert state.representative_selection is not None + state.representative_selection.representative_entries = ( + state.representative_selection.representative_entries[:1] + ) + + with pytest.raises( + ValueError, match="exactly one representative structure" + ): + build_packmol_setup( + state, + PackmolSetupSettings( + include_free_solvent=False, + ), + ) + + def test_build_packmol_setup_includes_single_atom_model_sources(tmp_path): project_dir, _paths, _single_atom_path = ( _build_sample_saxs_project_with_single_atom_model(tmp_path) @@ -1303,11 +2310,7 @@ def test_build_packmol_setup_includes_single_atom_model_sources(tmp_path): ) state.solvent_handling = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - ), + _integrated_solvent_handling_settings(), ) state.packmol_planning = build_packmol_plan( state, @@ -1363,11 +2366,7 @@ def test_build_constraint_generation_writes_per_structure_and_merged_files( ) state.solvent_handling = build_representative_solvent_outputs( state, - SolventHandlingSettings( - coordinated_solvent_mode="partial_coordinated_solvent", - reference_source="preset", - preset_name="dmf", - ), + _integrated_solvent_handling_settings(), ) state.packmol_planning = build_packmol_plan( state, @@ -1486,27 +2485,107 @@ def test_rmcsetup_main_window_prefills_project_and_applies_favorite(tmp_path): assert window._project_source_state.cluster_validation.is_valid is True assert not hasattr(window, "cluster_validation_box") assert not hasattr(window, "validation_warning_label") - assert "/rmcsetup/representative_clusters" in ( + assert "/rmcsetup/representative_structures" in ( window.output_summary_box.toPlainText() ) assert window.solution_group.isEnabled() is True assert "No saved solution-properties calculation yet" in ( window.solution_output_box.toPlainText() ) - assert "No representative selection has been saved yet" in ( + assert window.representative_group.title() == "Representative Structures" + assert window.compute_representatives_button.text() == ( + "Open Representative Structures" + ) + assert window.preview_representatives_button.text() == ( + "Reload Saved Representative Structures" + ) + assert "No representative structures have been saved yet" in ( window.representative_summary_box.toPlainText() ) -def test_rmcsetup_software_details_section_is_collapsible_and_link_ready( +def test_rmcsetup_main_window_lists_distribution_scoped_dream_runs(tmp_path): + qapp() + project_dir, _paths, _artifact_paths = ( + _build_sample_distribution_scoped_saxs_project(tmp_path) + ) + + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + assert [window.dream_run_combo.itemText(i) for i in range(2)] == [ + "dream_run_002", + "dream_run_001", + ] + assert window.dream_run_combo.currentText() == "dream_run_001" + assert window._project_source_state is not None + assert all( + run.relative_path.startswith("saved_distributions/") + for run in window._project_source_state.valid_runs + ) + + +def test_rmcsetup_representative_panel_opens_representative_structures_tool( tmp_path, + monkeypatch, ): qapp() project_dir, _paths = _build_sample_saxs_project(tmp_path) - window = RMCSetupMainWindow(initial_project_dir=project_dir) + fake_window = _FakeRepresentativeStructuresWindow() + launched: dict[str, Path | None] = {} - assert window.software_details_button.isCheckable() + import saxshell.representativefinder.ui.main_window as representativefinder_ui_module + + def fake_launch( + *, + initial_project_dir=None, + initial_input_path=None, + ): + launched["initial_project_dir"] = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + launched["initial_input_path"] = ( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ) + return fake_window + + monkeypatch.setattr( + representativefinder_ui_module, + "launch_representativefinder_ui", + fake_launch, + ) + + window.compute_representatives_button.click() + + assert launched["initial_project_dir"] == project_dir.resolve() + assert launched["initial_input_path"] == ( + window._project_source_state.settings.resolved_clusters_dir + ) + assert fake_window in window._child_tool_windows + + fake_window.project_results_changed.emit(str(project_dir.resolve())) + QApplication.processEvents() + + assert "Representative structures were updated in the dedicated tool" in ( + window.run_log_box.toPlainText() + ) + fake_window.close() + window.close() + + +def test_rmcsetup_software_details_section_is_collapsible_and_link_ready( + tmp_path, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + assert window.software_details_button.isCheckable() assert window.software_details_button.isChecked() is False assert window.software_details_panel.isHidden() is True assert window.software_details_label.openExternalLinks() is True @@ -1546,10 +2625,57 @@ def test_rmcsetup_main_window_uses_two_scrollable_panes_with_splitter( assert window._left_panel.isAncestorOf(window.solution_group) assert window._right_panel.isAncestorOf(window.dream_preview_group) assert window._right_panel.isAncestorOf(window.representative_group) + assert window._right_panel.isAncestorOf(window.solvent_group) assert window._right_panel.isAncestorOf(window.packmol_group) assert window._right_panel.isAncestorOf(window.run_log_group) +def test_rmcsetup_readiness_sections_can_collapse_without_hiding_status( + tmp_path, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + readiness_sections = { + "project_source": ["project_source"], + "dream_selection": ["dream_selection"], + "solution_properties": ["solution_properties"], + "representative_selection": ["representative_selection"], + "solvent_outputs": ["solvent_outputs"], + "packmol": ["packmol_plan", "packmol_setup"], + } + + for section_key, readiness_keys in readiness_sections.items(): + toggle = window._section_toggle_buttons[section_key] + content = window._section_content_widgets[section_key] + assert content.isHidden() is False + for readiness_key in readiness_keys: + checkbox = window._readiness_checkboxes[readiness_key] + assert content.isAncestorOf(checkbox) is False + assert checkbox.isHidden() is False + toggle.click() + assert content.isHidden() is True + assert toggle.text() == "Expand" + for readiness_key in readiness_keys: + assert ( + window._readiness_checkboxes[readiness_key].isHidden() is False + ) + toggle.click() + assert content.isHidden() is False + assert toggle.text() == "Collapse" + + representative_content = window._section_content_widgets[ + "representative_selection" + ] + assert representative_content.isAncestorOf(window.solvent_group) is False + window._section_toggle_buttons["representative_selection"].click() + assert representative_content.isHidden() is True + assert window.solvent_group.isHidden() is False + assert window._readiness_checkboxes["solvent_outputs"].isHidden() is False + + def test_rmcsetup_main_window_renders_selected_dream_preview_and_tooltips( tmp_path, ): @@ -1605,6 +2731,187 @@ def test_rmcsetup_main_window_renders_selected_dream_preview_and_tooltips( ) +def test_packmol_docker_link_dialog_validates_container_and_updates_tree( + tmp_path, +): + qapp() + client = _FakePackmolDockerClient() + dialog = fullrmc_ui_module.PackmolDockerLinkDialog( + recent_presets=[ + PackmolDockerLink( + display_name="Mounted Packmol", + container_name="packmol-dev", + container_project_root="/packmol_input_files", + ) + ], + docker_client=client, + ) + + assert client.listed_containers >= 1 + assert dialog.available_container_combo.count() == 2 + assert "packmol-dev" in dialog.available_container_combo.itemText(1) + assert dialog._test_connection() is True + assert "Docker validation succeeded." in dialog.status_box.toPlainText() + assert "Packmol version 20.14.4" in dialog.status_box.toPlainText() + assert dialog.directory_tree.topLevelItemCount() == 1 + + root_item = dialog.directory_tree.topLevelItem(0) + assert root_item.text(0) == "/packmol_input_files" + assert root_item.childCount() == 2 + + first_child = root_item.child(0) + dialog.directory_tree.setCurrentItem(first_child) + dialog._use_selected_directory() + + assert dialog.container_root_edit.text() == ( + "/packmol_input_files/project_alpha" + ) + dialog.close() + + +def test_packmol_docker_link_dialog_can_load_discovered_container_name(): + qapp() + dialog = fullrmc_ui_module.PackmolDockerLinkDialog( + docker_client=_FakePackmolDockerClient(), + ) + + assert dialog.available_container_combo.count() == 2 + dialog.available_container_combo.setCurrentIndex(1) + dialog._use_available_container() + + assert dialog.container_name_edit.text() == "packmol-dev" + assert "Press Test Container to verify Packmol" in ( + dialog.status_box.toPlainText() + ) + dialog.close() + + +def test_packmol_docker_link_dialog_rejects_invalid_container_project_root(): + qapp() + + class _RejectingClient(_FakePackmolDockerClient): + def verify_link(self, link): + raise RuntimeError( + "Container project root must be inside /packmol_input_files " + "so Packmol input files stay inside the expected bind-mounted folder." + ) + + dialog = fullrmc_ui_module.PackmolDockerLinkDialog( + docker_client=_RejectingClient(), + ) + dialog.container_name_edit.setText("packmol-dev") + dialog.container_root_edit.setText("/tmp/not_allowed") + + assert dialog._test_connection() is False + assert "must be inside /packmol_input_files" in ( + dialog.status_box.toPlainText() + ) + dialog.close() + + +def test_packmol_docker_link_dialog_explains_docker_daemon_failure(): + qapp() + + class _DaemonDownClient(_FakePackmolDockerClient): + def verify_link(self, link): + del link + raise RuntimeError( + 'WARNING: Plugin "/Users/test/.docker/cli-plugins/docker-scan" ' + "is not valid: failed to fetch metadata: fork/exec " + "/Users/test/.docker/cli-plugins/docker-scan: no such file " + "or directory\n" + "ERROR: Cannot connect to the Docker daemon at " + "unix:///Users/test/.docker/run/docker.sock. Is the docker " + "daemon running?\n" + "errors pretty printing info" + ) + + dialog = fullrmc_ui_module.PackmolDockerLinkDialog( + docker_client=_DaemonDownClient(), + ) + dialog.container_name_edit.setText("packmol-dev") + + assert dialog._test_connection() is False + assert ( + "Docker Desktop or the Docker daemon does not appear to be running." + in (dialog.status_box.toPlainText()) + ) + assert "wait for `docker info` to succeed" in ( + dialog.status_box.toPlainText() + ) + assert "Cannot connect to the Docker daemon" in ( + dialog.status_box.toPlainText() + ) + dialog.close() + + +def test_rmcsetup_tools_menu_can_link_packmol_docker_container( + tmp_path, + monkeypatch, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + settings_store = _FakeSettings() + linked = PackmolDockerLink( + display_name="Saved Container", + container_name="packmol-dev", + container_project_root="/packmol_input_files/project_alpha", + packmol_command="packmol", + shell_command="sh", + packmol_version="Packmol version 20.14.4", + last_verified_at="2026-04-17T12:30:00", + container_id="sha256:dialog", + image_name="packmol:test-image", + packmol_command_path="/usr/local/bin/packmol", + ) + + class _FakeDialog: + def __init__(self, *args, **kwargs) -> None: + del args, kwargs + + def exec(self): + return 1 + + def selected_link(self): + return PackmolDockerLink.from_dict(linked.to_dict()) + + monkeypatch.setattr( + RMCSetupMainWindow, + "_packmol_docker_settings", + lambda self: settings_store, + ) + monkeypatch.setattr( + fullrmc_ui_module, + "PackmolDockerLinkDialog", + _FakeDialog, + ) + + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + assert window.tools_menu.title() == "Tools" + assert window.link_packmol_docker_action.text() == ( + "Link Packmol Docker Container" + ) + + window._open_packmol_docker_link_dialog() + + assert window._project_source_state is not None + assert window._project_source_state.packmol_docker_link is not None + assert window._project_source_state.packmol_docker_link.container_name == ( + "packmol-dev" + ) + assert "packmol-dev" in window.packmol_docker_summary_box.toPlainText() + assert "Packmol version 20.14.4" in ( + window.packmol_docker_summary_box.toPlainText() + ) + raw_presets = settings_store.value("packmol_docker_presets", "[]") + preset_payload = json.loads(raw_presets) + assert preset_payload[0]["container_name"] == "packmol-dev" + + reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) + assert "packmol-dev" in reloaded.packmol_docker_summary_box.toPlainText() + + def test_rmcsetup_solution_properties_mode_switch_changes_active_page( tmp_path, ): @@ -1915,18 +3222,22 @@ def test_rmcsetup_solvent_handling_ui_builds_and_reloads(tmp_path): window._compute_representative_clusters() _wait_for_representative_worker(window) - window.coordinated_solvent_mode_combo.setCurrentIndex(1) - window.solvent_reference_source_combo.setCurrentIndex(0) - window.solvent_minimum_separation_spin.setValue(1.4) + _configure_integrated_rmcsetup_solvent_panel(window) + window._solvent_distribution_analysis = None + assert window.build_solvent_outputs_button.isEnabled() is True window._build_representative_solvent_outputs() assert "reference molecule: dmf" in ( window.solvent_summary_box.toPlainText().lower() ) + assert ( + "Detected representative distribution state: No solvent molecules detected" + in (window.solvent_summary_box.toPlainText()) + ) assert "Minimum solvent atom separation: 1.4 A" in ( window.solvent_summary_box.toPlainText() ) - assert window.generated_pdb_table.rowCount() == 4 + assert window.generated_pdb_table.rowCount() == 2 generated_rows = { ( window.generated_pdb_table.item(row, 0).text(), @@ -1935,13 +3246,11 @@ def test_rmcsetup_solvent_handling_ui_builds_and_reloads(tmp_path): for row in range(window.generated_pdb_table.rowCount()) } assert set(generated_rows) == { - ("PbI2", "No solvent"), - ("PbI2", "With solvent"), - ("PbI2O/motif_1", "No solvent"), - ("PbI2O/motif_1", "With solvent"), + ("PbI2", "No solvent molecules detected"), + ("PbI2O/motif_1", "Partial solvent molecules detected"), } window.generated_pdb_table.selectRow( - generated_rows[("PbI2", "With solvent")] + generated_rows[("PbI2", "No solvent molecules detected")] ) details_text = window.generated_pdb_details_box.toPlainText() assert "Atom count: 15" in details_text @@ -1950,9 +3259,10 @@ def test_rmcsetup_solvent_handling_ui_builds_and_reloads(tmp_path): details_text ) assert "PBI 1: 3 atoms (I:2, Pb:1)" in details_text - assert "DMF 2: 12 atoms (C:3, H:7, N:1, O:1)" in details_text + assert "12 atoms (C:3, H:7, N:1, O:1)" in details_text assert "1: Pb1 (Pb) -> PBI 1" in details_text - assert "4: O1 (O) -> DMF 2" in details_text + assert "O1 (O) -> DMF" in details_text + assert window.generated_pdb_viewer.current_structure is not None window._open_selected_generated_pdb_preview() @@ -1982,146 +3292,946 @@ def test_rmcsetup_solvent_handling_ui_builds_and_reloads(tmp_path): len(generated_preview_window.figure.axes[0].collections) > initial_collection_count ) - assert "Built representative solvent-aware PDB outputs." in ( + assert "Built solvent-decorated representative PDB outputs." in ( window.run_log_box.toPlainText() ) assert "Opened generated PDB preview:" in window.run_log_box.toPlainText() reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) - assert reloaded.coordinated_solvent_mode_combo.currentData() == ( - "partial_coordinated_solvent" - ) assert reloaded.solvent_reference_source_combo.currentData() == "preset" assert reloaded.solvent_preset_combo.currentData() == "dmf" assert reloaded.solvent_minimum_separation_spin.value() == pytest.approx( 1.4 ) - assert reloaded.generated_pdb_table.rowCount() == 4 + assert reloaded.generated_pdb_table.rowCount() == 2 assert "Representative entries exported: 2" in ( reloaded.solvent_summary_box.toPlainText() ) -def test_rmcsetup_ui_can_compute_packmol_plan_and_reload(tmp_path): +def test_rmcsetup_imported_full_solvent_representatives_mark_solvent_ready( + tmp_path, +): qapp() project_dir, _paths = _build_sample_saxs_project(tmp_path) - window = RMCSetupMainWindow(initial_project_dir=project_dir) - - window.solution_density_spin.setValue(1.05) - window.solute_stoich_edit.setText("Pb1I2") - window.solvent_stoich_edit.setText("C3H7NO") - window.molar_mass_solute_spin.setValue(461.0) - window.molar_mass_solvent_spin.setValue(73.09) - window.mass_solute_spin.setValue(4.61) - window.mass_solvent_spin.setValue(95.39) - window._calculate_solution_properties() - window._compute_representative_clusters() - _wait_for_representative_worker(window) - window.packmol_box_side_spin.setValue(80.0) - - window._compute_packmol_plan() - - assert "Planned clusters:" in ( - window.packmol_plan_summary_box.toPlainText() + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + metadata = select_first_file_representatives( + state, + selection, + ) + state.representative_selection = metadata + solvent_metadata = build_representative_solvent_outputs( + state, + _integrated_solvent_handling_settings(), ) - axes = window._packmol_plan_figure.axes - assert len(axes) == 3 - assert [axis.get_title() for axis in axes] == [ - "Original cluster distribution", - "DREAM fit model distribution used", - "Packmol planned distribution", - ] - assert all(len(axis.patches) == 4 for axis in axes) - assert [tick.get_text() for tick in axes[-1].get_xticklabels()] == [ - format_stoich_for_axis("PbI2"), - format_stoich_for_axis("PbI2O"), - ] - plan_entries = { - (entry.structure, entry.motif): entry - for entry in window._project_source_state.packmol_planning.entries + solvent_lookup = { + (entry.structure, entry.motif, entry.param): entry + for entry in solvent_metadata.entries } - assert [patch.get_height() for patch in axes[0].patches] == pytest.approx( - [66.6666666667, 0.0, 0.0, 33.3333333333] - ) - dream_total = sum(entry.selected_weight for entry in plan_entries.values()) - assert [patch.get_height() for patch in axes[1].patches] == pytest.approx( - [ - 100.0 - * plan_entries[("PbI2", "no_motif")].selected_weight - / dream_total, - 0.0, - 0.0, - 100.0 - * plan_entries[("PbI2O", "motif_1")].selected_weight - / dream_total, - ] - ) - assert [patch.get_height() for patch in axes[2].patches] == pytest.approx( - [ - 100.0 * plan_entries[("PbI2", "no_motif")].planned_count_weight, - 0.0, - 0.0, - 100.0 * plan_entries[("PbI2O", "motif_1")].planned_count_weight, + for entry in metadata.representative_entries: + solvent_entry = solvent_lookup[ + (entry.structure, entry.motif, entry.param) ] + completed_path = Path(solvent_entry.completed_pdb).resolve() + entry.source_dir = str(completed_path.parent) + entry.source_file = str(completed_path) + entry.source_file_name = completed_path.name + entry.source_solvent_mode = "unknown" + save_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path, + metadata, + ) + state.rmcsetup_paths.solvent_handling_path.write_text( + "{}\n", + encoding="utf-8", ) - assert "Computed Packmol planning counts." in ( - window.run_log_box.toPlainText() + + reloaded_selection = load_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path ) - assert "packmol_plan.json" in window.output_summary_box.toPlainText() + assert reloaded_selection is not None + assert { + entry.source_solvent_mode + for entry in reloaded_selection.representative_entries + } == {"fullsolv"} - reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) - assert reloaded.packmol_planning_mode_combo.currentData() == ( - "per_element" + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + assert window._readiness_checkboxes["solvent_outputs"].isChecked() is True + assert window.readiness_progress_bar.value() == 4 + assert window.generated_pdb_mode_combo.currentData() == "full_solvent" + assert window.generated_pdb_table.rowCount() == 2 + assert { + window.generated_pdb_table.item(row, 1).text() + for row in range(window.generated_pdb_table.rowCount()) + } == {"Full solvent analyzed"} + assert ( + "Imported representative structures already include the Full solvent structure set" + in (window.solvent_summary_box.toPlainText()) ) - assert reloaded.packmol_box_side_spin.value() == pytest.approx(80.0) - assert "Planned clusters:" in ( - reloaded.packmol_plan_summary_box.toPlainText() + assert ( + "The active representative source files already provide the Full " + "solvent structure set." + ) in window.solvent_status_stats_label.text() + assert "Solvent Shell Builder readiness: Ready for Packmol" in ( + window.solvent_status_stats_label.text() ) + assert window.solvent_group.title() == "Solvent Shell Builder" + assert window.analyze_solvent_outputs_button.isEnabled() is False + assert window.build_solvent_outputs_button.isEnabled() is False + assert window.solvent_cutoff_group.isEnabled() is False + window.close() -def test_rmcsetup_ui_packmol_preview_includes_single_atom_model_sources( +def test_rmcsetup_imported_full_solvent_representatives_can_build_packmol_setup( tmp_path, ): qapp() - project_dir, _paths, _single_atom_path = ( - _build_sample_saxs_project_with_single_atom_model(tmp_path) + project_dir, _paths = _build_sample_saxs_project(tmp_path) + state = load_rmc_project_source(project_dir) + selection = state.favorite_selection + assert selection is not None + metadata = select_first_file_representatives( + state, + selection, + ) + state.representative_selection = metadata + save_solution_properties_metadata( + state.rmcsetup_paths.solution_properties_path, + settings=SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ), + result=calculate_solution_properties( + SolutionPropertiesSettings( + mode="mass", + solution_density=1.05, + solute_stoich="Pb1I2", + solvent_stoich="C3H7NO", + molar_mass_solute=461.0, + molar_mass_solvent=73.09, + mass_solute=4.61, + mass_solvent=95.39, + ) + ), + ) + solvent_metadata = build_representative_solvent_outputs( + state, + _integrated_solvent_handling_settings(), ) + + solvent_lookup = { + (entry.structure, entry.motif, entry.param): entry + for entry in solvent_metadata.entries + } + for entry in metadata.representative_entries: + solvent_entry = solvent_lookup[ + (entry.structure, entry.motif, entry.param) + ] + completed_path = Path(solvent_entry.completed_pdb).resolve() + entry.source_dir = str(completed_path.parent) + entry.source_file = str(completed_path) + entry.source_file_name = completed_path.name + entry.source_solvent_mode = "unknown" + save_representative_selection_metadata( + state.rmcsetup_paths.representative_selection_path, + metadata, + ) + state.rmcsetup_paths.solvent_handling_path.write_text( + "{}\n", + encoding="utf-8", + ) + window = RMCSetupMainWindow(initial_project_dir=project_dir) - window.solution_density_spin.setValue(1.05) - window.solute_stoich_edit.setText("Pb1I2") - window.solvent_stoich_edit.setText("C3H7NO") - window.molar_mass_solute_spin.setValue(461.0) - window.molar_mass_solvent_spin.setValue(73.09) - window.mass_solute_spin.setValue(4.61) - window.mass_solvent_spin.setValue(95.39) - window._calculate_solution_properties() - window._compute_representative_clusters() - _wait_for_representative_worker(window) - window.coordinated_solvent_mode_combo.setCurrentIndex(1) - window._build_representative_solvent_outputs() + assert window._project_source_state is not None + assert window._project_source_state.solvent_handling is None + + dmf_index = window.packmol_free_solvent_combo.findText("dmf") + assert dmf_index >= 0 + window.packmol_free_solvent_combo.setCurrentIndex(dmf_index) window.packmol_box_side_spin.setValue(80.0) window._compute_packmol_plan() + window._build_packmol_setup() - axes = window._packmol_plan_figure.axes - assert len(axes) == 3 - assert all(len(axis.patches) == 6 for axis in axes) - assert format_stoich_for_axis("Zn1") in [ - tick.get_text() for tick in axes[-1].get_xticklabels() - ] - assert any( - entry.structure == "Zn1" - for entry in window._project_source_state.packmol_planning.entries + assert "Total solvent molecules:" in ( + window.packmol_plan_summary_box.toPlainText() ) - - + assert "Cluster solvent molecules:" in ( + window.packmol_plan_summary_box.toPlainText() + ) + assert "Free solvent structure: dmf" in ( + window.packmol_build_summary_box.toPlainText() + ) + assert "Free solvent molecules:" in ( + window.packmol_build_summary_box.toPlainText() + ) + assert window.open_packmol_setup_folder_button.isEnabled() is True + assert "Built Packmol setup inputs and audit report." in ( + window.run_log_box.toPlainText() + ) + window.close() + + +def test_solvent_shell_builder_analysis_detects_pdb_residue_types(tmp_path): + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + + result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + + assert result.input_format == "pdb" + assert result.detected_solvent_molecules == 2 + assert result.has_solvent_molecules is True + assert result.matched_atom_count == 6 + assert result.unmatched_atom_count == 1 + assert [ + summary.residue_name for summary in result.matched_residue_summaries + ] == [ + "ALT", + "HOH", + ] + assert { + summary.residue_name: summary.residue_numbers + for summary in result.matched_residue_summaries + } == { + "ALT": (3,), + "HOH": (2,), + } + assert "Solvent molecules detected: 2" in result.summary_text() + assert "ALT: 1 molecule(s)" in result.summary_text() + + +def test_solvent_shell_builder_analysis_detects_xyz_solvent_count(tmp_path): + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_solvent_shell_xyz( + tmp_path, + reference_path=reference_path, + ) + + result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + + assert result.input_format == "xyz" + assert result.detected_solvent_molecules == 2 + assert result.matched_residue_summaries == () + assert result.matched_atom_count == 6 + assert result.unmatched_atom_count == 1 + assert result.no_solvent_status_text == "no" + assert result.partial_solvent_status_text == "no" + assert result.complete_solvent_status_text == "yes" + assert ( + result.cluster_solvent_status_text + == "Complete solvent molecules detected." + ) + assert "Matched residue types: n/a for XYZ inputs" in result.summary_text() + + +def test_solvent_shell_builder_analysis_infers_partial_xyz_solvent_candidates( + tmp_path, +): + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_partial_solvent_shell_xyz(tmp_path) + + result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + + assert result.input_format == "xyz" + assert result.complete_solvent_molecule_count == 0 + assert result.partial_solvent_molecule_count == 1 + assert result.no_solvent_status_text == "no" + assert result.partial_solvent_status_text == "yes" + assert result.complete_solvent_status_text == "no" + assert ( + result.cluster_solvent_status_text + == "Partial solvent molecules detected." + ) + assert len(result.residue_mismatch_summaries) == 1 + candidate = result.residue_mismatch_summaries[0] + assert candidate.residue_name == "HOH" + assert candidate.common_atom_count == 1 + assert candidate.reference_atom_count == 3 + assert candidate.missing_atom_names == ("H1", "H2") + assert candidate.source_atom_ids == (3,) + summary_text = result.summary_text() + assert "Partial solvent candidate count: 1" in summary_text + assert "XYZ partial solvent candidates:" in summary_text + assert "source atom ids 3" in summary_text + + +def test_solvent_shell_builder_analysis_identifies_no_solvent_status(tmp_path): + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_no_solvent_shell_pdb(tmp_path) + + result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + + assert result.input_format == "pdb" + assert result.complete_solvent_molecule_count == 0 + assert result.partial_solvent_molecule_count == 0 + assert result.no_solvent_status_text == "yes" + assert result.partial_solvent_status_text == "no" + assert result.complete_solvent_status_text == "no" + assert ( + result.cluster_solvent_status_text == "No solvent molecules detected." + ) + assert "Cluster solvent status: No solvent molecules detected." in ( + result.summary_text() + ) + + +def test_solvent_shell_builder_analysis_preserves_incomplete_pdb_residues( + tmp_path, +): + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_incomplete_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + + result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + + assert result.input_format == "pdb" + assert result.detected_solvent_molecules == 1 + assert result.matched_atom_count == 3 + assert result.unmatched_atom_count == 3 + assert [ + summary.residue_name for summary in result.matched_residue_summaries + ] == ["HOH"] + assert result.matched_residue_summaries[0].residue_numbers == (2,) + assert len(result.residue_mismatch_summaries) == 1 + mismatch = result.residue_mismatch_summaries[0] + assert mismatch.residue_name == "HOH" + assert mismatch.residue_number == 3 + assert mismatch.observed_atom_count == 2 + assert mismatch.common_atom_count == 2 + assert mismatch.reference_atom_count == 3 + assert mismatch.missing_atom_names == ("H2",) + assert mismatch.extra_atom_names == () + assert result.no_solvent_status_text == "no" + assert result.partial_solvent_status_text == "yes" + assert result.complete_solvent_status_text == "yes" + assert ( + result.cluster_solvent_status_text + == "Complete and partial solvent molecules detected." + ) + summary_text = result.summary_text() + assert "Residue mismatches preserved: 1" in summary_text + assert "HOH 3: missing reference atoms" in summary_text + assert "missing H2" in summary_text + + +def test_solvent_shell_builder_builds_no_solvent_output_pdb(tmp_path): + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_no_solvent_shell_pdb(tmp_path) + + analysis_result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + output_path = tmp_path / "solvent_shell_no_solvent_output.pdb" + build_result = build_solvent_shell_output( + input_path, + "water_test", + output_path=output_path, + director_atom_name="O1", + minimum_solvent_atom_separation_a=1.2, + solute_distance_cutoffs_a={"Pb": 2.6}, + coordinating_center_elements=("Pb",), + target_average_coordination_numbers={"Pb": 1.0}, + reference_library_dir=reference_library_dir, + analysis_result=analysis_result, + ) + + output_structure = PDBStructure.from_file(output_path) + solvent_atoms = [ + atom for atom in output_structure.atoms if atom.residue_name == "HOH" + ] + oxygen_atoms = [atom for atom in solvent_atoms if atom.element == "O"] + + assert build_result.build_mode == "no_solvent_shell_build" + assert build_result.solvent_molecules_added == 1 + assert build_result.solvent_atoms_added == 3 + assert build_result.partial_candidates_completed == 0 + assert len(output_structure.atoms) == 4 + assert len(solvent_atoms) == 3 + assert len(oxygen_atoms) == 1 + assert np.allclose( + oxygen_atoms[0].coordinates, + np.array([2.6, 0.0, 0.0], dtype=float), + ) + assert build_result.target_average_coordination_numbers == {"Pb": 1.0} + assert build_result.achieved_average_coordination_numbers == {"Pb": 1.0} + + +def test_solvent_shell_builder_builds_no_solvent_output_using_average_coordination( + tmp_path, +): + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_no_solvent_mixed_shell_pdb(tmp_path) + + analysis_result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + output_path = tmp_path / "solvent_shell_coordination_average_output.pdb" + build_result = build_solvent_shell_output( + input_path, + "water_test", + output_path=output_path, + director_atom_name="O1", + minimum_solvent_atom_separation_a=1.2, + solute_distance_cutoffs_a={"I": 3.0, "Pb": 2.6}, + coordinating_center_elements=("Pb",), + target_average_coordination_numbers={"Pb": 2.0}, + reference_library_dir=reference_library_dir, + analysis_result=analysis_result, + ) + + assert build_result.build_mode == "no_solvent_shell_build" + assert 2 <= build_result.solvent_molecules_added <= 4 + assert build_result.target_average_coordination_numbers == {"Pb": 2.0} + assert build_result.achieved_average_coordination_numbers is not None + assert build_result.achieved_average_coordination_numbers["Pb"] >= 2.0 + + +def test_solvent_shell_builder_prefers_octahedral_vacancies_for_center_candidates(): + center_atom = PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="Pb", + ) + solute_atoms = [ + center_atom, + PDBAtom( + atom_id=2, + atom_name="I1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([2.6, 0.0, 0.0], dtype=float), + element="I", + ), + PDBAtom( + atom_id=3, + atom_name="I2", + residue_name="PBI", + residue_number=1, + coordinates=np.array([-2.6, 0.0, 0.0], dtype=float), + element="I", + ), + ] + + candidate_positions = ( + solvent_shell_builder_module._coordination_candidate_positions( + center_atoms=[center_atom], + solute_atoms=solute_atoms, + existing_anchor_positions=[], + solute_distance_cutoffs_a={"Pb": 2.6}, + ) + ) + + assert len(candidate_positions) >= 4 + leading_positions = candidate_positions[:4] + assert all(abs(float(position[0])) < 0.2 for position in leading_positions) + assert any(abs(float(position[1])) > 2.0 for position in leading_positions) + assert any(abs(float(position[2])) > 2.0 for position in leading_positions) + + +def test_solvent_shell_builder_builds_partial_xyz_output_pdb(tmp_path): + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_partial_solvent_shell_xyz(tmp_path) + + analysis_result = analyze_solvent_shell( + input_path, + "water_test", + reference_library_dir=reference_library_dir, + ) + output_path = tmp_path / "solvent_shell_partial_xyz_output.pdb" + build_result = build_solvent_shell_output( + input_path, + "water_test", + output_path=output_path, + director_atom_name="O1", + minimum_solvent_atom_separation_a=1.2, + solute_distance_cutoffs_a={"I": 3.0, "Pb": 2.6}, + reference_library_dir=reference_library_dir, + analysis_result=analysis_result, + ) + + output_structure = PDBStructure.from_file(output_path) + solvent_atoms = [ + atom for atom in output_structure.atoms if atom.residue_name == "HOH" + ] + oxygen_atoms = [atom for atom in solvent_atoms if atom.element == "O"] + + assert build_result.build_mode == "partial_solvent_completion" + assert build_result.solvent_molecules_added == 1 + assert build_result.solvent_atoms_added == 2 + assert build_result.partial_candidates_completed == 1 + assert build_result.replaced_source_atom_count == 1 + assert len(output_structure.atoms) == 5 + assert len(solvent_atoms) == 3 + assert len(oxygen_atoms) == 1 + assert np.allclose( + oxygen_atoms[0].coordinates, + np.array([0.0, 2.0, 0.0], dtype=float), + ) + + +def test_solvent_shell_builder_window_reports_residue_breakdown(tmp_path): + qapp() + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + central_layout = window.centralWidget().layout() + assert central_layout.itemAt(1).widget() is window._pane_splitter + assert window._pane_splitter.count() == 2 + assert window._pane_splitter.widget(0) is window._left_scroll_area + assert window._pane_splitter.widget(1) is window._right_scroll_area + assert window._left_scroll_area.widget() is window._left_panel + assert window._right_scroll_area.widget() is window._right_panel + assert window.cluster_status_group.parentWidget() is window._left_panel + assert "Residue HOH" in window.reference_details_box.toPlainText() + assert window.structure_viewer.current_structure is not None + assert window.structure_viewer.current_structure.file_path == input_path + assert "complete solvent molecules detected" in ( + window.cluster_status_headline_label.text().lower() + ) + assert ( + "No solvent molecules: no" in window.cluster_status_stats_label.text() + ) + assert "Partial solvent molecules: no" in ( + window.cluster_status_stats_label.text() + ) + assert "Complete solvent molecules: yes" in ( + window.cluster_status_stats_label.text() + ) + assert ( + "Complete solvent count: 2" in window.cluster_status_stats_label.text() + ) + assert "Solvent molecules detected: 2" in window.summary_box.toPlainText() + assert "PDB residue matches:" in window.summary_box.toPlainText() + assert window.residue_table.rowCount() == 2 + table_rows = { + window.residue_table.item(row, 0).text(): { + "molecules": window.residue_table.item(row, 1).text(), + "numbers": window.residue_table.item(row, 2).text(), + } + for row in range(window.residue_table.rowCount()) + } + assert table_rows == { + "ALT": {"molecules": "1", "numbers": "3"}, + "HOH": {"molecules": "1", "numbers": "2"}, + } + assert "matched the selected solvent geometry" in ( + window.residue_status_label.text().lower() + ) + window.close() + + +def test_solvent_shell_builder_window_reports_incomplete_residue_mismatch( + tmp_path, +): + qapp() + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_incomplete_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + assert "complete and partial solvent molecules detected" in ( + window.cluster_status_headline_label.text().lower() + ) + assert ( + "No solvent molecules: no" in window.cluster_status_stats_label.text() + ) + assert "Partial solvent molecules: yes" in ( + window.cluster_status_stats_label.text() + ) + assert "Complete solvent molecules: yes" in ( + window.cluster_status_stats_label.text() + ) + assert ( + "Complete solvent count: 1" in window.cluster_status_stats_label.text() + ) + assert "Partial solvent residue count: 1" in ( + window.cluster_status_stats_label.text() + ) + assert ( + "Residue mismatches preserved: 1" in window.summary_box.toPlainText() + ) + assert window.mismatch_table.rowCount() == 1 + assert window.mismatch_table.item(0, 0).text() == "HOH" + assert window.mismatch_table.item(0, 1).text() == "3" + assert window.mismatch_table.item(0, 2).text() == "2" + assert window.mismatch_table.item(0, 3).text() == "2/3" + assert window.mismatch_table.item(0, 4).text() == "H2" + assert window.mismatch_table.item(0, 5).text() == "none" + assert ( + "missing-atom details" in window.mismatch_status_label.text().lower() + ) + window.close() + + +def test_solvent_shell_builder_window_reports_partial_xyz_candidates(tmp_path): + qapp() + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_partial_solvent_shell_xyz(tmp_path) + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + assert "partial solvent molecules detected" in ( + window.cluster_status_headline_label.text().lower() + ) + assert "Partial solvent molecules: yes" in ( + window.cluster_status_stats_label.text() + ) + assert "Partial solvent candidate count: 1" in ( + window.cluster_status_stats_label.text() + ) + assert window.mismatch_table.rowCount() == 1 + assert window.mismatch_table.item(0, 0).text() == "HOH" + assert window.mismatch_table.item(0, 3).text() == "1/3" + assert window.mismatch_table.item(0, 4).text() == "H1, H2" + assert "xyz atom sets" in window.mismatch_status_label.text().lower() + window.close() + + +def test_solvent_shell_builder_window_populates_build_controls_and_writes_output( + tmp_path, +): + qapp() + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_partial_solvent_shell_xyz(tmp_path) + output_path = tmp_path / "beta_builder_output.pdb" + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + assert [ + window.director_atom_combo.itemText(index) for index in range(3) + ] == [ + "O1", + "H1", + "H2", + ] + assert window.director_atom_combo.currentData() == "O1" + assert window.build_output_button.isEnabled() is True + assert window.solute_cutoff_table.rowCount() == 2 + assert window.solute_cutoff_table.item(0, 0).text() == "I" + assert window.solute_cutoff_table.item(0, 1).text() == "1" + assert window.solute_cutoff_table.item(1, 0).text() == "Pb" + assert window.solute_cutoff_table.item(1, 1).text() == "1" + + window.output_path_edit.setText(str(output_path)) + window._build_solvated_output() + + assert output_path.is_file() + assert ( + "Generated solvent shell output:" in window.summary_box.toPlainText() + ) + assert "Build mode: partial_solvent_completion" in ( + window.summary_box.toPlainText() + ) + assert ( + "Previewing generated output" in window.visualizer_status_label.text() + ) + window.close() + + +def test_solvent_shell_builder_window_requires_coordination_targets_for_no_solvent_build( + tmp_path, +): + qapp() + reference_library_dir, _reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_no_solvent_mixed_shell_pdb(tmp_path) + output_path = tmp_path / "beta_coordination_target_output.pdb" + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + assert window.solute_cutoff_table.rowCount() == 2 + assert window.build_output_button.isEnabled() is False + assert ( + "needs coordination targets" + in window.build_status_label.text().lower() + ) + assert window.solute_cutoff_table.item(0, 0).text() == "I" + assert window.solute_cutoff_table.item(1, 0).text() == "Pb" + + pb_center_item = window.solute_cutoff_table.item(1, 2) + assert pb_center_item is not None + pb_center_item.setCheckState(Qt.CheckState.Checked) + pb_coordination_spin = window.solute_cutoff_table.cellWidget(1, 3) + assert isinstance(pb_coordination_spin, QDoubleSpinBox) + pb_coordination_spin.setValue(2.0) + + assert window.build_output_button.isEnabled() is True + window.output_path_edit.setText(str(output_path)) + window._build_solvated_output() + + assert output_path.is_file() + assert ( + "Target average coordination: Pb:2" in window.summary_box.toPlainText() + ) + assert ( + "Achieved average coordination: Pb:" + in window.summary_box.toPlainText() + ) + window.close() + + +def test_solvent_shell_builder_window_uses_selected_match_tolerance(tmp_path): + qapp() + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + input_path = _write_test_solvent_shell_pdb( + tmp_path, + reference_path=reference_path, + ) + window = SolventShellBuilderMainWindow( + initial_input_path=input_path, + reference_library_dir=reference_library_dir, + ) + + assert window.reference_match_tolerance_spin.value() == pytest.approx( + DEFAULT_REFERENCE_MATCH_TOLERANCE_A + ) + window.reference_match_tolerance_spin.setValue(0.5) + preset_index = window.reference_preset_combo.findData("water_test") + assert preset_index >= 0 + window.reference_preset_combo.setCurrentIndex(preset_index) + window._analyze_input_structure() + + assert ( + "Reference match tolerance: 0.5 A" in window.summary_box.toPlainText() + ) + window.close() + + +def test_rmcsetup_ui_can_compute_packmol_plan_and_reload(tmp_path): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + window.packmol_box_side_spin.setValue(80.0) + + window._compute_packmol_plan() + + assert "Planned clusters:" in ( + window.packmol_plan_summary_box.toPlainText() + ) + axes = window._packmol_plan_figure.axes + assert len(axes) == 3 + assert [axis.get_title() for axis in axes] == [ + "Original cluster distribution", + "DREAM fit model distribution used", + "Packmol planned distribution", + ] + assert all(len(axis.patches) == 4 for axis in axes) + assert [tick.get_text() for tick in axes[-1].get_xticklabels()] == [ + format_stoich_for_axis("PbI2"), + format_stoich_for_axis("PbI2O"), + ] + + plan_entries = { + (entry.structure, entry.motif): entry + for entry in window._project_source_state.packmol_planning.entries + } + assert [patch.get_height() for patch in axes[0].patches] == pytest.approx( + [66.6666666667, 0.0, 0.0, 33.3333333333] + ) + dream_total = sum(entry.selected_weight for entry in plan_entries.values()) + assert [patch.get_height() for patch in axes[1].patches] == pytest.approx( + [ + 100.0 + * plan_entries[("PbI2", "no_motif")].selected_weight + / dream_total, + 0.0, + 0.0, + 100.0 + * plan_entries[("PbI2O", "motif_1")].selected_weight + / dream_total, + ] + ) + assert [patch.get_height() for patch in axes[2].patches] == pytest.approx( + [ + 100.0 * plan_entries[("PbI2", "no_motif")].planned_count_weight, + 0.0, + 0.0, + 100.0 * plan_entries[("PbI2O", "motif_1")].planned_count_weight, + ] + ) + assert "Computed Packmol planning counts." in ( + window.run_log_box.toPlainText() + ) + assert "packmol_plan.json" in window.output_summary_box.toPlainText() + + reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) + assert reloaded.packmol_planning_mode_combo.currentData() == ( + "per_element" + ) + assert reloaded.packmol_box_side_spin.value() == pytest.approx(80.0) + assert "Planned clusters:" in ( + reloaded.packmol_plan_summary_box.toPlainText() + ) + + +def test_rmcsetup_ui_packmol_preview_includes_single_atom_model_sources( + tmp_path, +): + qapp() + project_dir, _paths, _single_atom_path = ( + _build_sample_saxs_project_with_single_atom_model(tmp_path) + ) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + _configure_integrated_rmcsetup_solvent_panel(window) + window._build_representative_solvent_outputs() + window.packmol_box_side_spin.setValue(80.0) + + window._compute_packmol_plan() + + axes = window._packmol_plan_figure.axes + assert len(axes) == 3 + assert all(len(axis.patches) == 6 for axis in axes) + assert format_stoich_for_axis("Zn1") in [ + tick.get_text() for tick in axes[-1].get_xticklabels() + ] + assert any( + entry.structure == "Zn1" + for entry in window._project_source_state.packmol_planning.entries + ) + + def test_rmcsetup_ui_can_build_packmol_setup_and_reload(tmp_path): qapp() project_dir, _paths = _build_sample_saxs_project(tmp_path) window = RMCSetupMainWindow(initial_project_dir=project_dir) + assert window.packmol_tolerance_spin.value() == pytest.approx(2.0) + window.solution_density_spin.setValue(1.05) window.solute_stoich_edit.setText("Pb1I2") window.solvent_stoich_edit.setText("C3H7NO") @@ -2132,28 +4242,157 @@ def test_rmcsetup_ui_can_build_packmol_setup_and_reload(tmp_path): window._calculate_solution_properties() window._compute_representative_clusters() _wait_for_representative_worker(window) - window.coordinated_solvent_mode_combo.setCurrentIndex(1) + _configure_integrated_rmcsetup_solvent_panel(window) window._build_representative_solvent_outputs() window.packmol_box_side_spin.setValue(80.0) window._compute_packmol_plan() + window.packmol_tolerance_spin.setValue(2.2) window._build_packmol_setup() + assert window._project_source_state is not None + assert window._project_source_state.packmol_setup is not None + assert "Packmol tolerance: 2.200 A" in ( + window.packmol_build_summary_box.toPlainText() + ) + assert "tolerance 2.200" in Path( + window._project_source_state.packmol_setup.packmol_input_path + ).read_text(encoding="utf-8") assert "Representative PDBs copied:" in ( window.packmol_build_summary_box.toPlainText() ) + assert window.open_packmol_setup_folder_button.isEnabled() is True assert "Built Packmol setup inputs and audit report." in ( window.run_log_box.toPlainText() ) assert "packmol_combined.inp" in window.output_summary_box.toPlainText() reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) + assert reloaded.packmol_tolerance_spin.value() == pytest.approx(2.2) + assert "Packmol tolerance: 2.200 A" in ( + reloaded.packmol_build_summary_box.toPlainText() + ) assert "Representative PDBs copied:" in ( reloaded.packmol_build_summary_box.toPlainText() ) + assert reloaded.open_packmol_setup_folder_button.isEnabled() is True assert "packmol_audit.md" in reloaded.output_summary_box.toPlainText() +def test_rmcsetup_ui_can_open_packmol_setup_folder( + tmp_path, + monkeypatch, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + opened_paths: list[Path] = [] + monkeypatch.setattr( + window, + "_open_path_in_file_manager", + lambda path: opened_paths.append(path), + ) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + _configure_integrated_rmcsetup_solvent_panel(window) + window._build_representative_solvent_outputs() + window.packmol_box_side_spin.setValue(80.0) + window._compute_packmol_plan() + window._build_packmol_setup() + + assert window.open_packmol_setup_folder_button.isEnabled() is True + + window.open_packmol_setup_folder_button.click() + + assert opened_paths == [ + window._project_source_state.rmcsetup_paths.packmol_inputs_dir.resolve() + ] + assert "Opened Packmol setup folder in Finder/file manager:" in ( + window.run_log_box.toPlainText() + ) + assert window.statusBar().currentMessage() == ( + "Opened Packmol setup folder: packmol_inputs" + ) + + +def test_rmcsetup_ui_syncs_packmol_setup_to_linked_docker_container( + tmp_path, + monkeypatch, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + fake_client = _FakePackmolDockerClient() + monkeypatch.setattr( + RMCSetupMainWindow, + "_create_packmol_docker_client", + lambda self: fake_client, + ) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + assert window._project_source_state is not None + + linked = PackmolDockerLink( + display_name="Mounted Packmol", + container_name="packmol-dev", + container_project_root="/packmol_input_files/project_alpha", + packmol_command="packmol", + shell_command="sh", + packmol_version="Packmol version 20.14.4", + last_verified_at="2026-04-17T12:30:00", + container_id="sha256:sync", + image_name="packmol:test-image", + packmol_command_path="/usr/local/bin/packmol", + ) + window._save_packmol_docker_link(linked) + window.packmol_docker_summary_box.setPlainText( + window._packmol_docker_summary_text() + ) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + _configure_integrated_rmcsetup_solvent_panel(window) + window._build_representative_solvent_outputs() + window.packmol_box_side_spin.setValue(80.0) + window._compute_packmol_plan() + + window._build_packmol_setup() + + assert fake_client.synced_calls + assert fake_client.synced_calls[0][0] == "packmol-dev" + assert window._project_source_state.packmol_docker_link is not None + assert ( + window._project_source_state.packmol_docker_link.last_sync_status + == ("success") + ) + assert "/packmol_input_files/project_alpha/rmcsetup/packmol_inputs" in ( + window.packmol_build_summary_box.toPlainText() + ) + assert "Last sync status: success" in ( + window.packmol_docker_summary_box.toPlainText() + ) + + reloaded = RMCSetupMainWindow(initial_project_dir=project_dir) + assert "Last sync status: success" in ( + reloaded.packmol_docker_summary_box.toPlainText() + ) + + def test_rmcsetup_ui_can_generate_constraints_and_reload(tmp_path): qapp() project_dir, _paths = _build_sample_saxs_project(tmp_path) @@ -2169,7 +4408,7 @@ def test_rmcsetup_ui_can_generate_constraints_and_reload(tmp_path): window._calculate_solution_properties() window._compute_representative_clusters() _wait_for_representative_worker(window) - window.coordinated_solvent_mode_combo.setCurrentIndex(1) + _configure_integrated_rmcsetup_solvent_panel(window) window._build_representative_solvent_outputs() window.packmol_box_side_spin.setValue(80.0) window._compute_packmol_plan() @@ -2182,6 +4421,8 @@ def test_rmcsetup_ui_can_generate_constraints_and_reload(tmp_path): assert "Per-structure files:" in ( window.constraints_summary_box.toPlainText() ) + assert window.open_constraints_folder_button.isEnabled() is True + assert window.preview_constraints_button.isEnabled() is True assert ( "Generated per-structure constraints and merged fullrmc constraints." in (window.run_log_box.toPlainText()) @@ -2200,6 +4441,99 @@ def test_rmcsetup_ui_can_generate_constraints_and_reload(tmp_path): assert "Per-structure files:" in ( reloaded.constraints_summary_box.toPlainText() ) + assert reloaded.open_constraints_folder_button.isEnabled() is True + assert reloaded.preview_constraints_button.isEnabled() is True + + +def test_rmcsetup_ui_can_open_constraints_folder( + tmp_path, + monkeypatch, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + opened_paths: list[Path] = [] + monkeypatch.setattr( + window, + "_open_path_in_file_manager", + lambda path: opened_paths.append(path), + ) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + _configure_integrated_rmcsetup_solvent_panel(window) + window._build_representative_solvent_outputs() + window.packmol_box_side_spin.setValue(80.0) + window._compute_packmol_plan() + window._build_packmol_setup() + window._generate_constraints() + + assert window.open_constraints_folder_button.isEnabled() is True + + window.open_constraints_folder_button.click() + + assert window._project_source_state is not None + assert window._project_source_state.constraint_generation is not None + assert opened_paths == [ + Path( + window._project_source_state.constraint_generation.merged_constraints_path + ).resolve() + ] + assert "Opened constraints file location in Finder/file manager:" in ( + window.run_log_box.toPlainText() + ) + assert window.statusBar().currentMessage() == ( + "Opened constraints file location: merged_fullrmc_constraints.py" + ) + + +def test_rmcsetup_ui_can_preview_merged_constraints( + tmp_path, +): + qapp() + project_dir, _paths = _build_sample_saxs_project(tmp_path) + window = RMCSetupMainWindow(initial_project_dir=project_dir) + + window.solution_density_spin.setValue(1.05) + window.solute_stoich_edit.setText("Pb1I2") + window.solvent_stoich_edit.setText("C3H7NO") + window.molar_mass_solute_spin.setValue(461.0) + window.molar_mass_solvent_spin.setValue(73.09) + window.mass_solute_spin.setValue(4.61) + window.mass_solvent_spin.setValue(95.39) + window._calculate_solution_properties() + window._compute_representative_clusters() + _wait_for_representative_worker(window) + _configure_integrated_rmcsetup_solvent_panel(window) + window._build_representative_solvent_outputs() + window.packmol_box_side_spin.setValue(80.0) + window._compute_packmol_plan() + window._build_packmol_setup() + window._generate_constraints() + + assert window.preview_constraints_button.isEnabled() is True + + window.preview_constraints_button.click() + + assert window._constraints_preview_window is not None + assert "BOND_ANGLE_CONSTRAINTS = {" in ( + window._constraints_preview_window.text_box.toPlainText() + ) + assert "BOND_LENGTH_CONSTRAINTS = {" in ( + window._constraints_preview_window.text_box.toPlainText() + ) + assert "Opened merged constraints preview:" in ( + window.run_log_box.toPlainText() + ) + window._constraints_preview_window.close() def test_rmcsetup_cluster_validation_runs_in_backend_for_cluster_drift( @@ -2324,7 +4658,7 @@ def test_rmcsetup_ui_end_to_end_pipeline_updates_readiness_and_outputs( window.representative_angle_triplet_table.item(0, 4).setText("3.5") window._compute_representative_clusters() _wait_for_representative_worker(window) - window.coordinated_solvent_mode_combo.setCurrentIndex(1) + _configure_integrated_rmcsetup_solvent_panel(window) window._build_representative_solvent_outputs() window.packmol_box_side_spin.setValue(80.0) window._compute_packmol_plan() From d67dab11f09c911b6ae73756e70a42ce4daffccb Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:47:22 -0600 Subject: [PATCH 3/7] feat(representativefinder): add project-backed CLI and UI Introduce the representativefinder package, console entry point, run-file workflow, Qt setup/analyzer windows, and representative persistence helpers. Add tests and a performance harness while keeping generated benchmark outputs ignored. --- .gitignore | 3 + pyproject.toml | 1 + src/saxshell/representativefinder/__init__.py | 67 + src/saxshell/representativefinder/__main__.py | 4 + src/saxshell/representativefinder/cli.py | 246 ++ .../representativefinder/run_config.py | 539 +++ .../representativefinder/ui/__init__.py | 11 + .../representativefinder/ui/main_window.py | 3166 ++++++++++++++ .../ui/run_file_window.py | 692 +++ src/saxshell/representativefinder/workflow.py | 3896 +++++++++++++++++ src/saxshell/saxs/contrast/descriptors.py | 187 +- .../benchmark_representativefinder.py | 161 + .../test_parallel_representativefinder.py | 141 + tests/test_representativefinder.py | 1676 +++++++ 14 files changed, 10724 insertions(+), 66 deletions(-) create mode 100644 src/saxshell/representativefinder/__init__.py create mode 100644 src/saxshell/representativefinder/__main__.py create mode 100644 src/saxshell/representativefinder/cli.py create mode 100644 src/saxshell/representativefinder/run_config.py create mode 100644 src/saxshell/representativefinder/ui/__init__.py create mode 100644 src/saxshell/representativefinder/ui/main_window.py create mode 100644 src/saxshell/representativefinder/ui/run_file_window.py create mode 100644 src/saxshell/representativefinder/workflow.py create mode 100644 tests/representativefinder_performance/benchmark_representativefinder.py create mode 100644 tests/representativefinder_performance/test_parallel_representativefinder.py create mode 100644 tests/test_representativefinder.py diff --git a/.gitignore b/.gitignore index 88a157e..68e8832 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,10 @@ site/ # pytest .pytest_cache/ +tests/born_vs_debye_backend_debug_*/ +tests/contrast_fft_backend_debug_*/ tests/edm_contiguous_mode_report/ +tests/representativefinder_performance/output_results/ tests/smearing_analysis_*/ # PyBuilder diff --git a/pyproject.toml b/pyproject.toml index f5af0ad..08240b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ clusters = "saxshell.cluster.cli:main" mdtrajectory = "saxshell.mdtrajectory.cli:main" saxshell = "saxshell.saxshell:main" pdfsetup = "saxshell.pdfsetup:main" +representativefinder = "saxshell.representativefinder.cli:main" structureviewer = "saxshell.saxs.structure_viewer.cli:main" xyz2pdb = "saxshell.xyz2pdb.cli:main" diff --git a/src/saxshell/representativefinder/__init__.py b/src/saxshell/representativefinder/__init__.py new file mode 100644 index 0000000..af2c3c6 --- /dev/null +++ b/src/saxshell/representativefinder/__init__.py @@ -0,0 +1,67 @@ +"""Representative-structure selection workflow and UI.""" + +from .run_config import ( + DEFAULT_RUN_FILE_NAME, + RepresentativeFinderRunConfig, + RepresentativeFinderRunExecutionSummary, + RepresentativeFinderRunFailure, + RepresentativeFinderRunTarget, + build_representativefinder_run_config, + default_representativefinder_run_file_path, + load_representativefinder_run_config, + representativefinder_run_targets, + representativefinder_settings_from_dict, + run_representativefinder_run_config, + save_representativefinder_run_config, + suggest_run_config_output_dir, +) +from .workflow import ( + RepresentativeFinderCandidate, + RepresentativeFinderFolderInspection, + RepresentativeFinderInputInspection, + RepresentativeFinderOperationCancelled, + RepresentativeFinderPlotSeries, + RepresentativeFinderResult, + RepresentativeFinderSettings, + analyze_representative_structure_folder, + estimate_representativefinder_total_work, + inspect_representative_structure_folder, + inspect_representative_structure_input, + load_representativefinder_result, + persist_representativefinder_result_to_project, + representativefinder_result_from_dict, + suggest_representativefinder_output_dir, + suggest_representativefinder_target_output_dir, +) + +__all__ = [ + "RepresentativeFinderCandidate", + "RepresentativeFinderFolderInspection", + "RepresentativeFinderInputInspection", + "RepresentativeFinderOperationCancelled", + "RepresentativeFinderPlotSeries", + "RepresentativeFinderResult", + "RepresentativeFinderRunConfig", + "RepresentativeFinderRunExecutionSummary", + "RepresentativeFinderRunFailure", + "RepresentativeFinderRunTarget", + "RepresentativeFinderSettings", + "DEFAULT_RUN_FILE_NAME", + "analyze_representative_structure_folder", + "build_representativefinder_run_config", + "default_representativefinder_run_file_path", + "estimate_representativefinder_total_work", + "inspect_representative_structure_input", + "inspect_representative_structure_folder", + "load_representativefinder_result", + "load_representativefinder_run_config", + "persist_representativefinder_result_to_project", + "representativefinder_result_from_dict", + "representativefinder_run_targets", + "representativefinder_settings_from_dict", + "run_representativefinder_run_config", + "save_representativefinder_run_config", + "suggest_representativefinder_output_dir", + "suggest_representativefinder_target_output_dir", + "suggest_run_config_output_dir", +] diff --git a/src/saxshell/representativefinder/__main__.py b/src/saxshell/representativefinder/__main__.py new file mode 100644 index 0000000..bfdcd0c --- /dev/null +++ b/src/saxshell/representativefinder/__main__.py @@ -0,0 +1,4 @@ +from .cli import main + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/saxshell/representativefinder/cli.py b/src/saxshell/representativefinder/cli.py new file mode 100644 index 0000000..7ec33d6 --- /dev/null +++ b/src/saxshell/representativefinder/cli.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from saxshell.version import __version__ + +from .run_config import ( + default_representativefinder_run_file_path, + load_representativefinder_run_config, + representativefinder_run_targets, + run_representativefinder_run_config, +) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="representativefinder", + description=( + "Build or run project-backed representative-structure analysis " + "run files. Running without a subcommand launches the beta run " + "file setup UI." + ), + ) + parser.add_argument( + "--version", + action="store_true", + help="Show the representativefinder version number and exit.", + ) + + subparsers = parser.add_subparsers(dest="command") + + setup_ui_parser = subparsers.add_parser( + "setup-ui", + help="Launch the beta Qt run-file setup interface.", + ) + setup_ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + setup_ui_parser.add_argument( + "--input-dir", + type=Path, + default=None, + help="Optional cluster/stoichiometry folder to prefill.", + ) + setup_ui_parser.set_defaults(handler=_handle_setup_ui) + + ui_parser = subparsers.add_parser( + "ui", + help="Launch the full representative-structure analysis UI.", + ) + ui_parser.add_argument( + "project_dir", + nargs="?", + type=Path, + help="Optional SAXSShell project folder.", + ) + ui_parser.add_argument( + "--input-dir", + type=Path, + default=None, + help="Optional cluster/stoichiometry folder to prefill.", + ) + ui_parser.set_defaults(handler=_handle_ui) + + inspect_parser = subparsers.add_parser( + "inspect", + help="Inspect the targets described by a representative run file.", + ) + inspect_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + inspect_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to " + "representative_structure_cli_run.json in the project folder." + ), + ) + inspect_parser.set_defaults(handler=_handle_inspect) + + run_parser = subparsers.add_parser( + "run", + help="Run representative-structure analysis from a project run file.", + ) + run_parser.add_argument( + "project_dir", + type=Path, + help="SAXSShell project folder containing the run file.", + ) + run_parser.add_argument( + "--run-file", + type=Path, + default=None, + help=( + "Run file path. Defaults to " + "representative_structure_cli_run.json in the project folder." + ), + ) + run_parser.add_argument( + "--workers", + type=int, + default=None, + help="Override worker thread count stored in the run file.", + ) + run_parser.add_argument( + "--overwrite-existing", + action="store_true", + help="Recalculate stoichiometries that already have saved project representatives.", + ) + run_parser.set_defaults(handler=_handle_run) + + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + if args.version: + print(f"representativefinder {__version__}") + return 0 + + if args.command is None: + return _handle_setup_ui(args) + + try: + return int(args.handler(args)) + except Exception as exc: + parser.exit(2, f"Error: {exc}\n") + + +def _handle_setup_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from .ui.run_file_window import launch_representativefinder_run_file_ui + + owns_app = QApplication.instance() is None + launch_representativefinder_run_file_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_input_path=getattr(args, "input_dir", None), + ) + app = QApplication.instance() + if owns_app and app is not None: + return app.exec() + return 0 + + +def _handle_ui(args: argparse.Namespace) -> int: + from PySide6.QtWidgets import QApplication + + from .ui.main_window import launch_representativefinder_ui + + owns_app = QApplication.instance() is None + launch_representativefinder_ui( + initial_project_dir=getattr(args, "project_dir", None), + initial_input_path=getattr(args, "input_dir", None), + ) + app = QApplication.instance() + if owns_app and app is not None: + return app.exec() + return 0 + + +def _handle_inspect(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_representativefinder_run_config(run_file) + targets, skipped_existing = representativefinder_run_targets( + project_dir=project_dir, + config=config, + ) + print(f"Project folder: {project_dir}") + print(f"Run file: {run_file}") + print(f"Analysis mode: {config.analysis_mode}") + print(f"Targets to run: {len(targets)}") + if skipped_existing: + print("Skipped existing: " + ", ".join(skipped_existing)) + for target in targets: + print( + f"- {target.inspection.structure_label}: " + f"{target.inspection.candidate_count} candidate file(s) -> " + f"{target.output_dir}" + ) + return 0 + + +def _handle_run(args: argparse.Namespace) -> int: + project_dir = Path(args.project_dir).expanduser().resolve() + run_file = _resolve_run_file(project_dir, args.run_file) + config = load_representativefinder_run_config(run_file) + if args.workers is not None: + config.settings = config.settings.__class__( + selection_algorithm=config.settings.selection_algorithm, + bond_weight=config.settings.bond_weight, + angle_weight=config.settings.angle_weight, + solvent_weight=config.settings.solvent_weight, + generate_predicted_optimized_representative=( + config.settings.generate_predicted_optimized_representative + ), + parallel_workers=int(args.workers), + quantiles=config.settings.quantiles, + bond_pairs=config.settings.bond_pairs, + angle_triplets=config.settings.angle_triplets, + ) + if bool(args.overwrite_existing): + config.overwrite_existing = True + + summary = run_representativefinder_run_config( + project_dir, + config, + run_file_path=run_file, + log_callback=print, + progress_callback=_print_progress, + ) + print("") + print("Representative CLI run complete") + print(f"Completed: {summary.completed_count}") + print(f"Failed: {summary.failed_count}") + if summary.skipped_existing: + print("Skipped existing: " + ", ".join(summary.skipped_existing)) + for path in summary.project_representative_paths: + print(f"Project representative: {path}") + for failure in summary.failures: + print(f"FAILED {failure.structure_label}: {failure.message}") + return 1 if summary.failures else 0 + + +def _resolve_run_file(project_dir: Path, run_file: Path | None) -> Path: + if run_file is None: + return default_representativefinder_run_file_path(project_dir) + return Path(run_file).expanduser().resolve() + + +def _print_progress(processed: int, total: int, message: str) -> None: + print(f"{processed}/{total} {message}") + + +__all__ = ["build_parser", "main"] diff --git a/src/saxshell/representativefinder/run_config.py b/src/saxshell/representativefinder/run_config.py new file mode 100644 index 0000000..2a4fe5b --- /dev/null +++ b/src/saxshell/representativefinder/run_config.py @@ -0,0 +1,539 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Callable + +from saxshell.bondanalysis import AngleTripletDefinition, BondPairDefinition +from saxshell.fullrmc.project_model import ensure_rmcsetup_structure +from saxshell.fullrmc.representatives import ( + load_representative_selection_metadata, +) + +from .workflow import ( + RepresentativeFinderFolderInspection, + RepresentativeFinderResult, + RepresentativeFinderSettings, + analyze_representative_structure_folder, + inspect_representative_structure_input, + persist_representativefinder_result_to_project, + suggest_representativefinder_output_dir, + suggest_representativefinder_target_output_dir, +) + +DEFAULT_RUN_FILE_NAME = "representative_structure_cli_run.json" +RUN_CONFIG_VERSION = 1 +RepresentativeFinderRunLogCallback = Callable[[str], None] +RepresentativeFinderRunProgressCallback = Callable[[int, int, str], None] + + +@dataclass(slots=True) +class RepresentativeFinderRunConfig: + input_dir: str + output_dir: str | None + analysis_mode: str = "all" + selected_stoichiometry: str | None = None + overwrite_existing: bool = False + settings: RepresentativeFinderSettings = field( + default_factory=RepresentativeFinderSettings + ) + created_at: str = field( + default_factory=lambda: datetime.now().isoformat(timespec="seconds") + ) + + def to_dict(self) -> dict[str, object]: + return { + "version": RUN_CONFIG_VERSION, + "created_at": self.created_at, + "input_dir": self.input_dir, + "output_dir": self.output_dir, + "analysis_mode": _normalize_analysis_mode(self.analysis_mode), + "selected_stoichiometry": self.selected_stoichiometry, + "overwrite_existing": bool(self.overwrite_existing), + "settings": self.settings.to_dict(), + } + + @classmethod + def from_dict( + cls, + payload: dict[str, object], + ) -> "RepresentativeFinderRunConfig": + input_dir = str(payload.get("input_dir", "")).strip() + if not input_dir: + raise ValueError("Representative run file is missing input_dir.") + output_dir = _optional_text(payload.get("output_dir")) + return cls( + input_dir=input_dir, + output_dir=output_dir, + analysis_mode=_normalize_analysis_mode( + payload.get("analysis_mode", "all") + ), + selected_stoichiometry=_optional_text( + payload.get("selected_stoichiometry") + ), + overwrite_existing=bool(payload.get("overwrite_existing", False)), + settings=representativefinder_settings_from_dict( + payload.get("settings") + ), + created_at=str(payload.get("created_at", "")).strip() + or datetime.now().isoformat(timespec="seconds"), + ) + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderRunTarget: + inspection: RepresentativeFinderFolderInspection + output_dir: Path + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderRunFailure: + input_dir: Path + structure_label: str + message: str + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderRunExecutionSummary: + project_dir: Path + run_file_path: Path | None + targets: tuple[RepresentativeFinderRunTarget, ...] + results: tuple[RepresentativeFinderResult, ...] + project_representative_paths: tuple[Path, ...] + failures: tuple[RepresentativeFinderRunFailure, ...] + skipped_existing: tuple[str, ...] + + @property + def completed_count(self) -> int: + return len(self.results) + + @property + def failed_count(self) -> int: + return len(self.failures) + + +def default_representativefinder_run_file_path( + project_dir: str | Path, +) -> Path: + return Path(project_dir).expanduser().resolve() / DEFAULT_RUN_FILE_NAME + + +def representativefinder_settings_from_dict( + payload: object, +) -> RepresentativeFinderSettings: + source = dict(payload) if isinstance(payload, dict) else {} + quantile_values = ( + source.get("quantiles") or RepresentativeFinderSettings().quantiles + ) + quantiles = tuple(float(value) for value in quantile_values) + return RepresentativeFinderSettings( + selection_algorithm=str( + source.get( + "selection_algorithm", + "target_distribution_quantile_distance", + ) + ).strip() + or "target_distribution_quantile_distance", + bond_weight=_float_value(source.get("bond_weight"), 1.0), + angle_weight=_float_value(source.get("angle_weight"), 1.0), + solvent_weight=_float_value(source.get("solvent_weight"), 1.0), + generate_predicted_optimized_representative=bool( + source.get("generate_predicted_optimized_representative", False) + ), + parallel_workers=_int_value(source.get("parallel_workers"), 0), + quantiles=quantiles or RepresentativeFinderSettings().quantiles, + bond_pairs=tuple( + _bond_pair_from_dict(entry) + for entry in source.get("bond_pairs", []) + if isinstance(entry, dict) + ), + angle_triplets=tuple( + _angle_triplet_from_dict(entry) + for entry in source.get("angle_triplets", []) + if isinstance(entry, dict) + ), + ) + + +def save_representativefinder_run_config( + output_path: str | Path, + config: RepresentativeFinderRunConfig, +) -> Path: + path = Path(output_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(config.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return path + + +def load_representativefinder_run_config( + run_file_path: str | Path, +) -> RepresentativeFinderRunConfig: + path = Path(run_file_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError( + f"Representative run file must contain a JSON object: {path}" + ) + return RepresentativeFinderRunConfig.from_dict(payload) + + +def path_text_for_run_config( + path: str | Path | None, + *, + project_dir: str | Path, +) -> str | None: + if path is None: + return None + resolved_project_dir = Path(project_dir).expanduser().resolve() + resolved_path = Path(path).expanduser().resolve() + try: + return resolved_path.relative_to(resolved_project_dir).as_posix() + except ValueError: + return str(resolved_path) + + +def resolve_run_config_path( + path_text: str | None, + *, + project_dir: str | Path, +) -> Path | None: + text = str(path_text or "").strip() + if not text: + return None + path = Path(text).expanduser() + if not path.is_absolute(): + path = Path(project_dir).expanduser().resolve() / path + return path.resolve() + + +def build_representativefinder_run_config( + *, + project_dir: str | Path, + input_dir: str | Path, + output_dir: str | Path | None, + analysis_mode: str, + settings: RepresentativeFinderSettings, + selected_stoichiometry: str | None = None, + overwrite_existing: bool = False, +) -> RepresentativeFinderRunConfig: + return RepresentativeFinderRunConfig( + input_dir=path_text_for_run_config( + input_dir, + project_dir=project_dir, + ) + or "", + output_dir=path_text_for_run_config( + output_dir, + project_dir=project_dir, + ), + analysis_mode=_normalize_analysis_mode(analysis_mode), + selected_stoichiometry=_optional_text(selected_stoichiometry), + overwrite_existing=bool(overwrite_existing), + settings=settings, + ) + + +def suggest_run_config_output_dir( + *, + project_dir: str | Path, + input_dir: str | Path, + analysis_mode: str, +) -> Path: + inspection = inspect_representative_structure_input(input_dir) + batch = ( + _normalize_analysis_mode(analysis_mode) == "all" + or inspection.stoichiometry_count > 1 + or not inspection.input_is_stoichiometry_folder + ) + suggestion_source = ( + inspection.input_dir + if batch + else inspection.stoichiometry_folders[0].input_dir + ) + return suggest_representativefinder_output_dir( + suggestion_source, + project_dir=project_dir, + batch=batch, + ) + + +def representativefinder_run_targets( + *, + project_dir: str | Path, + config: RepresentativeFinderRunConfig, +) -> tuple[tuple[RepresentativeFinderRunTarget, ...], tuple[str, ...]]: + resolved_project_dir = Path(project_dir).expanduser().resolve() + input_dir = resolve_run_config_path( + config.input_dir, + project_dir=resolved_project_dir, + ) + if input_dir is None: + raise ValueError("Representative run file is missing input_dir.") + inspection = inspect_representative_structure_input(input_dir) + selected_stoichiometries = _selected_stoichiometries_for_config( + inspection.stoichiometry_folders, + analysis_mode=config.analysis_mode, + selected_stoichiometry=config.selected_stoichiometry, + ) + skipped_existing: tuple[str, ...] = () + if not config.overwrite_existing: + saved_labels = _saved_project_representative_labels( + resolved_project_dir + ) + skipped_existing = tuple( + stoich.structure_label + for stoich in selected_stoichiometries + if stoich.structure_label in saved_labels + ) + selected_stoichiometries = tuple( + stoich + for stoich in selected_stoichiometries + if stoich.structure_label not in saved_labels + ) + + output_root = resolve_run_config_path( + config.output_dir, + project_dir=resolved_project_dir, + ) + if output_root is None: + output_root = suggest_run_config_output_dir( + project_dir=resolved_project_dir, + input_dir=input_dir, + analysis_mode=config.analysis_mode, + ) + use_direct_output_dir = ( + inspection.input_is_stoichiometry_folder + and inspection.stoichiometry_count == 1 + ) + targets = tuple( + RepresentativeFinderRunTarget( + inspection=stoich, + output_dir=( + output_root + if use_direct_output_dir + else suggest_representativefinder_target_output_dir( + output_root, + stoich.structure_label, + ) + ), + ) + for stoich in selected_stoichiometries + ) + return targets, skipped_existing + + +def run_representativefinder_run_config( + project_dir: str | Path, + config: RepresentativeFinderRunConfig, + *, + run_file_path: str | Path | None = None, + log_callback: RepresentativeFinderRunLogCallback | None = None, + progress_callback: RepresentativeFinderRunProgressCallback | None = None, +) -> RepresentativeFinderRunExecutionSummary: + resolved_project_dir = Path(project_dir).expanduser().resolve() + targets, skipped_existing = representativefinder_run_targets( + project_dir=resolved_project_dir, + config=config, + ) + results: list[RepresentativeFinderResult] = [] + failures: list[RepresentativeFinderRunFailure] = [] + project_paths: list[Path] = [] + + if skipped_existing: + _emit_run_log( + log_callback, + "Skipping saved representative structures: " + + ", ".join(skipped_existing), + ) + if not targets: + _emit_run_log( + log_callback, + "No representative-structure targets need to be run.", + ) + target_count = len(targets) + for index, target in enumerate(targets, start=1): + label = target.inspection.structure_label + _emit_run_log( + log_callback, + f"[{index}/{target_count}] Starting {label}.", + ) + + def on_progress( + processed: int, + total: int, + message: str, + ) -> None: + if progress_callback is not None: + progress_callback(processed, total, f"[{label}] {message}") + + def on_log(message: str) -> None: + _emit_run_log(log_callback, f"[{label}] {message}") + + try: + result = analyze_representative_structure_folder( + target.inspection.input_dir, + settings=config.settings, + output_dir=target.output_dir, + project_dir=resolved_project_dir, + progress_callback=on_progress, + log_callback=on_log, + ) + shared_path = persist_representativefinder_result_to_project( + resolved_project_dir, + result, + ) + except Exception as exc: + failures.append( + RepresentativeFinderRunFailure( + input_dir=target.inspection.input_dir, + structure_label=label, + message=str(exc), + ) + ) + _emit_run_log( + log_callback, + f"[{label}] Failed representative selection: {exc}", + ) + continue + results.append(result) + project_paths.append(shared_path) + _emit_run_log( + log_callback, + f"[{label}] Project representative: {shared_path}", + ) + + return RepresentativeFinderRunExecutionSummary( + project_dir=resolved_project_dir, + run_file_path=( + None if run_file_path is None else Path(run_file_path).resolve() + ), + targets=targets, + results=tuple(results), + project_representative_paths=tuple(project_paths), + failures=tuple(failures), + skipped_existing=skipped_existing, + ) + + +def _selected_stoichiometries_for_config( + stoichiometries: tuple[RepresentativeFinderFolderInspection, ...], + *, + analysis_mode: str, + selected_stoichiometry: str | None, +) -> tuple[RepresentativeFinderFolderInspection, ...]: + if not stoichiometries: + raise ValueError("No stoichiometry folders were found.") + if _normalize_analysis_mode(analysis_mode) == "all": + return stoichiometries + selected_label = str(selected_stoichiometry or "").strip() + if selected_label: + for stoich in stoichiometries: + if stoich.structure_label == selected_label: + return (stoich,) + raise ValueError( + "Selected stoichiometry was not found in the input folder: " + f"{selected_label}" + ) + return (stoichiometries[0],) + + +def _saved_project_representative_labels(project_dir: Path) -> set[str]: + paths = ensure_rmcsetup_structure(project_dir) + metadata = load_representative_selection_metadata( + paths.representative_selection_path + ) + if metadata is None: + return set() + labels: set[str] = set() + for entry in metadata.representative_entries: + source_file = str(entry.source_file or "").strip() + if not source_file: + continue + if not Path(source_file).expanduser().resolve().is_file(): + continue + structure = str(entry.structure or "").strip() + if structure: + labels.add(structure) + return labels + + +def _bond_pair_from_dict(payload: dict[str, object]) -> BondPairDefinition: + return BondPairDefinition( + str(payload["atom1"]), + str(payload["atom2"]), + float(payload["cutoff_angstrom"]), + ) + + +def _angle_triplet_from_dict( + payload: dict[str, object], +) -> AngleTripletDefinition: + return AngleTripletDefinition( + str(payload["vertex"]), + str(payload["arm1"]), + str(payload["arm2"]), + float(payload["cutoff1_angstrom"]), + float(payload["cutoff2_angstrom"]), + ) + + +def _normalize_analysis_mode(value: object) -> str: + text = str(value or "").strip().lower() + return "single" if text == "single" else "all" + + +def _optional_text(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _float_value(value: object, default: float) -> float: + if value is None: + return float(default) + text = str(value).strip() + if not text: + return float(default) + return float(text) + + +def _int_value(value: object, default: int) -> int: + if value is None: + return int(default) + text = str(value).strip() + if not text: + return int(default) + return int(text) + + +def _emit_run_log( + callback: RepresentativeFinderRunLogCallback | None, + message: str, +) -> None: + if callback is not None: + callback(str(message).strip()) + + +__all__ = [ + "DEFAULT_RUN_FILE_NAME", + "RepresentativeFinderRunConfig", + "RepresentativeFinderRunExecutionSummary", + "RepresentativeFinderRunFailure", + "RepresentativeFinderRunTarget", + "build_representativefinder_run_config", + "default_representativefinder_run_file_path", + "load_representativefinder_run_config", + "path_text_for_run_config", + "representativefinder_run_targets", + "representativefinder_settings_from_dict", + "resolve_run_config_path", + "run_representativefinder_run_config", + "save_representativefinder_run_config", + "suggest_run_config_output_dir", +] diff --git a/src/saxshell/representativefinder/ui/__init__.py b/src/saxshell/representativefinder/ui/__init__.py new file mode 100644 index 0000000..5122948 --- /dev/null +++ b/src/saxshell/representativefinder/ui/__init__.py @@ -0,0 +1,11 @@ +"""Qt UI for representative-structure screening.""" + +from .main_window import ( + RepresentativeStructureFinderMainWindow, + launch_representativefinder_ui, +) + +__all__ = [ + "RepresentativeStructureFinderMainWindow", + "launch_representativefinder_ui", +] diff --git a/src/saxshell/representativefinder/ui/main_window.py b/src/saxshell/representativefinder/ui/main_window.py new file mode 100644 index 0000000..b91e872 --- /dev/null +++ b/src/saxshell/representativefinder/ui/main_window.py @@ -0,0 +1,3166 @@ +from __future__ import annotations + +import subprocess +import sys +from dataclasses import dataclass, replace +from pathlib import Path + +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure +from PySide6.QtCore import QObject, Qt, QThread, QTimer, QUrl, Signal, Slot +from PySide6.QtGui import QDesktopServices +from PySide6.QtWidgets import ( + QAbstractItemView, + QApplication, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QHeaderView, + QInputDialog, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QProgressBar, + QProgressDialog, + QPushButton, + QScrollArea, + QSplitter, + QTableWidget, + QTableWidgetItem, + QVBoxLayout, + QWidget, +) + +from saxshell.bondanalysis import ( + AngleTripletDefinition, + BondAnalysisPreset, + BondPairDefinition, + load_presets, + ordered_preset_names, + save_custom_preset, +) +from saxshell.representativefinder.workflow import ( + RepresentativeFinderCandidate, + RepresentativeFinderFolderInspection, + RepresentativeFinderInputInspection, + RepresentativeFinderOperationCancelled, + RepresentativeFinderPlotSeries, + RepresentativeFinderResult, + RepresentativeFinderSettings, + analyze_representative_structure_folder, + estimate_representativefinder_total_work, + inspect_representative_structure_input, + load_representativefinder_result, + persist_representativefinder_result_to_project, + suggest_representativefinder_output_dir, + suggest_representativefinder_target_output_dir, +) +from saxshell.saxs.electron_density_mapping.ui.viewer import ( + ElectronDensityStructureViewer, +) +from saxshell.saxs.electron_density_mapping.workflow import ( + ElectronDensityMeshGeometry, + ElectronDensityStructure, + build_electron_density_mesh, + legacy_born_average_default_mesh_settings, + load_electron_density_structure, +) +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + +_ALGORITHM_ITEMS = [ + ( + "Quantile Distance (Recommended)", + "target_distribution_quantile_distance", + ), + ("Mean/Std Distance", "target_distribution_moment_distance"), +] +_ANALYSIS_MODE_ITEMS = [ + ("Selected Stoichiometry Only", "single"), + ("All Discovered Stoichiometries", "all"), +] +_DISPLAY_MODE_ITEMS = [ + ("Selected Candidate", "selected_candidate"), + ("Observed Representative", "observed_representative"), + ( + "Predicted Optimized Representative", + "predicted_optimized_representative", + ), + ( + "Solvent-completed Predicted Representative", + "solvent_completed_predicted_representative", + ), +] + + +@dataclass(slots=True) +class RepresentativeFinderAnalysisTarget: + inspection: RepresentativeFinderFolderInspection + output_dir: Path + estimated_total_work: int + + +@dataclass(slots=True) +class RepresentativeFinderJobConfig: + analysis_mode: str + targets: tuple[RepresentativeFinderAnalysisTarget, ...] + settings: RepresentativeFinderSettings + project_dir: Path | None = None + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderTargetFailure: + input_dir: Path + structure_label: str + message: str + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderRunSummary: + analysis_mode: str + targets: tuple[RepresentativeFinderAnalysisTarget, ...] + results: tuple[RepresentativeFinderResult, ...] + failures: tuple[RepresentativeFinderTargetFailure, ...] + + +@dataclass(slots=True) +class RepresentativeFinderSessionState: + input_dir_text: str + output_dir_text: str + analysis_mode: str + settings: RepresentativeFinderSettings | None + results_by_input_dir: dict[str, RepresentativeFinderResult] + failures_by_input_dir: dict[str, str] + selected_stoichiometry_key: str | None + display_mode: str | None + console_text: str + + +_PROJECT_SESSION_STATES: dict[str, RepresentativeFinderSessionState] = {} + + +class RepresentativeFinderWorker(QObject): + log = Signal(str) + progress = Signal(int, int, str) + target_started = Signal(str) + result_ready = Signal(object) + target_failed = Signal(object) + finished = Signal(object) + failed = Signal(str) + canceled = Signal() + + def __init__(self, config: RepresentativeFinderJobConfig) -> None: + super().__init__() + self.config = config + self._cancel_requested = False + + @Slot() + def cancel(self) -> None: + self._cancel_requested = True + + def is_cancel_requested(self) -> bool: + return ( + self._cancel_requested + or QThread.currentThread().isInterruptionRequested() + ) + + @Slot() + def run(self) -> None: + try: + results: list[RepresentativeFinderResult] = [] + failures: list[RepresentativeFinderTargetFailure] = [] + target_count = len(self.config.targets) + global_total_work = max( + sum( + target.estimated_total_work + for target in self.config.targets + ), + 1, + ) + completed_work = 0 + + for index, target in enumerate(self.config.targets, start=1): + if self.is_cancel_requested(): + raise RepresentativeFinderOperationCancelled( + "Representative-structure analysis canceled." + ) + target_key = str(target.inspection.input_dir) + target_label = target.inspection.structure_label + target_prefix = ( + f"[{index}/{target_count}] {target_label}: " + if target_count > 1 + else "" + ) + log_prefix = f"[{target_label}] " if target_count > 1 else "" + self.target_started.emit(target_key) + self.log.emit( + f"{log_prefix}Starting representative selection for " + f"{target_label}." + ) + + def on_progress( + processed: int, + total: int, + message: str, + ) -> None: + del total + bounded = min( + max(int(processed), 0), + max(target.estimated_total_work, 1), + ) + self.progress.emit( + min(completed_work + bounded, global_total_work), + global_total_work, + f"{target_prefix}{message}", + ) + + def on_log(message: str) -> None: + self.log.emit(f"{log_prefix}{message}") + + try: + result = analyze_representative_structure_folder( + target.inspection.input_dir, + settings=self.config.settings, + output_dir=target.output_dir, + project_dir=self.config.project_dir, + progress_callback=on_progress, + log_callback=on_log, + cancel_callback=self.is_cancel_requested, + ) + except RepresentativeFinderOperationCancelled: + raise + except Exception as exc: + failure = RepresentativeFinderTargetFailure( + input_dir=target.inspection.input_dir, + structure_label=target_label, + message=str(exc), + ) + failures.append(failure) + self.target_failed.emit(failure) + completed_work += target.estimated_total_work + self.progress.emit( + min(completed_work, global_total_work), + global_total_work, + f"{target_prefix}failed", + ) + if target_count == 1: + self.failed.emit(str(exc)) + return + continue + + results.append(result) + self.result_ready.emit(result) + completed_work += target.estimated_total_work + + completion_message = ( + "Representative selection complete." + if not failures + else ( + "Representative selection complete with " + f"{len(failures)} failed stoichiometry run(s)." + ) + ) + self.progress.emit( + global_total_work, global_total_work, completion_message + ) + self.finished.emit( + RepresentativeFinderRunSummary( + analysis_mode=self.config.analysis_mode, + targets=self.config.targets, + results=tuple(results), + failures=tuple(failures), + ) + ) + except RepresentativeFinderOperationCancelled: + self.canceled.emit() + except Exception as exc: + self.failed.emit(str(exc)) + + +class RepresentativeDistributionPlotWidget(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._result: RepresentativeFinderResult | None = None + self._candidate: RepresentativeFinderCandidate | None = None + self._plot_series: tuple[RepresentativeFinderPlotSeries, ...] = () + self._selected_series_index = 0 + self.figure = Figure(figsize=(9.0, 7.4)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + self._build_ui() + self.draw_placeholder() + + def _build_ui(self) -> None: + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + selector_row = QHBoxLayout() + selector_row.addWidget(QLabel("Distribution")) + self.previous_distribution_button = QPushButton("<") + self.previous_distribution_button.clicked.connect( + self._select_previous_distribution + ) + selector_row.addWidget(self.previous_distribution_button) + self.distribution_selector_combo = QComboBox() + self.distribution_selector_combo.currentIndexChanged.connect( + self._on_distribution_changed + ) + selector_row.addWidget(self.distribution_selector_combo, stretch=1) + self.next_distribution_button = QPushButton(">") + self.next_distribution_button.clicked.connect( + self._select_next_distribution + ) + selector_row.addWidget(self.next_distribution_button) + layout.addLayout(selector_row) + layout.addWidget(self.toolbar) + self.canvas.setMinimumHeight(340) + layout.addWidget(self.canvas, stretch=1) + + def draw_placeholder(self) -> None: + self._result = None + self._candidate = None + self._set_plot_series(()) + self._draw_message( + "Run representative selection to compare the folder-wide bond and angle distributions with one candidate structure.", + secondary=( + "Use the distribution selector or Previous/Next buttons to inspect one computed distribution at a time." + ), + ) + + def set_result( + self, + result: RepresentativeFinderResult | None, + *, + candidate: RepresentativeFinderCandidate | None = None, + ) -> None: + previous_label = self._current_distribution_label() + self._result = result + self._candidate = candidate or ( + None if result is None else result.selected_candidate + ) + if self._result is None or self._candidate is None: + self._set_plot_series(()) + else: + filtered_series = tuple( + series + for series in self._result.plot_series_for_candidate( + self._candidate + ) + if series.distribution_values.size > 0 + or bool(series.candidate_values) + ) + self._set_plot_series( + filtered_series, + selected_label=previous_label, + ) + self.refresh_plot() + + def refresh_plot(self) -> None: + self.figure.clear() + if self._result is None or self._candidate is None: + self._draw_message( + "Run representative selection to compare the folder-wide bond and angle distributions with one candidate structure.", + secondary=( + "Use the distribution selector or Previous/Next buttons to inspect one computed distribution at a time." + ), + ) + return + if not self._plot_series: + self._draw_message( + "No computed bond or angle distributions are available for the active stoichiometry and candidate.", + secondary=( + "Only distributions with measured values for the active stoichiometry are listed in the selector." + ), + ) + return + self._selected_series_index = min( + max(self._selected_series_index, 0), + len(self._plot_series) - 1, + ) + series = self._plot_series[self._selected_series_index] + axis = self.figure.add_subplot(111) + self._draw_series(axis, series) + self.figure.suptitle( + f"Distribution Comparison • {self._candidate.file_name}", + y=0.995, + ) + self.figure.tight_layout(rect=(0.0, 0.0, 1.0, 0.975)) + self.canvas.draw_idle() + + def _set_plot_series( + self, + plot_series: tuple[RepresentativeFinderPlotSeries, ...], + *, + selected_label: str | None = None, + ) -> None: + self._plot_series = tuple(plot_series) + target_index = 0 + if self._plot_series and selected_label: + for index, series in enumerate(self._plot_series): + if self._series_selector_label(series) == selected_label: + target_index = index + break + self._selected_series_index = min( + max(target_index, 0), + max(len(self._plot_series) - 1, 0), + ) + self.distribution_selector_combo.blockSignals(True) + self.distribution_selector_combo.clear() + for series in self._plot_series: + self.distribution_selector_combo.addItem( + self._series_selector_label(series) + ) + if self._plot_series: + self.distribution_selector_combo.setCurrentIndex( + self._selected_series_index + ) + self.distribution_selector_combo.blockSignals(False) + self._update_distribution_controls() + + def _update_distribution_controls(self) -> None: + series_count = len(self._plot_series) + has_series = series_count > 0 + self.distribution_selector_combo.setEnabled(has_series) + self.previous_distribution_button.setEnabled(series_count > 1) + self.next_distribution_button.setEnabled(series_count > 1) + + def _current_distribution_label(self) -> str | None: + if not self._plot_series: + return None + return self._series_selector_label( + self._plot_series[ + min( + max(self._selected_series_index, 0), + len(self._plot_series) - 1, + ) + ] + ) + + @staticmethod + def _series_selector_label(series: RepresentativeFinderPlotSeries) -> str: + prefix = "Bond" if series.category == "bond" else "Angle" + return f"{prefix}: {series.display_label}" + + def _draw_message( + self, + message: str, + *, + secondary: str | None = None, + ) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.56 if secondary else 0.5, + message, + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + if secondary: + axis.text( + 0.5, + 0.40, + secondary, + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + alpha=0.8, + ) + axis.set_axis_off() + self.canvas.draw_idle() + + def _on_distribution_changed(self, index: int) -> None: + if index < 0 or index >= len(self._plot_series): + return + self._selected_series_index = index + self.refresh_plot() + + def _select_previous_distribution(self) -> None: + if len(self._plot_series) <= 1: + return + new_index = (self._selected_series_index - 1) % len(self._plot_series) + self.distribution_selector_combo.setCurrentIndex(new_index) + + def _select_next_distribution(self) -> None: + if len(self._plot_series) <= 1: + return + new_index = (self._selected_series_index + 1) % len(self._plot_series) + self.distribution_selector_combo.setCurrentIndex(new_index) + + def _draw_series( + self, axis, series: RepresentativeFinderPlotSeries + ) -> None: + if series.distribution_values.size > 0: + axis.hist( + series.distribution_values, + bins=60, + color="#355070" if series.category == "bond" else "#bc6c25", + edgecolor="white", + alpha=0.88, + ) + else: + axis.text( + 0.5, + 0.5, + "No distribution values were available for this definition.", + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + for index, value in enumerate(series.candidate_values): + axis.axvline( + value, + color="black", + linestyle="--", + linewidth=1.2, + label=("Candidate value" if index == 0 else None), + ) + axis.set_title(series.display_label) + axis.set_xlabel(series.xlabel) + axis.set_ylabel("Count") + if series.candidate_values: + axis.legend(frameon=False, loc="upper right") + + +class RepresentativeStructureFinderMainWindow(QMainWindow): + project_results_changed = Signal(str) + + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + ) -> None: + super().__init__() + self._initial_project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + self._browse_start_dir = ( + self._initial_project_dir + if self._initial_project_dir is not None + else Path.home() + ) + self._last_suggested_output_dir: str | None = None + self._presets: dict[str, BondAnalysisPreset] = {} + self._input_inspection: RepresentativeFinderInputInspection | None = ( + None + ) + self._analysis_results_by_input_dir: dict[ + str, RepresentativeFinderResult + ] = {} + self._analysis_failures_by_input_dir: dict[str, str] = {} + self._stoichiometry_row_by_input_dir: dict[str, int] = {} + self._active_stoichiometry_key: str | None = None + self._run_summary: RepresentativeFinderRunSummary | None = None + self._viewer_scene_payload_by_path: dict[ + str, + tuple[ + ElectronDensityStructure, ElectronDensityMeshGeometry | None + ], + ] = {} + self._shared_project_representative_path_by_input_dir: dict[ + str, Path + ] = {} + self._shared_project_representative_entry_by_input_dir: dict[ + str, + object, + ] = {} + self._analysis_thread: QThread | None = None + self._analysis_worker: RepresentativeFinderWorker | None = None + self._closing_after_analysis_cancel = False + + self.setWindowTitle("Representative Structures") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1380, 900) + self._build_ui() + self._refresh_analysis_mode_ui() + self._reload_presets() + + if initial_input_path is not None: + resolved_input_path = ( + Path(initial_input_path).expanduser().resolve() + ) + if resolved_input_path.is_dir(): + self._browse_start_dir = resolved_input_path + self.input_dir_edit.setText(str(resolved_input_path)) + self._refresh_input_preview() + restored_from_session = self._restore_project_session_state() + if not restored_from_session: + self._restore_project_cached_results_with_startup_progress() + + def closeEvent(self, event) -> None: + if ( + self._analysis_thread is not None + and self._analysis_thread.isRunning() + ): + self._cancel_analysis_for_close() + event.ignore() + return + self._save_project_session_state() + super().closeEvent(event) + + def _cancel_analysis(self) -> None: + worker = self._analysis_worker + if worker is not None: + worker.cancel() + thread = self._analysis_thread + if thread is not None and thread.isRunning(): + thread.requestInterruption() + thread.quit() + + def _cancel_analysis_for_close(self) -> None: + if self._closing_after_analysis_cancel: + return + self._closing_after_analysis_cancel = True + self.run_status_label.setText( + "Representative selection: canceling so the window can close..." + ) + self.statusBar().showMessage("Stopping representative selection...") + self._append_console( + "Close requested. Canceling representative-structure analysis." + ) + self._cancel_analysis() + self.setEnabled(False) + self.hide() + + def _build_ui(self) -> None: + central = QWidget(self) + self.setCentralWidget(central) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + + intro_label = QLabel( + "Select one representative structure per stoichiometry using the shared bondanalysis preset workflow and solvent-aware scoring from the contrast descriptor backend. Saved project representatives can then be reused by compatible SAXS and RMCSetup tools." + ) + intro_label.setWordWrap(True) + root.addWidget(intro_label) + + self._pane_splitter = QSplitter(Qt.Orientation.Horizontal, self) + self._pane_splitter.setChildrenCollapsible(False) + self._pane_splitter.setStretchFactor(0, 0) + self._pane_splitter.setStretchFactor(1, 1) + root.addWidget(self._pane_splitter, stretch=1) + + self._left_scroll = QScrollArea(self) + self._left_scroll.setWidgetResizable(True) + self._left_scroll.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + self._right_scroll = QScrollArea(self) + self._right_scroll.setWidgetResizable(True) + self._right_scroll.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) + + self._left_panel = QWidget() + self._right_panel = QWidget() + self._left_layout = QVBoxLayout(self._left_panel) + self._left_layout.setContentsMargins(10, 10, 10, 10) + self._left_layout.setSpacing(10) + self._right_layout = QVBoxLayout(self._right_panel) + self._right_layout.setContentsMargins(10, 10, 10, 10) + self._right_layout.setSpacing(10) + self._left_scroll.setWidget(self._left_panel) + self._right_scroll.setWidget(self._right_panel) + self._pane_splitter.addWidget(self._left_scroll) + self._pane_splitter.addWidget(self._right_scroll) + self._pane_splitter.setSizes([430, 900]) + + self._build_left_panel() + self._build_right_panel() + + self.statusBar().showMessage("Ready") + + def _build_left_panel(self) -> None: + self._left_layout.addWidget(self._build_input_group()) + self._left_layout.addWidget(self._build_preset_group()) + self._left_layout.addWidget(self._build_bond_pairs_group()) + self._left_layout.addWidget(self._build_angle_triplets_group()) + self._left_layout.addWidget(self._build_advanced_group()) + self._left_layout.addWidget(self._build_solvent_shell_group()) + self._left_layout.addWidget(self._build_run_group()) + self._left_layout.addStretch(1) + + def _build_right_panel(self) -> None: + self._right_splitter = QSplitter( + Qt.Orientation.Vertical, + self._right_panel, + ) + self._right_splitter.setChildrenCollapsible(False) + self._right_layout.addWidget(self._right_splitter, stretch=1) + + self._stoichiometry_group = self._build_stoichiometry_group() + self._result_summary_group = self._build_result_summary_group() + self._candidate_scores_group = self._build_candidate_scores_group() + self._plot_group = self._build_plot_group() + self._viewer_group = self._build_viewer_group() + + for widget in ( + self._stoichiometry_group, + self._result_summary_group, + self._candidate_scores_group, + self._plot_group, + self._viewer_group, + ): + self._right_splitter.addWidget(widget) + + self._right_splitter.setStretchFactor(0, 3) + self._right_splitter.setStretchFactor(1, 2) + self._right_splitter.setStretchFactor(2, 2) + self._right_splitter.setStretchFactor(3, 3) + self._right_splitter.setStretchFactor(4, 3) + self._apply_initial_right_splitter_sizes() + QTimer.singleShot(0, self._apply_initial_right_splitter_sizes) + + def _apply_initial_right_splitter_sizes(self) -> None: + if not hasattr(self, "_right_splitter"): + return + self._right_splitter.setSizes([420, 210, 240, 320, 300]) + + def _build_input_group(self) -> QGroupBox: + group = QGroupBox("Input Settings") + layout = QVBoxLayout(group) + form = QFormLayout() + + if self._initial_project_dir is not None: + self.project_dir_edit = QLineEdit(str(self._initial_project_dir)) + self.project_dir_edit.setReadOnly(True) + form.addRow("Project folder", self.project_dir_edit) + else: + self.project_dir_edit = None + + input_row = QHBoxLayout() + self.input_dir_edit = QLineEdit() + self.input_dir_edit.setPlaceholderText( + "Choose one stoichiometry folder, or a parent folder whose immediate subfolders are stoichiometries..." + ) + self.input_dir_edit.editingFinished.connect( + self._refresh_input_preview + ) + input_row.addWidget(self.input_dir_edit, stretch=1) + self.browse_input_button = QPushButton("Browse...") + self.browse_input_button.clicked.connect(self._browse_input_dir) + input_row.addWidget(self.browse_input_button) + input_widget = QWidget() + input_widget.setLayout(input_row) + form.addRow("Input folder", input_widget) + + self.analysis_mode_combo = QComboBox() + for label, value in _ANALYSIS_MODE_ITEMS: + self.analysis_mode_combo.addItem(label, value) + self.analysis_mode_combo.currentIndexChanged.connect( + self._refresh_analysis_mode_ui + ) + form.addRow("Analysis mode", self.analysis_mode_combo) + + output_row = QHBoxLayout() + self.output_dir_edit = QLineEdit() + self.output_dir_edit.setPlaceholderText( + "Representative output folder or batch output root" + ) + output_row.addWidget(self.output_dir_edit, stretch=1) + self.browse_output_button = QPushButton("Browse...") + self.browse_output_button.clicked.connect(self._browse_output_dir) + output_row.addWidget(self.browse_output_button) + output_widget = QWidget() + output_widget.setLayout(output_row) + form.addRow("Output folder", output_widget) + + layout.addLayout(form) + self.input_preview_box = QPlainTextEdit() + self.input_preview_box.setReadOnly(True) + self.input_preview_box.setMinimumHeight(130) + layout.addWidget(self.input_preview_box) + return group + + def _build_preset_group(self) -> QGroupBox: + group = QGroupBox("Bondanalysis Presets") + layout = QVBoxLayout(group) + row = QHBoxLayout() + self.preset_combo = QComboBox() + row.addWidget(self.preset_combo, stretch=1) + self.load_preset_button = QPushButton("Load") + self.load_preset_button.clicked.connect(self._load_selected_preset) + row.addWidget(self.load_preset_button) + self.save_preset_button = QPushButton("Save Current As...") + self.save_preset_button.clicked.connect(self._save_current_preset) + row.addWidget(self.save_preset_button) + layout.addLayout(row) + hint = QLabel( + "Uses the same preset file as bondanalysis so the same bond-pair and angle-triplet definitions can drive representative selection here." + ) + hint.setWordWrap(True) + layout.addWidget(hint) + return group + + def _build_bond_pairs_group(self) -> QGroupBox: + group = QGroupBox("Bond Pairs") + layout = QVBoxLayout(group) + controls = QHBoxLayout() + self.add_bond_pair_button = QPushButton("Add Bond Pair") + self.add_bond_pair_button.clicked.connect(self._add_bond_pair_row) + controls.addWidget(self.add_bond_pair_button) + self.remove_bond_pair_button = QPushButton("Remove Selected") + self.remove_bond_pair_button.clicked.connect( + self._remove_selected_bond_pair_rows + ) + controls.addWidget(self.remove_bond_pair_button) + controls.addStretch(1) + layout.addLayout(controls) + self.bond_pair_table = QTableWidget(0, 3) + self.bond_pair_table.setHorizontalHeaderLabels( + ["Atom 1", "Atom 2", "Cutoff (A)"] + ) + self.bond_pair_table.horizontalHeader().setStretchLastSection(True) + self.bond_pair_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + layout.addWidget(self.bond_pair_table) + self._add_empty_bond_pair_row(blocked=True) + return group + + def _build_angle_triplets_group(self) -> QGroupBox: + group = QGroupBox("Angle Triplets") + layout = QVBoxLayout(group) + controls = QHBoxLayout() + self.add_angle_triplet_button = QPushButton("Add Angle Triplet") + self.add_angle_triplet_button.clicked.connect( + self._add_angle_triplet_row + ) + controls.addWidget(self.add_angle_triplet_button) + self.remove_angle_triplet_button = QPushButton("Remove Selected") + self.remove_angle_triplet_button.clicked.connect( + self._remove_selected_angle_triplet_rows + ) + controls.addWidget(self.remove_angle_triplet_button) + controls.addStretch(1) + layout.addLayout(controls) + self.angle_triplet_table = QTableWidget(0, 5) + self.angle_triplet_table.setHorizontalHeaderLabels( + [ + "Vertex", + "Arm 1", + "Arm 2", + "Vertex-Arm 1 Cutoff (A)", + "Vertex-Arm 2 Cutoff (A)", + ] + ) + self.angle_triplet_table.horizontalHeader().setStretchLastSection(True) + self.angle_triplet_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + layout.addWidget(self.angle_triplet_table) + self._add_empty_angle_triplet_row(blocked=True) + return group + + def _build_advanced_group(self) -> QGroupBox: + group = QGroupBox("Advanced Scoring") + layout = QFormLayout(group) + self.algorithm_combo = QComboBox() + for label, value in _ALGORITHM_ITEMS: + self.algorithm_combo.addItem(label, value) + layout.addRow("Bond/angle distance", self.algorithm_combo) + self.bond_weight_spin = self._new_float_spin( + maximum=100.0, + step=0.1, + decimals=3, + value=1.0, + ) + layout.addRow("Bond weight", self.bond_weight_spin) + self.angle_weight_spin = self._new_float_spin( + maximum=100.0, + step=0.1, + decimals=3, + value=1.0, + ) + layout.addRow("Angle weight", self.angle_weight_spin) + self.solvent_weight_spin = self._new_float_spin( + maximum=100.0, + step=0.1, + decimals=3, + value=1.0, + ) + layout.addRow("Solvent weight", self.solvent_weight_spin) + hint = QLabel( + "Solvent scoring is inferred automatically from the selected stoichiometry folder name and the solvent-shell metrics measured from each candidate structure." + ) + hint.setWordWrap(True) + layout.addRow("", hint) + self.generate_predicted_checkbox = QCheckBox( + "Generate Predicted Optimized Representative" + ) + self.generate_predicted_checkbox.setChecked(False) + layout.addRow("", self.generate_predicted_checkbox) + predicted_hint = QLabel( + "Optional: generate a synthetic optimized representative alongside the observed representative. When project solvent settings are available, the tool will also attempt a solvent-completed predicted output." + ) + predicted_hint.setWordWrap(True) + layout.addRow("", predicted_hint) + return group + + def _build_run_group(self) -> QGroupBox: + group = QGroupBox("Run") + layout = QVBoxLayout(group) + self.overwrite_existing_checkbox = QCheckBox( + "Overwrite Existing Representative Structures" + ) + self.overwrite_existing_checkbox.setChecked(False) + self.overwrite_existing_checkbox.setToolTip( + "When unchecked, stoichiometries with saved project " + "representatives are skipped instead of recalculated." + ) + layout.addWidget(self.overwrite_existing_checkbox) + button_row = QHBoxLayout() + self.run_button = QPushButton("Analyze Representative Structure") + self.run_button.clicked.connect(self._run_analysis) + button_row.addWidget(self.run_button) + self.open_output_button = QPushButton("Show Output Path") + self.open_output_button.clicked.connect(self._show_output_folder) + self.open_output_button.setEnabled(False) + button_row.addWidget(self.open_output_button) + button_row.addStretch(1) + layout.addLayout(button_row) + self.run_status_label = QLabel("Representative selection: idle") + self.run_status_label.setWordWrap(True) + layout.addWidget(self.run_status_label) + self.run_progress_bar = QProgressBar() + self.run_progress_bar.setRange(0, 1) + self.run_progress_bar.setValue(0) + layout.addWidget(self.run_progress_bar) + self.console_box = QPlainTextEdit() + self.console_box.setReadOnly(True) + self.console_box.setMinimumHeight(220) + layout.addWidget(self.console_box) + return group + + def _build_solvent_shell_group(self) -> QGroupBox: + group = QGroupBox("Build Solvent Shell") + layout = QVBoxLayout(group) + self.solvent_shell_toggle_button = QPushButton( + "Show Solvent Shell Builder Options" + ) + self.solvent_shell_toggle_button.setCheckable(True) + self.solvent_shell_toggle_button.setChecked(False) + self.solvent_shell_toggle_button.toggled.connect( + self._toggle_solvent_shell_options + ) + layout.addWidget(self.solvent_shell_toggle_button) + + self.solvent_shell_body = QWidget(group) + body_layout = QVBoxLayout(self.solvent_shell_body) + body_layout.setContentsMargins(0, 0, 0, 0) + body_layout.setSpacing(8) + hint = QLabel( + "Open the shared solvent-shell builder preloaded with the active representative structure. Use this after representative selection when the stored source representative does not yet contain the solvent shell you want to preserve." + ) + hint.setWordWrap(True) + body_layout.addWidget(hint) + self.solvent_shell_status_label = QLabel( + "Select or compute a representative structure to enable this handoff." + ) + self.solvent_shell_status_label.setWordWrap(True) + body_layout.addWidget(self.solvent_shell_status_label) + button_row = QHBoxLayout() + self.open_selected_solvent_shell_button = QPushButton( + "Open for Selected Representative" + ) + self.open_selected_solvent_shell_button.clicked.connect( + self._open_solvent_shell_builder_for_selected_representative + ) + button_row.addWidget(self.open_selected_solvent_shell_button) + button_row.addStretch(1) + body_layout.addLayout(button_row) + self.solvent_shell_body.setVisible(False) + layout.addWidget(self.solvent_shell_body) + self._refresh_solvent_shell_controls() + return group + + def _build_result_summary_group(self) -> QGroupBox: + group = QGroupBox("Selected Stoichiometry Summary") + layout = QVBoxLayout(group) + selector_row = QHBoxLayout() + selector_row.addWidget(QLabel("Displayed structure")) + self.display_mode_combo = QComboBox() + self.display_mode_combo.setEnabled(False) + self.display_mode_combo.currentIndexChanged.connect( + self._update_selected_candidate_view + ) + selector_row.addWidget(self.display_mode_combo, stretch=1) + layout.addLayout(selector_row) + self.result_summary_box = QPlainTextEdit() + self.result_summary_box.setReadOnly(True) + self.result_summary_box.setMinimumHeight(180) + self.result_summary_box.setPlainText( + "Select a stoichiometry row to inspect it. Completed runs will populate the summary, candidate scores, plots, and representative viewer." + ) + layout.addWidget(self.result_summary_box) + return group + + def _build_stoichiometry_group(self) -> QGroupBox: + group = QGroupBox("Stoichiometry Results") + layout = QVBoxLayout(group) + self.stoichiometry_table = QTableWidget(0, 8) + self.stoichiometry_table.setMinimumHeight(280) + self.stoichiometry_table.setHorizontalHeaderLabels( + [ + "Stoichiometry", + "Candidates", + "Motifs", + "Status", + "Representative", + "Score", + "Output", + "Open", + ] + ) + self.stoichiometry_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.stoichiometry_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.stoichiometry_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.stoichiometry_table.horizontalHeader().setSectionResizeMode( + QHeaderView.ResizeMode.ResizeToContents + ) + self.stoichiometry_table.horizontalHeader().setStretchLastSection(True) + self.stoichiometry_table.itemSelectionChanged.connect( + self._update_selected_stoichiometry_view + ) + layout.addWidget(self.stoichiometry_table) + return group + + def _build_candidate_scores_group(self) -> QGroupBox: + group = QGroupBox("Candidate Scores") + layout = QVBoxLayout(group) + self.candidate_table = QTableWidget(0, 8) + self.candidate_table.setHorizontalHeaderLabels( + [ + "File", + "Source", + "Score", + "Bond", + "Angle", + "Solvent", + "Atoms", + "Solvent Atoms", + ] + ) + self.candidate_table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.candidate_table.setSelectionMode( + QAbstractItemView.SelectionMode.SingleSelection + ) + self.candidate_table.setEditTriggers( + QAbstractItemView.EditTrigger.NoEditTriggers + ) + self.candidate_table.horizontalHeader().setSectionResizeMode( + QHeaderView.ResizeMode.ResizeToContents + ) + self.candidate_table.horizontalHeader().setStretchLastSection(False) + self.candidate_table.itemSelectionChanged.connect( + self._update_selected_candidate_view + ) + layout.addWidget(self.candidate_table) + return group + + def _build_plot_group(self) -> QGroupBox: + group = QGroupBox("Distribution Plots") + layout = QVBoxLayout(group) + self.plot_widget = RepresentativeDistributionPlotWidget(self) + layout.addWidget(self.plot_widget) + return group + + def _build_viewer_group(self) -> QGroupBox: + group = QGroupBox("Structure Viewer") + layout = QVBoxLayout(group) + self.viewer_widget = ElectronDensityStructureViewer(self) + self.viewer_widget.mesh_contrast_spin.setValue(90.0) + self.viewer_widget.mesh_linewidth_spin.setValue(1.6) + layout.addWidget(self.viewer_widget) + return group + + def _browse_input_dir(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select stoichiometry folder or stoichiometry parent folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.input_dir_edit.setText(selected) + self._browse_start_dir = Path(selected).expanduser().resolve() + self._refresh_input_preview() + + def _browse_output_dir(self) -> None: + start_dir = self.output_dir_edit.text().strip() or str( + self._browse_start_dir + ) + selected = QFileDialog.getExistingDirectory( + self, + "Select representative output folder", + start_dir, + ) + if selected: + self.output_dir_edit.setText(selected) + + def _current_analysis_mode(self) -> str: + payload = self.analysis_mode_combo.currentData() + if payload is None: + return "single" + return str(payload) + + def _refresh_analysis_mode_ui(self) -> None: + mode = self._current_analysis_mode() + if hasattr(self, "run_button"): + self.run_button.setText( + "Analyze All Stoichiometries" + if mode == "all" + else "Analyze Selected Stoichiometry" + ) + if hasattr(self, "output_dir_edit"): + self._refresh_suggested_output_dir() + + def _refresh_input_preview(self) -> None: + input_dir = self.input_dir_edit.text().strip() + if not input_dir: + self._input_inspection = None + self._reset_analysis_results() + self.input_preview_box.setPlainText( + "Choose a stoichiometry folder, or a parent folder whose immediate subfolders are stoichiometries, to inspect representative-selection inputs." + ) + self._populate_stoichiometry_table(None) + self._refresh_project_representative_path_cache() + self._refresh_solvent_shell_controls() + return + try: + inspection = inspect_representative_structure_input(input_dir) + except Exception as exc: + self._input_inspection = None + self._reset_analysis_results() + self.input_preview_box.setPlainText(str(exc)) + self._populate_stoichiometry_table(None) + self._refresh_project_representative_path_cache() + self._refresh_solvent_shell_controls() + self.statusBar().showMessage( + "Representative folder inspection failed" + ) + return + self._input_inspection = inspection + self._reset_analysis_results() + self.input_preview_box.setPlainText(inspection.summary_text()) + self._populate_stoichiometry_table(inspection) + self._refresh_project_representative_path_cache() + self._refresh_solvent_shell_controls() + self.statusBar().showMessage( + f"Discovered {inspection.stoichiometry_count} stoichiometry folder(s)" + ) + self._refresh_suggested_output_dir() + + def _toggle_solvent_shell_options(self, checked: bool) -> None: + expanded = bool(checked) + self.solvent_shell_body.setVisible(expanded) + self.solvent_shell_toggle_button.setText( + "Hide Solvent Shell Builder Options" + if expanded + else "Show Solvent Shell Builder Options" + ) + + def _refresh_project_representative_path_cache(self) -> None: + self._shared_project_representative_path_by_input_dir = {} + self._shared_project_representative_entry_by_input_dir = {} + if self._initial_project_dir is None or self._input_inspection is None: + return + try: + from saxshell.fullrmc.project_model import ( + ensure_rmcsetup_structure, + ) + from saxshell.fullrmc.representatives import ( + load_representative_selection_metadata, + ) + except Exception: + return + rmcsetup_paths = ensure_rmcsetup_structure(self._initial_project_dir) + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + if metadata is None: + return + shared_entry_by_structure = {} + for entry in metadata.representative_entries: + structure = str(entry.structure).strip() + source_file = str(entry.source_file).strip() + if structure and source_file: + shared_entry_by_structure[structure] = entry + for stoich in self._input_inspection.stoichiometry_folders: + representative_entry = shared_entry_by_structure.get( + stoich.structure_label + ) + if representative_entry is None: + continue + shared_path = ( + Path(representative_entry.source_file).expanduser().resolve() + ) + key = str(stoich.input_dir) + self._shared_project_representative_path_by_input_dir[key] = ( + shared_path + ) + self._shared_project_representative_entry_by_input_dir[key] = ( + representative_entry + ) + row = self._stoichiometry_row_by_input_dir.get(key) + if row is not None: + output_text, output_tooltip = ( + self._project_representative_output_display( + representative_entry, + shared_path, + ) + ) + self._update_stoichiometry_row( + key, + status="Complete", + representative=( + str(representative_entry.source_file_name).strip() + or shared_path.name + ), + score=_format_score( + getattr(representative_entry, "score_total", None) + ), + output_text=output_text, + output_tooltip=output_tooltip, + representative_path=shared_path, + ) + + def _project_representative_output_display( + self, + representative_entry: object, + shared_path: Path, + ) -> tuple[str, str]: + cached_results_path = str( + getattr(representative_entry, "cached_results_path", "") or "" + ).strip() + if cached_results_path: + output_dir = ( + Path(cached_results_path).expanduser().resolve().parent + ) + return output_dir.name, str(output_dir) + source_dir = shared_path.parent + return source_dir.name, str(source_dir) + + def _project_representative_path_for_key( + self, key: str | None + ) -> Path | None: + if key is None: + return None + return self._shared_project_representative_path_by_input_dir.get(key) + + def _project_representative_entry_for_key( + self, key: str | None + ) -> object | None: + if key is None: + return None + return self._shared_project_representative_entry_by_input_dir.get(key) + + def _current_display_mode(self) -> str | None: + if not hasattr(self, "display_mode_combo"): + return None + payload = self.display_mode_combo.currentData() + if payload is None: + return None + return str(payload) + + def _set_display_mode(self, mode: str) -> bool: + if not hasattr(self, "display_mode_combo"): + return False + for index in range(self.display_mode_combo.count()): + if self.display_mode_combo.itemData(index) == mode: + self.display_mode_combo.setCurrentIndex(index) + return True + return False + + def _refresh_display_mode_options( + self, + result: RepresentativeFinderResult | None, + ) -> None: + previous_mode = self._current_display_mode() + self.display_mode_combo.blockSignals(True) + self.display_mode_combo.clear() + self.display_mode_combo.setEnabled(False) + if result is not None: + available_modes = {"selected_candidate", "observed_representative"} + if ( + result.predicted_candidate is not None + and result.predicted_output_path is not None + ): + available_modes.add("predicted_optimized_representative") + if ( + result.solvent_completed_predicted_candidate is not None + and result.solvent_completed_predicted_output_path is not None + ): + available_modes.add( + "solvent_completed_predicted_representative" + ) + for label, value in _DISPLAY_MODE_ITEMS: + if value in available_modes: + self.display_mode_combo.addItem(label, value) + self.display_mode_combo.setEnabled( + self.display_mode_combo.count() > 0 + ) + if previous_mode is not None and self._set_display_mode( + previous_mode + ): + pass + elif self.display_mode_combo.count() > 0: + self.display_mode_combo.setCurrentIndex(0) + self.display_mode_combo.blockSignals(False) + + def _active_display_payload( + self, + ) -> tuple[RepresentativeFinderCandidate, Path, str] | None: + result = self._selected_result() + if result is None: + return None + mode = self._current_display_mode() or "selected_candidate" + shared_project_path = self._project_representative_path_for_key( + str(result.input_dir) + ) + if mode == "observed_representative": + return ( + result.selected_candidate, + shared_project_path or result.representative_output_path, + "Observed Representative", + ) + if ( + mode == "predicted_optimized_representative" + and result.predicted_candidate is not None + and result.predicted_output_path is not None + ): + return ( + result.predicted_candidate, + result.predicted_output_path, + "Predicted Optimized Representative", + ) + if ( + mode == "solvent_completed_predicted_representative" + and result.solvent_completed_predicted_candidate is not None + and result.solvent_completed_predicted_output_path is not None + ): + return ( + result.solvent_completed_predicted_candidate, + result.solvent_completed_predicted_output_path, + "Solvent-completed Predicted Representative", + ) + candidate = self._selected_candidate() + if candidate is None: + candidate = result.selected_candidate + return candidate, candidate.file_path, "Selected Candidate" + + def _active_representative_input_path(self) -> Path | None: + payload = self._active_display_payload() + if payload is not None: + _candidate, display_path, _display_label = payload + if display_path.is_file(): + return display_path + key = self._selected_stoichiometry_key() + shared_path = self._project_representative_path_for_key(key) + if shared_path is not None and shared_path.is_file(): + return shared_path + return None + + def _refresh_solvent_shell_controls(self) -> None: + representative_path = self._active_representative_input_path() + self.open_selected_solvent_shell_button.setEnabled( + representative_path is not None + ) + project_root_text = ( + "" + if self._initial_project_dir is None + else f"Project root: {self._initial_project_dir}\n" + ) + if representative_path is not None: + self.solvent_shell_status_label.setText( + project_root_text + + "Selected representative source:\n" + + str(representative_path) + ) + return + self.solvent_shell_status_label.setText( + project_root_text + + "Select or compute a representative structure to enable this handoff." + ) + + def _open_solvent_shell_builder_for_selected_representative(self) -> None: + from saxshell.fullrmc.ui.solvent_shell_builder_window import ( + launch_solvent_shell_builder_ui, + ) + + representative_path = self._active_representative_input_path() + if representative_path is None: + QMessageBox.information( + self, + "Solvent Shell Builder", + "Select or compute a representative structure first.", + ) + return + window = launch_solvent_shell_builder_ui( + initial_project_dir=self._initial_project_dir, + initial_input_path=representative_path, + ) + window.raise_() + self.statusBar().showMessage( + f"Opened solvent shell builder for {representative_path.name}" + ) + self._append_console( + "Opened solvent shell builder for representative structure: " + f"{representative_path}" + ) + + def _refresh_suggested_output_dir(self) -> None: + inspection = self._input_inspection + if inspection is None: + return + batch_output = ( + self._current_analysis_mode() == "all" + or inspection.stoichiometry_count > 1 + or not inspection.input_is_stoichiometry_folder + ) + suggestion_source = ( + inspection.input_dir + if batch_output + else inspection.stoichiometry_folders[0].input_dir + ) + suggested_output = suggest_representativefinder_output_dir( + suggestion_source, + project_dir=self._initial_project_dir, + batch=batch_output, + ) + current_output = self.output_dir_edit.text().strip() + if ( + not current_output + or current_output == self._last_suggested_output_dir + ): + self.output_dir_edit.setText(str(suggested_output)) + self._last_suggested_output_dir = str(suggested_output) + + def _project_session_cache_key(self) -> str | None: + if self._initial_project_dir is not None: + return str(self._initial_project_dir) + input_dir_text = self.input_dir_edit.text().strip() + if input_dir_text: + try: + return str(Path(input_dir_text).expanduser().resolve()) + except Exception: + return input_dir_text + return None + + def _capture_project_session_state( + self, + ) -> RepresentativeFinderSessionState | None: + cache_key = self._project_session_cache_key() + if cache_key is None or not self._analysis_results_by_input_dir: + return None + settings: RepresentativeFinderSettings | None = None + try: + settings = self._current_settings() + except Exception: + result = next( + iter(self._analysis_results_by_input_dir.values()), None + ) + settings = None if result is None else result.settings + return RepresentativeFinderSessionState( + input_dir_text=self.input_dir_edit.text().strip(), + output_dir_text=self.output_dir_edit.text().strip(), + analysis_mode=self._current_analysis_mode(), + settings=settings, + results_by_input_dir=dict(self._analysis_results_by_input_dir), + failures_by_input_dir=dict(self._analysis_failures_by_input_dir), + selected_stoichiometry_key=self._selected_stoichiometry_key(), + display_mode=self._current_display_mode(), + console_text=self.console_box.toPlainText(), + ) + + def _save_project_session_state(self) -> None: + cache_key = self._project_session_cache_key() + if cache_key is None: + return + state = self._capture_project_session_state() + if state is None: + return + _PROJECT_SESSION_STATES[cache_key] = state + + def _restore_project_session_state(self) -> bool: + cache_key = self._project_session_cache_key() + if cache_key is None: + return False + state = _PROJECT_SESSION_STATES.get(cache_key) + if state is None: + return False + restored_input_dir = str(state.input_dir_text).strip() + if ( + restored_input_dir + and restored_input_dir != self.input_dir_edit.text().strip() + ): + restored_input_path = ( + Path(restored_input_dir).expanduser().resolve() + ) + if restored_input_path.is_dir(): + self._browse_start_dir = restored_input_path + self.input_dir_edit.setText(str(restored_input_path)) + self._refresh_input_preview() + self._set_analysis_mode(state.analysis_mode) + if state.output_dir_text: + self.output_dir_edit.setText(state.output_dir_text) + self._last_suggested_output_dir = state.output_dir_text + if state.settings is not None: + self._apply_settings_to_controls(state.settings) + restored_any = False + valid_keys = set(self._stoichiometry_row_by_input_dir) + self._analysis_results_by_input_dir = {} + self._analysis_failures_by_input_dir = {} + for key, result in state.results_by_input_dir.items(): + if key not in valid_keys: + continue + self._analysis_results_by_input_dir[key] = result + self._update_stoichiometry_row( + key, + status="Complete", + representative=result.selected_candidate.file_name, + score=_format_score(result.selected_candidate.score_total), + output_text=result.output_dir.name, + output_tooltip=str(result.output_dir), + representative_path=result.representative_output_path, + ) + restored_any = True + for key, message in state.failures_by_input_dir.items(): + if ( + key not in valid_keys + or key in self._analysis_results_by_input_dir + ): + continue + self._analysis_failures_by_input_dir[key] = message + self._update_stoichiometry_row( + key, + status="Failed", + representative="", + score="", + output_text="", + output_tooltip="", + representative_path=False, + ) + restored_any = True + if not restored_any: + return False + self.console_box.setPlainText(state.console_text) + self.run_status_label.setText( + "Representative selection: restored from project session" + ) + self.open_output_button.setEnabled( + bool(self._analysis_results_by_input_dir) + or bool(self.output_dir_edit.text().strip()) + ) + selected_key = state.selected_stoichiometry_key + if ( + selected_key is not None + and selected_key in self._analysis_results_by_input_dir + ): + self._select_stoichiometry_row_by_key(selected_key) + else: + first_key = next(iter(self._analysis_results_by_input_dir), None) + if first_key is not None: + self._select_stoichiometry_row_by_key(first_key) + if state.display_mode: + self._set_display_mode(state.display_mode) + self._refresh_solvent_shell_controls() + self.statusBar().showMessage( + "Restored representative-structure results from the current project session" + ) + return True + + def _restore_project_cached_results_with_startup_progress(self) -> bool: + if self._initial_project_dir is None or self._input_inspection is None: + return False + load_items: list[tuple[str, str, object, Path]] = [] + for stoich in self._input_inspection.stoichiometry_folders: + key = str(stoich.input_dir) + representative_entry = self._project_representative_entry_for_key( + key + ) + if representative_entry is None: + continue + cached_result_path = self._project_cached_result_path_for_entry( + representative_entry + ) + if cached_result_path is None: + continue + load_items.append( + ( + key, + stoich.structure_label, + representative_entry, + cached_result_path, + ) + ) + if not load_items: + return False + + progress_dialog = QProgressDialog( + "Loading saved representative-structure analysis...", + "Cancel", + 0, + len(load_items), + self, + ) + progress_dialog.setWindowTitle("Loading Representative Structures") + progress_dialog.setWindowModality(Qt.WindowModality.WindowModal) + progress_dialog.setMinimumDuration(0) + progress_dialog.setCancelButton(None) + progress_dialog.setValue(0) + progress_dialog.show() + app = QApplication.instance() + if app is not None: + app.processEvents() + + restored_keys: list[str] = [] + first_settings: RepresentativeFinderSettings | None = None + for ( + index, + (key, label, representative_entry, cached_result_path), + ) in enumerate(load_items, start=1): + progress_dialog.setLabelText( + f"Loading saved representative analysis for {label}..." + ) + if app is not None: + app.processEvents() + try: + result = load_representativefinder_result(cached_result_path) + except Exception as exc: + self._append_console( + "Unable to restore saved representative analysis for " + f"{label}: {exc}" + ) + progress_dialog.setValue(index) + continue + + result.input_dir = Path(key).expanduser().resolve() + self._analysis_results_by_input_dir[key] = result + self._analysis_failures_by_input_dir.pop(key, None) + if first_settings is None: + first_settings = result.settings + shared_path = self._project_representative_path_for_key(key) + output_text, output_tooltip = ( + self._project_representative_output_display( + representative_entry, + shared_path or result.representative_output_path, + ) + ) + if result.output_dir: + output_text = result.output_dir.name + output_tooltip = str(result.output_dir) + self._update_stoichiometry_row( + key, + status="Complete", + representative=( + str( + getattr( + representative_entry, + "source_file_name", + "", + ) + ).strip() + or result.selected_candidate.file_name + ), + score=_format_score(result.selected_candidate.score_total), + output_text=output_text, + output_tooltip=output_tooltip, + representative_path=( + shared_path or result.representative_output_path + ), + ) + restored_keys.append(key) + progress_dialog.setValue(index) + if app is not None: + app.processEvents() + + progress_dialog.close() + if not restored_keys: + return False + + if first_settings is not None: + self._apply_settings_to_controls(first_settings) + self.run_status_label.setText( + "Representative selection: restored from saved project analysis" + ) + self.open_output_button.setEnabled(True) + self._append_console( + "Loaded saved representative-structure analysis for " + f"{len(restored_keys)} stoichiometry row(s) from the active project." + ) + selected_key = self._selected_stoichiometry_key() + if selected_key not in self._analysis_results_by_input_dir: + selected_key = restored_keys[0] + self._select_stoichiometry_row_by_key(selected_key) + self._refresh_solvent_shell_controls() + self.statusBar().showMessage( + "Loaded saved representative-structure analysis from the active project" + ) + return True + + def _project_cached_result_path_for_entry( + self, + representative_entry: object, + ) -> Path | None: + for attribute_name in ( + "project_cached_results_path", + "cached_results_path", + ): + path_text = str( + getattr(representative_entry, attribute_name, "") or "" + ).strip() + if not path_text: + continue + path = Path(path_text).expanduser() + if ( + not path.is_absolute() + and self._initial_project_dir is not None + ): + path = self._initial_project_dir / path + resolved_path = path.resolve() + if resolved_path.is_file(): + return resolved_path + return None + + def _set_analysis_mode(self, mode: str) -> None: + for index in range(self.analysis_mode_combo.count()): + if self.analysis_mode_combo.itemData(index) == mode: + self.analysis_mode_combo.setCurrentIndex(index) + return + + def _apply_settings_to_controls( + self, + settings: RepresentativeFinderSettings, + ) -> None: + self._set_bond_pair_rows(settings.bond_pairs) + self._set_angle_triplet_rows(settings.angle_triplets) + self.bond_weight_spin.setValue(float(settings.bond_weight)) + self.angle_weight_spin.setValue(float(settings.angle_weight)) + self.solvent_weight_spin.setValue(float(settings.solvent_weight)) + self.generate_predicted_checkbox.setChecked( + bool(settings.generate_predicted_optimized_representative) + ) + for index in range(self.algorithm_combo.count()): + if ( + self.algorithm_combo.itemData(index) + == settings.selection_algorithm + ): + self.algorithm_combo.setCurrentIndex(index) + break + + def _reset_analysis_results(self) -> None: + self._analysis_results_by_input_dir = {} + self._analysis_failures_by_input_dir = {} + self._active_stoichiometry_key = None + self._run_summary = None + self._viewer_scene_payload_by_path = {} + self.open_output_button.setEnabled(False) + self._clear_selected_result_view( + "Select a stoichiometry row to inspect it. Completed runs will populate the summary, candidate scores, plots, and representative viewer." + ) + self._refresh_solvent_shell_controls() + + def _clear_selected_result_view(self, message: str) -> None: + self.result_summary_box.setPlainText(message) + self.candidate_table.setRowCount(0) + self.display_mode_combo.blockSignals(True) + self.display_mode_combo.clear() + self.display_mode_combo.setEnabled(False) + self.display_mode_combo.blockSignals(False) + self.plot_widget.set_result(None) + self.viewer_widget.draw_placeholder() + self._refresh_solvent_shell_controls() + + def _populate_stoichiometry_table( + self, + inspection: RepresentativeFinderInputInspection | None, + ) -> None: + self.stoichiometry_table.blockSignals(True) + self.stoichiometry_table.setRowCount(0) + self._stoichiometry_row_by_input_dir = {} + if inspection is not None: + for row, stoich in enumerate(inspection.stoichiometry_folders): + self.stoichiometry_table.insertRow(row) + key = str(stoich.input_dir) + self._stoichiometry_row_by_input_dir[key] = row + self._set_stoichiometry_table_item( + row, 0, stoich.structure_label, key + ) + self._set_stoichiometry_table_item( + row, + 1, + str(stoich.candidate_count), + key, + ) + self._set_stoichiometry_table_item( + row, + 2, + ( + ", ".join(stoich.motif_labels) + if stoich.motif_labels + else "none" + ), + key, + ) + self._set_stoichiometry_table_item(row, 3, "Pending", key) + self._set_stoichiometry_table_item(row, 4, "", key) + self._set_stoichiometry_table_item(row, 5, "", key) + self._set_stoichiometry_table_item(row, 6, "", key) + self._set_stoichiometry_open_button(row, key, None) + self.stoichiometry_table.blockSignals(False) + self.stoichiometry_table.resizeColumnsToContents() + if self.stoichiometry_table.rowCount() > 0: + self.stoichiometry_table.selectRow(0) + self._update_selected_stoichiometry_view() + + def _set_stoichiometry_table_item( + self, + row: int, + column: int, + text: str, + key: str, + ) -> None: + item = QTableWidgetItem(text) + item.setData(Qt.ItemDataRole.UserRole, key) + self.stoichiometry_table.setItem(row, column, item) + + def _set_stoichiometry_open_button( + self, + row: int, + key: str, + representative_path: Path | None, + ) -> None: + button = self.stoichiometry_table.cellWidget(row, 7) + if not isinstance(button, QPushButton): + button = QPushButton("Open in Finder", self.stoichiometry_table) + button.clicked.connect( + lambda _checked=False, row_key=key: ( + self._select_stoichiometry_row_by_key(row_key), + self._open_stoichiometry_representative_path(row_key), + ) + ) + self.stoichiometry_table.setCellWidget(row, 7, button) + resolved_path = ( + None + if representative_path is None + else representative_path.expanduser().resolve() + ) + if resolved_path is None or not resolved_path.is_file(): + button.setEnabled(False) + button.setToolTip( + "Representative output file is not available for this stoichiometry yet." + ) + return + button.setEnabled(True) + button.setToolTip(str(resolved_path)) + + def _reload_presets(self, *, selected_name: str | None = None) -> None: + previous_name = selected_name or self._selected_preset_name() + self._presets = load_presets() + self.preset_combo.blockSignals(True) + self.preset_combo.clear() + self.preset_combo.addItem("Select preset...", None) + selected_index = 0 + for index, name in enumerate( + ordered_preset_names(self._presets), + start=1, + ): + preset = self._presets[name] + label = f"{name} (Built-in)" if preset.builtin else name + self.preset_combo.addItem(label, name) + if name == previous_name: + selected_index = index + self.preset_combo.setCurrentIndex(selected_index) + self.preset_combo.blockSignals(False) + + def _selected_preset_name(self) -> str | None: + payload = self.preset_combo.currentData() + if payload is None: + return None + return str(payload) + + def load_preset(self, preset_name: str) -> None: + preset = self._presets.get(preset_name) + if preset is None: + raise KeyError(f"Unknown preset: {preset_name}") + self._set_bond_pair_rows(preset.bond_pairs) + self._set_angle_triplet_rows(preset.angle_triplets) + self._select_preset_name(preset.name) + + def _load_selected_preset(self) -> None: + preset_name = self._selected_preset_name() + if not preset_name: + QMessageBox.information( + self, + "Representative Presets", + "Choose a preset first.", + ) + return + try: + self.load_preset(preset_name) + except KeyError: + QMessageBox.warning( + self, + "Representative Presets", + f"The selected preset is no longer available: {preset_name}", + ) + return + self._append_console(f"Loaded representative preset: {preset_name}") + + def _save_current_preset(self) -> None: + try: + bond_pairs = self._read_bond_pairs() + angle_triplets = self._read_angle_triplets() + except ValueError as exc: + QMessageBox.warning(self, "Representative Presets", str(exc)) + return + suggested_name = self._selected_preset_name() or "" + name, accepted = QInputDialog.getText( + self, + "Save Representative Preset", + "Preset name", + text=suggested_name, + ) + if not accepted: + return + normalized_name = name.strip() + if not normalized_name: + QMessageBox.warning( + self, + "Representative Presets", + "Preset names cannot be empty.", + ) + return + preset = BondAnalysisPreset( + name=normalized_name, + bond_pairs=bond_pairs, + angle_triplets=angle_triplets, + builtin=False, + ) + save_custom_preset(preset) + self._reload_presets(selected_name=normalized_name) + self._append_console(f"Saved representative preset: {normalized_name}") + + def _select_preset_name(self, preset_name: str) -> None: + for index in range(self.preset_combo.count()): + if self.preset_combo.itemData(index) == preset_name: + self.preset_combo.setCurrentIndex(index) + return + + def _set_bond_pair_rows( + self, + definitions: tuple[BondPairDefinition, ...], + ) -> None: + self.bond_pair_table.blockSignals(True) + self.bond_pair_table.setRowCount(0) + if not definitions: + self._add_empty_bond_pair_row(blocked=True) + else: + for definition in definitions: + row = self.bond_pair_table.rowCount() + self.bond_pair_table.insertRow(row) + self.bond_pair_table.setItem( + row, 0, QTableWidgetItem(definition.atom1) + ) + self.bond_pair_table.setItem( + row, 1, QTableWidgetItem(definition.atom2) + ) + self.bond_pair_table.setItem( + row, + 2, + QTableWidgetItem(f"{definition.cutoff_angstrom:g}"), + ) + self.bond_pair_table.blockSignals(False) + + def _set_angle_triplet_rows( + self, + definitions: tuple[AngleTripletDefinition, ...], + ) -> None: + self.angle_triplet_table.blockSignals(True) + self.angle_triplet_table.setRowCount(0) + if not definitions: + self._add_empty_angle_triplet_row(blocked=True) + else: + for definition in definitions: + row = self.angle_triplet_table.rowCount() + self.angle_triplet_table.insertRow(row) + self.angle_triplet_table.setItem( + row, + 0, + QTableWidgetItem(definition.vertex), + ) + self.angle_triplet_table.setItem( + row, + 1, + QTableWidgetItem(definition.arm1), + ) + self.angle_triplet_table.setItem( + row, + 2, + QTableWidgetItem(definition.arm2), + ) + self.angle_triplet_table.setItem( + row, + 3, + QTableWidgetItem(f"{definition.cutoff1_angstrom:g}"), + ) + self.angle_triplet_table.setItem( + row, + 4, + QTableWidgetItem(f"{definition.cutoff2_angstrom:g}"), + ) + self.angle_triplet_table.blockSignals(False) + + def _add_empty_bond_pair_row(self, *, blocked: bool = False) -> None: + previous = self.bond_pair_table.blockSignals(blocked) + row = self.bond_pair_table.rowCount() + self.bond_pair_table.insertRow(row) + for column in range(self.bond_pair_table.columnCount()): + self.bond_pair_table.setItem(row, column, QTableWidgetItem("")) + self.bond_pair_table.blockSignals(previous) + + def _add_empty_angle_triplet_row(self, *, blocked: bool = False) -> None: + previous = self.angle_triplet_table.blockSignals(blocked) + row = self.angle_triplet_table.rowCount() + self.angle_triplet_table.insertRow(row) + for column in range(self.angle_triplet_table.columnCount()): + self.angle_triplet_table.setItem(row, column, QTableWidgetItem("")) + self.angle_triplet_table.blockSignals(previous) + + def _add_bond_pair_row(self) -> None: + self._add_empty_bond_pair_row(blocked=True) + + def _remove_selected_bond_pair_rows(self) -> None: + rows = sorted( + {index.row() for index in self.bond_pair_table.selectedIndexes()}, + reverse=True, + ) + for row in rows: + self.bond_pair_table.removeRow(row) + if self.bond_pair_table.rowCount() == 0: + self._add_empty_bond_pair_row(blocked=True) + + def _add_angle_triplet_row(self) -> None: + self._add_empty_angle_triplet_row(blocked=True) + + def _remove_selected_angle_triplet_rows(self) -> None: + rows = sorted( + { + index.row() + for index in self.angle_triplet_table.selectedIndexes() + }, + reverse=True, + ) + for row in rows: + self.angle_triplet_table.removeRow(row) + if self.angle_triplet_table.rowCount() == 0: + self._add_empty_angle_triplet_row(blocked=True) + + def _read_bond_pairs(self) -> tuple[BondPairDefinition, ...]: + definitions: list[BondPairDefinition] = [] + for row in range(self.bond_pair_table.rowCount()): + atom1 = self._table_text(self.bond_pair_table, row, 0) + atom2 = self._table_text(self.bond_pair_table, row, 1) + cutoff_text = self._table_text(self.bond_pair_table, row, 2) + if not atom1 and not atom2 and not cutoff_text: + continue + if not atom1 or not atom2 or not cutoff_text: + raise ValueError( + "Every non-empty bond-pair row must include Atom 1, Atom 2, and a cutoff." + ) + definitions.append( + BondPairDefinition(atom1, atom2, float(cutoff_text)) + ) + return tuple(definitions) + + def _read_angle_triplets(self) -> tuple[AngleTripletDefinition, ...]: + definitions: list[AngleTripletDefinition] = [] + for row in range(self.angle_triplet_table.rowCount()): + vertex = self._table_text(self.angle_triplet_table, row, 0) + arm1 = self._table_text(self.angle_triplet_table, row, 1) + arm2 = self._table_text(self.angle_triplet_table, row, 2) + cutoff1_text = self._table_text(self.angle_triplet_table, row, 3) + cutoff2_text = self._table_text(self.angle_triplet_table, row, 4) + if ( + not vertex + and not arm1 + and not arm2 + and not cutoff1_text + and not cutoff2_text + ): + continue + if not all((vertex, arm1, arm2, cutoff1_text, cutoff2_text)): + raise ValueError( + "Every non-empty angle-triplet row must include all five values." + ) + definitions.append( + AngleTripletDefinition( + vertex=vertex, + arm1=arm1, + arm2=arm2, + cutoff1_angstrom=float(cutoff1_text), + cutoff2_angstrom=float(cutoff2_text), + ) + ) + return tuple(definitions) + + def _current_settings(self) -> RepresentativeFinderSettings: + return RepresentativeFinderSettings( + selection_algorithm=str( + self.algorithm_combo.currentData() + or "target_distribution_quantile_distance" + ), + bond_weight=float(self.bond_weight_spin.value()), + angle_weight=float(self.angle_weight_spin.value()), + solvent_weight=float(self.solvent_weight_spin.value()), + generate_predicted_optimized_representative=bool( + self.generate_predicted_checkbox.isChecked() + ), + bond_pairs=self._read_bond_pairs(), + angle_triplets=self._read_angle_triplets(), + ) + + def _stoichiometry_inspection_for_key( + self, + key: str | None, + ) -> RepresentativeFinderFolderInspection | None: + inspection = self._input_inspection + if inspection is None or key is None: + return None + for stoich in inspection.stoichiometry_folders: + if str(stoich.input_dir) == key: + return stoich + return None + + def _selected_stoichiometry_key(self) -> str | None: + if not hasattr(self, "stoichiometry_table"): + return None + selected_items = self.stoichiometry_table.selectedItems() + if selected_items: + return str( + selected_items[0].data(Qt.ItemDataRole.UserRole) or "" + ).strip() + if self.stoichiometry_table.rowCount() <= 0: + return None + item = self.stoichiometry_table.item(0, 0) + if item is None: + return None + return str(item.data(Qt.ItemDataRole.UserRole) or "").strip() + + def _selected_stoichiometry_inspection( + self, + ) -> RepresentativeFinderFolderInspection | None: + return self._stoichiometry_inspection_for_key( + self._selected_stoichiometry_key() + ) + + def _selected_result(self) -> RepresentativeFinderResult | None: + key = self._selected_stoichiometry_key() + if key is None: + return None + return self._analysis_results_by_input_dir.get(key) + + def _overwrite_existing_representatives(self) -> bool: + checkbox = getattr(self, "overwrite_existing_checkbox", None) + return bool(checkbox is not None and checkbox.isChecked()) + + def _stoichiometry_has_saved_project_representative( + self, + stoich: RepresentativeFinderFolderInspection, + ) -> bool: + key = str(stoich.input_dir) + shared_path = self._project_representative_path_for_key(key) + return ( + self._project_representative_entry_for_key(key) is not None + and shared_path is not None + and shared_path.is_file() + ) + + def _selected_stoichiometries_for_current_mode( + self, + ) -> tuple[RepresentativeFinderFolderInspection, ...]: + inspection = self._input_inspection + if inspection is None or not inspection.stoichiometry_folders: + raise ValueError( + "Choose a valid stoichiometry folder or stoichiometry parent folder before running the analysis." + ) + if self._current_analysis_mode() == "all": + return tuple(inspection.stoichiometry_folders) + selected_stoich = ( + self._selected_stoichiometry_inspection() + or inspection.stoichiometry_folders[0] + ) + return (selected_stoich,) + + def _analysis_targets_from_inputs( + self, + *, + output_root: Path, + settings: RepresentativeFinderSettings, + ) -> tuple[RepresentativeFinderAnalysisTarget, ...]: + inspection = self._input_inspection + selected_stoichiometries = ( + self._selected_stoichiometries_for_current_mode() + ) + if not self._overwrite_existing_representatives(): + selected_stoichiometries = tuple( + stoich + for stoich in selected_stoichiometries + if not self._stoichiometry_has_saved_project_representative( + stoich + ) + ) + + use_direct_output_dir = ( + inspection.input_is_stoichiometry_folder + and inspection.stoichiometry_count == 1 + ) + targets: list[RepresentativeFinderAnalysisTarget] = [] + for stoich in selected_stoichiometries: + target_output_dir = ( + output_root + if use_direct_output_dir + else suggest_representativefinder_target_output_dir( + output_root, + stoich.structure_label, + ) + ) + targets.append( + RepresentativeFinderAnalysisTarget( + inspection=stoich, + output_dir=target_output_dir, + estimated_total_work=estimate_representativefinder_total_work( + stoich.candidate_count, + solvent_phase_enabled=settings.solvent_weight > 0.0, + predicted_phase_enabled=bool( + settings.generate_predicted_optimized_representative + ), + predicted_solvent_phase_enabled=bool( + settings.generate_predicted_optimized_representative + and self._initial_project_dir is not None + ), + ), + ) + ) + return tuple(targets) + + def _reset_stoichiometry_run_state(self, target_keys: set[str]) -> None: + self._analysis_results_by_input_dir = {} + self._analysis_failures_by_input_dir = {} + self._active_stoichiometry_key = None + self._run_summary = None + self._viewer_scene_payload_by_path = {} + for key in self._stoichiometry_row_by_input_dir: + representative_entry = self._project_representative_entry_for_key( + key + ) + shared_path = self._project_representative_path_for_key(key) + if ( + key not in target_keys + and representative_entry is not None + and shared_path is not None + and shared_path.is_file() + ): + output_text, output_tooltip = ( + self._project_representative_output_display( + representative_entry, + shared_path, + ) + ) + self._update_stoichiometry_row( + key, + status="Complete", + representative=( + str( + getattr( + representative_entry, + "source_file_name", + "", + ) + ).strip() + or shared_path.name + ), + score=_format_score( + getattr(representative_entry, "score_total", None) + ), + output_text=output_text, + output_tooltip=output_tooltip, + representative_path=shared_path, + ) + continue + self._update_stoichiometry_row( + key, + status=("Queued" if key in target_keys else "Not selected"), + representative="", + score="", + output_text="", + output_tooltip="", + representative_path=False, + ) + + def _run_analysis(self) -> None: + if self._analysis_thread is not None: + QMessageBox.information( + self, + "Representative selection running", + "A representative-selection run is already in progress.", + ) + return + input_dir_text = self.input_dir_edit.text().strip() + output_dir_text = self.output_dir_edit.text().strip() + if not input_dir_text: + QMessageBox.warning( + self, + "Representative Structure Finder", + "Choose a stoichiometry folder or stoichiometry parent folder before running the analysis.", + ) + return + if not output_dir_text: + QMessageBox.warning( + self, + "Representative Structure Finder", + "Choose an output folder before running the analysis.", + ) + return + try: + settings = self._current_settings() + except Exception as exc: + QMessageBox.warning( + self, + "Representative Structure Finder", + str(exc), + ) + return + + output_root = Path(output_dir_text).expanduser().resolve() + try: + self._refresh_project_representative_path_cache() + requested_stoichiometries = ( + self._selected_stoichiometries_for_current_mode() + ) + targets = self._analysis_targets_from_inputs( + output_root=output_root, + settings=settings, + ) + except Exception as exc: + QMessageBox.warning( + self, + "Representative Structure Finder", + str(exc), + ) + return + skipped_existing_count = ( + len(requested_stoichiometries) - len(targets) + if not self._overwrite_existing_representatives() + else 0 + ) + if not targets: + self._reset_stoichiometry_run_state(set()) + message = ( + "All selected stoichiometries already have saved " + "representative structures. Enable overwrite to recalculate " + "them." + ) + self.run_status_label.setText( + f"Representative selection: {message}" + ) + self.statusBar().showMessage(message) + self._append_console(message) + self.open_output_button.setEnabled(bool(output_dir_text)) + return + + self._reset_stoichiometry_run_state( + {str(target.inspection.input_dir) for target in targets} + ) + self.run_button.setEnabled(False) + self.open_output_button.setEnabled(bool(output_dir_text)) + self.run_status_label.setText( + "Representative selection: starting background task..." + ) + self.run_progress_bar.setRange(0, 1) + self.run_progress_bar.setValue(0) + self.result_summary_box.setPlainText( + "Representative selection is running. Click a stoichiometry row to follow its status. Completed rows will populate the score table, plots, and representative viewer." + ) + self.candidate_table.setRowCount(0) + self.display_mode_combo.blockSignals(True) + self.display_mode_combo.clear() + self.display_mode_combo.setEnabled(False) + self.display_mode_combo.blockSignals(False) + self.plot_widget.set_result(None) + self.viewer_widget.draw_placeholder() + self._append_console("Starting representative-structure analysis.") + if skipped_existing_count > 0: + self._append_console( + "Skipping " + f"{skipped_existing_count} stoichiometr" + f"{'y' if skipped_existing_count == 1 else 'ies'} " + "with saved project representative structures. Enable " + "overwrite to recalculate them." + ) + + config = RepresentativeFinderJobConfig( + analysis_mode=self._current_analysis_mode(), + targets=targets, + settings=settings, + project_dir=self._initial_project_dir, + ) + self._analysis_thread = QThread(self) + self._analysis_worker = RepresentativeFinderWorker(config) + self._analysis_worker.moveToThread(self._analysis_thread) + self._analysis_thread.started.connect(self._analysis_worker.run) + self._analysis_worker.log.connect(self._append_console) + self._analysis_worker.progress.connect(self._update_progress) + self._analysis_worker.target_started.connect(self._on_target_started) + self._analysis_worker.result_ready.connect( + self._on_target_result_ready + ) + self._analysis_worker.target_failed.connect(self._on_target_failed) + self._analysis_worker.finished.connect(self._finish_analysis_run) + self._analysis_worker.failed.connect(self._fail_analysis) + self._analysis_worker.canceled.connect(self._cancel_analysis_complete) + self._analysis_worker.finished.connect(self._analysis_thread.quit) + self._analysis_worker.failed.connect(self._analysis_thread.quit) + self._analysis_worker.canceled.connect(self._analysis_thread.quit) + self._analysis_thread.finished.connect(self._cleanup_thread) + self._analysis_thread.finished.connect( + self._analysis_thread.deleteLater + ) + self._analysis_worker.finished.connect( + self._analysis_worker.deleteLater + ) + self._analysis_worker.failed.connect(self._analysis_worker.deleteLater) + self._analysis_worker.canceled.connect( + self._analysis_worker.deleteLater + ) + self._analysis_thread.start() + + def _update_progress( + self, + processed: int, + total: int, + message: str, + ) -> None: + self.run_progress_bar.setRange(0, max(total, 1)) + self.run_progress_bar.setValue(min(max(processed, 0), max(total, 1))) + self.run_status_label.setText(f"Representative selection: {message}") + self.statusBar().showMessage(message) + + def _on_target_started(self, key: str) -> None: + self._active_stoichiometry_key = key + self._update_stoichiometry_row( + key, + status="Running", + ) + if self._selected_stoichiometry_key() == key: + self._update_selected_stoichiometry_view() + + def _on_target_result_ready( + self, result: RepresentativeFinderResult + ) -> None: + key = str(result.input_dir) + had_results = bool(self._analysis_results_by_input_dir) + shared_project_path: Path | None = None + if self._initial_project_dir is not None: + try: + shared_project_path = ( + persist_representativefinder_result_to_project( + self._initial_project_dir, + result, + ) + ) + except Exception as exc: + self._append_console( + "Unable to publish representative structure into the project " + f"folder: {exc}" + ) + else: + self._shared_project_representative_path_by_input_dir[key] = ( + shared_project_path + ) + self.project_results_changed.emit( + str(Path(self._initial_project_dir).resolve()) + ) + self._analysis_results_by_input_dir[key] = result + self._analysis_failures_by_input_dir.pop(key, None) + self._update_stoichiometry_row( + key, + status="Complete", + representative=result.selected_candidate.file_name, + score=_format_score(result.selected_candidate.score_total), + output_text=result.output_dir.name, + output_tooltip=str(result.output_dir), + representative_path=( + shared_project_path or result.representative_output_path + ), + ) + if not had_results or self._selected_stoichiometry_key() == key: + self._select_stoichiometry_row_by_key(key) + + def _on_target_failed( + self, failure: RepresentativeFinderTargetFailure + ) -> None: + key = str(failure.input_dir) + self._analysis_failures_by_input_dir[key] = failure.message + self._update_stoichiometry_row( + key, + status="Failed", + representative="", + score="", + output_text="", + output_tooltip="", + representative_path=False, + ) + self._append_console( + f"[{failure.structure_label}] Representative selection failed: {failure.message}" + ) + if self._selected_stoichiometry_key() == key: + self._update_selected_stoichiometry_view() + + def _finish_analysis_run( + self, summary: RepresentativeFinderRunSummary + ) -> None: + self._run_summary = summary + self.run_button.setEnabled(True) + self.open_output_button.setEnabled( + bool(self._analysis_results_by_input_dir) + or bool(self.output_dir_edit.text().strip()) + ) + if summary.failures: + self.run_status_label.setText( + "Representative selection: complete with failures" + ) + else: + self.run_status_label.setText("Representative selection: complete") + self.run_progress_bar.setValue(self.run_progress_bar.maximum()) + self.statusBar().showMessage("Representative selection complete") + selected_key = self._selected_stoichiometry_key() + if summary.results and ( + selected_key not in self._analysis_results_by_input_dir + ): + self._select_stoichiometry_row_by_key( + str(summary.results[0].input_dir) + ) + elif summary.failures: + failed_keys = { + str(failure.input_dir) for failure in summary.failures + } + if selected_key not in failed_keys and not summary.results: + self._select_stoichiometry_row_by_key( + str(summary.failures[0].input_dir) + ) + else: + self._update_selected_stoichiometry_view() + else: + self._update_selected_stoichiometry_view() + self._append_console( + "Representative selection complete: " + f"{len(summary.results)} stoichiometry run(s) completed, " + f"{len(summary.failures)} failed." + ) + + def _fail_analysis(self, message: str) -> None: + if self._closing_after_analysis_cancel: + self.statusBar().showMessage("Representative selection stopped") + self._append_console( + f"Representative selection stopped while closing: {message}" + ) + return + self.run_button.setEnabled(True) + self.open_output_button.setEnabled( + bool(self._analysis_results_by_input_dir) + or bool(self.output_dir_edit.text().strip()) + ) + self.run_status_label.setText("Representative selection: failed") + self._update_selected_stoichiometry_view() + self.statusBar().showMessage("Representative selection failed") + QMessageBox.warning( + self, + "Representative selection failed", + message, + ) + + def _cancel_analysis_complete(self) -> None: + if self._closing_after_analysis_cancel: + self.statusBar().showMessage("Representative selection canceled") + return + self.run_button.setEnabled(True) + self.open_output_button.setEnabled( + bool(self._analysis_results_by_input_dir) + or bool(self.output_dir_edit.text().strip()) + ) + self.run_status_label.setText("Representative selection: canceled") + self._update_selected_stoichiometry_view() + self.statusBar().showMessage("Representative selection canceled") + self._append_console( + "Representative selection canceled. Any completed stoichiometry rows remain available for review." + ) + + def _cleanup_thread(self) -> None: + self._analysis_worker = None + self._analysis_thread = None + if self._closing_after_analysis_cancel: + self._closing_after_analysis_cancel = False + QTimer.singleShot(0, self.close) + + def _update_stoichiometry_row( + self, + key: str, + *, + status: str | None = None, + representative: str | None = None, + score: str | None = None, + output_text: str | None = None, + output_tooltip: str | None = None, + representative_path: Path | None | bool = None, + ) -> None: + row = self._stoichiometry_row_by_input_dir.get(key) + if row is None: + return + + def set_column( + column: int, + value: str | None, + *, + tooltip: str | None = None, + ) -> None: + if value is None: + return + item = self.stoichiometry_table.item(row, column) + if item is None: + self._set_stoichiometry_table_item(row, column, value, key) + item = self.stoichiometry_table.item(row, column) + else: + item.setText(value) + if item is not None: + item.setData(Qt.ItemDataRole.UserRole, key) + item.setToolTip(tooltip if tooltip is not None else value) + + set_column(3, status) + set_column(4, representative) + set_column(5, score) + set_column(6, output_text, tooltip=output_tooltip) + if representative_path is not None: + self._set_stoichiometry_open_button( + row, + key, + ( + representative_path + if isinstance(representative_path, Path) + else None + ), + ) + self.stoichiometry_table.resizeColumnsToContents() + + def _select_stoichiometry_row_by_key(self, key: str) -> None: + row = self._stoichiometry_row_by_input_dir.get(key) + if row is None: + return + self.stoichiometry_table.selectRow(row) + self._update_selected_stoichiometry_view() + + def _populate_candidate_table( + self, + result: RepresentativeFinderResult, + ) -> None: + self.candidate_table.setRowCount(0) + for row, candidate in enumerate(result.candidates): + self.candidate_table.insertRow(row) + self._set_candidate_table_item( + row, 0, candidate.file_name, candidate + ) + self._set_candidate_table_item( + row, 1, candidate.relative_label, candidate + ) + self._set_candidate_table_item( + row, 2, _format_score(candidate.score_total), candidate + ) + self._set_candidate_table_item( + row, 3, _format_score(candidate.score_bond), candidate + ) + self._set_candidate_table_item( + row, 4, _format_score(candidate.score_angle), candidate + ) + self._set_candidate_table_item( + row, 5, _format_score(candidate.score_solvent), candidate + ) + self._set_candidate_table_item( + row, 6, str(candidate.atom_count), candidate + ) + self._set_candidate_table_item( + row, 7, str(candidate.solvent_atom_count), candidate + ) + self.candidate_table.resizeColumnsToContents() + + def _set_candidate_table_item( + self, + row: int, + column: int, + text: str, + candidate: RepresentativeFinderCandidate, + ) -> None: + item = QTableWidgetItem(text) + item.setData(Qt.ItemDataRole.UserRole, str(candidate.file_path)) + self.candidate_table.setItem(row, column, item) + + def _select_candidate_row(self, row: int) -> None: + if row < 0 or row >= self.candidate_table.rowCount(): + return + self.candidate_table.selectRow(row) + self._update_selected_candidate_view() + + def _selected_candidate(self) -> RepresentativeFinderCandidate | None: + result = self._selected_result() + if result is None: + return None + selected_items = self.candidate_table.selectedItems() + if not selected_items: + return result.selected_candidate + target_path = str( + selected_items[0].data(Qt.ItemDataRole.UserRole) or "" + ).strip() + for candidate in result.candidates: + if str(candidate.file_path) == target_path: + return candidate + return result.selected_candidate + + def _update_selected_stoichiometry_view(self) -> None: + key = self._selected_stoichiometry_key() + self._active_stoichiometry_key = key + if key is None: + self._clear_selected_result_view( + "Select a stoichiometry row to inspect it." + ) + return + result = self._analysis_results_by_input_dir.get(key) + if result is not None: + self._populate_candidate_table(result) + self._refresh_display_mode_options(result) + self.plot_widget.set_result(result) + self._select_candidate_row(0) + return + representative_entry = self._project_representative_entry_for_key(key) + if representative_entry is not None: + shared_path = self._project_representative_path_for_key(key) + if shared_path is not None: + self._clear_selected_result_view( + self._project_representative_summary_text( + representative_entry, + shared_path, + ) + ) + return + + inspection = self._stoichiometry_inspection_for_key(key) + if inspection is None: + self._clear_selected_result_view( + "Select a stoichiometry row to inspect it." + ) + return + row = self._stoichiometry_row_by_input_dir.get(key, -1) + status_item = ( + self.stoichiometry_table.item(row, 3) if row >= 0 else None + ) + status_text = ( + status_item.text().strip() + if status_item is not None + else "Pending" + ) + failure_message = self._analysis_failures_by_input_dir.get(key) + lines = [ + f"Stoichiometry: {inspection.structure_label}", + f"Input folder: {inspection.input_dir}", + f"Status: {status_text}", + f"Candidate files: {inspection.candidate_count}", + f"Direct files: {inspection.direct_file_count}", + "Motif folders: " + + ( + ", ".join(inspection.motif_labels) + if inspection.motif_labels + else "none" + ), + ] + if failure_message: + lines.extend(["", "Failure", failure_message]) + else: + lines.extend( + [ + "", + "Run representative selection to populate this stoichiometry with a representative structure and candidate ranking.", + ] + ) + self._clear_selected_result_view("\n".join(lines)) + + def _project_representative_summary_text( + self, + representative_entry: object, + shared_path: Path, + ) -> str: + element_counts = ( + getattr(representative_entry, "element_counts", {}) or {} + ) + element_text = ( + ", ".join( + f"{element} x{count}" + for element, count in sorted(dict(element_counts).items()) + ) + if element_counts + else "none" + ) + output_text, output_tooltip = ( + self._project_representative_output_display( + representative_entry, + shared_path, + ) + ) + cached_results_path = str( + getattr(representative_entry, "cached_results_path", "") or "" + ).strip() + lines = [ + f"Stoichiometry: {getattr(representative_entry, 'structure', '')}", + "Status: Complete", + "Representative: " + + ( + str( + getattr(representative_entry, "source_file_name", "") + ).strip() + or shared_path.name + ), + "Score: " + + _format_score( + getattr(representative_entry, "score_total", None) + ), + f"Atoms: {int(getattr(representative_entry, 'atom_count', 0))}", + f"Elements: {element_text}", + f"Output: {output_text}", + f"Output folder: {output_tooltip}", + f"Project representative file: {shared_path}", + ] + analysis_source = str( + getattr(representative_entry, "analysis_source", "") or "" + ).strip() + if analysis_source: + lines.append(f"Analysis source: {analysis_source}") + source_mode = str( + getattr(representative_entry, "source_solvent_mode", "") or "" + ).strip() + if source_mode: + lines.append(f"Source solvent mode: {source_mode}") + if cached_results_path: + lines.append(f"Cached analysis metadata: {cached_results_path}") + return "\n".join(lines) + + def _update_selected_candidate_view(self) -> None: + result = self._selected_result() + payload = self._active_display_payload() + if result is None or payload is None: + return + candidate, display_path, display_mode_label = payload + self.result_summary_box.setPlainText( + self._candidate_summary_text( + result, + candidate, + display_mode_label=display_mode_label, + display_path=display_path, + ) + ) + self.plot_widget.set_result(result, candidate=candidate) + try: + structure, mesh_geometry = self._load_viewer_scene( + candidate, + structure_path=display_path, + display_label=( + candidate.relative_label + if display_mode_label == "Selected Candidate" + else display_mode_label + ), + ) + self.viewer_widget.set_structure( + structure, + mesh_geometry=mesh_geometry, + scene_key=str(display_path), + ) + except Exception: + self.viewer_widget.draw_placeholder() + self._refresh_solvent_shell_controls() + return + self._refresh_solvent_shell_controls() + + def _load_viewer_scene( + self, + candidate: RepresentativeFinderCandidate, + *, + structure_path: Path | None = None, + display_label: str | None = None, + ) -> tuple[ElectronDensityStructure, ElectronDensityMeshGeometry | None]: + resolved_path = ( + candidate.file_path.expanduser().resolve() + if structure_path is None + else structure_path.expanduser().resolve() + ) + cache_key = str(resolved_path) + cached = self._viewer_scene_payload_by_path.get(cache_key) + if cached is not None: + structure, mesh_geometry = cached + resolved_label = str( + display_label or candidate.relative_label + ).strip() + if resolved_label and resolved_label != structure.display_label: + structure = replace(structure, display_label=resolved_label) + return structure, mesh_geometry + + structure = load_electron_density_structure(resolved_path) + resolved_label = str(display_label or candidate.relative_label).strip() + if resolved_label and resolved_label != structure.display_label: + structure = replace(structure, display_label=resolved_label) + mesh_geometry = build_electron_density_mesh( + structure, + legacy_born_average_default_mesh_settings(structure), + ) + payload = (structure, mesh_geometry) + self._viewer_scene_payload_by_path[cache_key] = payload + return payload + + def _candidate_summary_text( + self, + result: RepresentativeFinderResult, + candidate: RepresentativeFinderCandidate, + *, + display_mode_label: str, + display_path: Path, + ) -> str: + shared_project_path = self._project_representative_path_for_key( + str(result.input_dir) + ) + representative_entry = self._project_representative_entry_for_key( + str(result.input_dir) + ) + lines = [ + f"Stoichiometry: {result.structure_label}", + "Status: Complete", + f"Displayed structure: {display_mode_label}", + "Representative: " + + ( + shared_project_path.name + if shared_project_path is not None + else result.representative_output_path.name + ), + f"Candidate: {candidate.file_name}", + f"Source label: {candidate.relative_label}", + f"Displayed file: {display_path}", + ( + "Scores: " + f"total={_format_score(candidate.score_total)}, " + f"bond={_format_score(candidate.score_bond)}, " + f"angle={_format_score(candidate.score_angle)}, " + f"solvent={_format_score(candidate.score_solvent)}" + ), + f"Atoms: {candidate.atom_count}", + "Elements: " + + ", ".join( + f"{element} x{count}" + for element, count in sorted(candidate.element_counts.items()) + ), + ( + "Solvent shell: " + f"total={candidate.solvent_atom_count}, " + f"direct={candidate.direct_solvent_atom_count}, " + f"outer={candidate.outer_solvent_atom_count}" + ), + ( + "Mean direct solvent coordination: " + f"{candidate.mean_direct_solvent_coordination:.6g}" + ), + "", + "Observed representative output: " + + str(shared_project_path or result.representative_output_path), + ] + if shared_project_path is not None: + lines.append(f"Project representative file: {shared_project_path}") + source_mode = ( + "" + if representative_entry is None + else str( + getattr(representative_entry, "source_solvent_mode", "") or "" + ).strip() + ) + if source_mode: + lines.append(f"Source solvent mode: {source_mode}") + if result.predicted_output_path is not None: + lines.append( + "Predicted optimized representative output: " + + str(result.predicted_output_path) + ) + if result.solvent_completed_predicted_output_path is not None: + lines.append( + "Solvent-completed predicted output: " + + str(result.solvent_completed_predicted_output_path) + ) + if result.predicted_generation_notes: + lines.extend(["", "Predicted-output notes"]) + lines.extend( + f" {note}" for note in result.predicted_generation_notes + ) + if candidate.descriptor_notes: + lines.extend(["", "Notes"]) + lines.extend(f" {note}" for note in candidate.descriptor_notes) + return "\n".join(lines) + + def _show_output_folder(self) -> None: + result = self._selected_result() + if result is not None: + output_path: Path | None = result.output_dir + else: + output_text = self.output_dir_edit.text().strip() + output_path = ( + Path(output_text).expanduser().resolve() + if output_text + else None + ) + if output_path is None: + return + QMessageBox.information( + self, + "Representative Output Folder", + str(output_path), + ) + + def _open_stoichiometry_representative_path(self, key: str) -> None: + shared_project_path = self._project_representative_path_for_key(key) + if shared_project_path is not None and shared_project_path.is_file(): + representative_path = shared_project_path + else: + result = self._analysis_results_by_input_dir.get(key) + representative_path = ( + None + if result is None + else result.representative_output_path.expanduser().resolve() + ) + if representative_path is None or not representative_path.is_file(): + QMessageBox.information( + self, + "Representative Structure Path", + "This stoichiometry does not have a saved representative structure file yet.", + ) + return + try: + self._reveal_path_in_file_manager(representative_path) + except Exception as exc: + QMessageBox.warning( + self, + "Representative Structure Path", + f"Could not open the representative structure path:\n{exc}", + ) + return + self.statusBar().showMessage( + f"Opened representative structure path for {representative_path.name}" + ) + self._append_console( + f"Opened representative structure path in Finder: {representative_path}" + ) + + @staticmethod + def _reveal_path_in_file_manager(path: Path) -> None: + resolved_path = path.expanduser().resolve() + if sys.platform == "darwin": + subprocess.Popen(["open", "-R", str(resolved_path)]) + return + if sys.platform.startswith("win"): + subprocess.Popen(["explorer", f"/select,{resolved_path}"]) + return + QDesktopServices.openUrl(QUrl.fromLocalFile(str(resolved_path.parent))) + + def _append_console(self, message: str) -> None: + text = str(message).strip() + if not text: + return + existing = self.console_box.toPlainText().strip() + if existing: + self.console_box.setPlainText(existing + "\n" + text) + else: + self.console_box.setPlainText(text) + cursor = self.console_box.textCursor() + cursor.movePosition(cursor.MoveOperation.End) + self.console_box.setTextCursor(cursor) + + @staticmethod + def _table_text(table: QTableWidget, row: int, column: int) -> str: + item = table.item(row, column) + if item is None: + return "" + return item.text().strip() + + @staticmethod + def _new_float_spin( + *, + maximum: float, + step: float, + decimals: int, + value: float, + ) -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setRange(0.0, maximum) + spin.setSingleStep(step) + spin.setDecimals(decimals) + spin.setValue(value) + return spin + + +def _format_score(value: float | None) -> str: + if value is None: + return "n/a" + return f"{float(value):.6f}" + + +def launch_representativefinder_ui( + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, +) -> RepresentativeStructureFinderMainWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=initial_project_dir, + initial_input_path=initial_input_path, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "RepresentativeStructureFinderMainWindow", + "launch_representativefinder_ui", +] diff --git a/src/saxshell/representativefinder/ui/run_file_window.py b/src/saxshell/representativefinder/ui/run_file_window.py new file mode 100644 index 0000000..0f5ecb6 --- /dev/null +++ b/src/saxshell/representativefinder/ui/run_file_window.py @@ -0,0 +1,692 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSpinBox, + QSplitter, + QVBoxLayout, + QWidget, +) + +from saxshell.bondanalysis import ( + AngleTripletDefinition, + BondAnalysisPreset, + BondPairDefinition, + load_presets, + ordered_preset_names, +) +from saxshell.representativefinder.run_config import ( + build_representativefinder_run_config, + default_representativefinder_run_file_path, + save_representativefinder_run_config, + suggest_run_config_output_dir, +) +from saxshell.representativefinder.workflow import ( + RepresentativeFinderInputInspection, + RepresentativeFinderSettings, + inspect_representative_structure_input, +) +from saxshell.saxs.project_manager import SAXSProjectManager +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + load_saxshell_icon, + prepare_saxshell_application_identity, +) + +_ALGORITHM_ITEMS = ( + ("Quantile Distance", "target_distribution_quantile_distance"), + ("Mean/Std Distance", "target_distribution_moment_distance"), +) +_ANALYSIS_MODE_ITEMS = ( + ("All Discovered Stoichiometries", "all"), + ("Selected Stoichiometry Only", "single"), +) + + +class RepresentativeFinderRunFileWindow(QMainWindow): + def __init__( + self, + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + ) -> None: + super().__init__() + self._inspection: RepresentativeFinderInputInspection | None = None + self._last_suggested_output_dir: str | None = None + self._presets: dict[str, BondAnalysisPreset] = {} + self._browse_start_dir = Path.home() + + project_dir = ( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ) + input_path = ( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ) + if project_dir is not None: + self._browse_start_dir = project_dir + if input_path is None: + input_path = self._project_clusters_dir(project_dir) + + self.setWindowTitle("Representative Structure CLI Setup (Beta)") + self.setWindowIcon(load_saxshell_icon()) + self.resize(1040, 760) + self._build_ui() + self._reload_presets() + + if project_dir is not None: + self.project_dir_edit.setText(str(project_dir)) + self._refresh_run_file_path() + if input_path is not None and input_path.is_dir(): + self.input_dir_edit.setText(str(input_path)) + self._browse_start_dir = input_path + self._inspect_input() + self._update_command_preview() + + def _build_ui(self) -> None: + central = QWidget(self) + root = QVBoxLayout(central) + root.setContentsMargins(10, 10, 10, 10) + root.setSpacing(8) + self.setCentralWidget(central) + + splitter = QSplitter(Qt.Orientation.Horizontal, self) + splitter.setChildrenCollapsible(False) + root.addWidget(splitter, stretch=1) + + left_scroll = QScrollArea(self) + left_scroll.setWidgetResizable(True) + left_panel = QWidget() + self.left_layout = QVBoxLayout(left_panel) + self.left_layout.setContentsMargins(10, 10, 10, 10) + self.left_layout.setSpacing(10) + left_scroll.setWidget(left_panel) + + right_scroll = QScrollArea(self) + right_scroll.setWidgetResizable(True) + right_panel = QWidget() + self.right_layout = QVBoxLayout(right_panel) + self.right_layout.setContentsMargins(10, 10, 10, 10) + self.right_layout.setSpacing(10) + right_scroll.setWidget(right_panel) + + splitter.addWidget(left_scroll) + splitter.addWidget(right_scroll) + splitter.setSizes([500, 540]) + + self.left_layout.addWidget(self._build_project_group()) + self.left_layout.addWidget(self._build_input_group()) + self.left_layout.addWidget(self._build_preset_group()) + self.left_layout.addWidget(self._build_measurement_group()) + self.left_layout.addWidget(self._build_scoring_group()) + self.left_layout.addWidget(self._build_save_group()) + self.left_layout.addStretch(1) + + self.right_layout.addWidget(self._build_inspection_group()) + self.right_layout.addWidget(self._build_command_group()) + self.right_layout.addStretch(1) + self.statusBar().showMessage("Ready") + + def _build_project_group(self) -> QGroupBox: + group = QGroupBox("Project") + form = QFormLayout(group) + project_row = QHBoxLayout() + self.project_dir_edit = QLineEdit() + self.project_dir_edit.editingFinished.connect( + self._on_project_dir_changed + ) + project_row.addWidget(self.project_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_project_dir) + project_row.addWidget(browse_button) + project_widget = QWidget() + project_widget.setLayout(project_row) + form.addRow("Project folder", project_widget) + + self.run_file_edit = QLineEdit() + self.run_file_edit.setReadOnly(True) + form.addRow("Run file", self.run_file_edit) + return group + + def _build_input_group(self) -> QGroupBox: + group = QGroupBox("Input") + form = QFormLayout(group) + input_row = QHBoxLayout() + self.input_dir_edit = QLineEdit() + self.input_dir_edit.editingFinished.connect(self._inspect_input) + input_row.addWidget(self.input_dir_edit, stretch=1) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self._browse_input_dir) + input_row.addWidget(browse_button) + input_widget = QWidget() + input_widget.setLayout(input_row) + form.addRow("Input folder", input_widget) + + self.analysis_mode_combo = QComboBox() + for label, value in _ANALYSIS_MODE_ITEMS: + self.analysis_mode_combo.addItem(label, value) + self.analysis_mode_combo.currentIndexChanged.connect( + self._on_analysis_mode_changed + ) + form.addRow("Analysis mode", self.analysis_mode_combo) + + self.stoichiometry_combo = QComboBox() + self.stoichiometry_combo.currentIndexChanged.connect( + self._update_command_preview + ) + form.addRow("Stoichiometry", self.stoichiometry_combo) + + output_row = QHBoxLayout() + self.output_dir_edit = QLineEdit() + self.output_dir_edit.editingFinished.connect( + self._update_command_preview + ) + output_row.addWidget(self.output_dir_edit, stretch=1) + output_browse_button = QPushButton("Browse...") + output_browse_button.clicked.connect(self._browse_output_dir) + output_row.addWidget(output_browse_button) + output_widget = QWidget() + output_widget.setLayout(output_row) + form.addRow("Output folder", output_widget) + return group + + def _build_preset_group(self) -> QGroupBox: + group = QGroupBox("Bondanalysis Preset") + layout = QHBoxLayout(group) + self.preset_combo = QComboBox() + layout.addWidget(self.preset_combo, stretch=1) + load_button = QPushButton("Load") + load_button.clicked.connect(self._load_selected_preset) + layout.addWidget(load_button) + return group + + def _build_measurement_group(self) -> QGroupBox: + group = QGroupBox("Measurements") + layout = QVBoxLayout(group) + layout.addWidget(QLabel("Bond pairs")) + self.bond_pairs_edit = QPlainTextEdit() + self.bond_pairs_edit.setMinimumHeight(90) + self.bond_pairs_edit.textChanged.connect(self._update_command_preview) + layout.addWidget(self.bond_pairs_edit) + + layout.addWidget(QLabel("Angle triplets")) + self.angle_triplets_edit = QPlainTextEdit() + self.angle_triplets_edit.setMinimumHeight(90) + self.angle_triplets_edit.textChanged.connect( + self._update_command_preview + ) + layout.addWidget(self.angle_triplets_edit) + return group + + def _build_scoring_group(self) -> QGroupBox: + group = QGroupBox("Scoring") + form = QFormLayout(group) + self.algorithm_combo = QComboBox() + for label, value in _ALGORITHM_ITEMS: + self.algorithm_combo.addItem(label, value) + self.algorithm_combo.currentIndexChanged.connect( + self._update_command_preview + ) + form.addRow("Algorithm", self.algorithm_combo) + + self.bond_weight_spin = self._new_float_spin(value=1.0) + self.bond_weight_spin.valueChanged.connect( + self._update_command_preview + ) + form.addRow("Bond weight", self.bond_weight_spin) + self.angle_weight_spin = self._new_float_spin(value=1.0) + self.angle_weight_spin.valueChanged.connect( + self._update_command_preview + ) + form.addRow("Angle weight", self.angle_weight_spin) + self.solvent_weight_spin = self._new_float_spin(value=1.0) + self.solvent_weight_spin.valueChanged.connect( + self._update_command_preview + ) + form.addRow("Solvent weight", self.solvent_weight_spin) + + worker_default = min(max(os.cpu_count() or 1, 1), 32) + self.worker_spin = QSpinBox() + self.worker_spin.setRange(0, 32) + self.worker_spin.setValue(worker_default) + self.worker_spin.valueChanged.connect(self._update_command_preview) + form.addRow("Worker threads", self.worker_spin) + + self.generate_predicted_checkbox = QCheckBox( + "Generate predicted optimized representative" + ) + self.generate_predicted_checkbox.toggled.connect( + self._update_command_preview + ) + form.addRow("", self.generate_predicted_checkbox) + + self.overwrite_existing_checkbox = QCheckBox( + "Overwrite existing project representatives" + ) + self.overwrite_existing_checkbox.toggled.connect( + self._update_command_preview + ) + form.addRow("", self.overwrite_existing_checkbox) + return group + + def _build_save_group(self) -> QGroupBox: + group = QGroupBox("Save") + layout = QHBoxLayout(group) + inspect_button = QPushButton("Inspect Input") + inspect_button.clicked.connect(self._inspect_input) + layout.addWidget(inspect_button) + save_button = QPushButton("Save Run File") + save_button.clicked.connect(self._save_run_file) + layout.addWidget(save_button) + layout.addStretch(1) + return group + + def _build_inspection_group(self) -> QGroupBox: + group = QGroupBox("Inspection") + layout = QVBoxLayout(group) + self.inspection_box = QPlainTextEdit() + self.inspection_box.setReadOnly(True) + self.inspection_box.setMinimumHeight(240) + layout.addWidget(self.inspection_box) + return group + + def _build_command_group(self) -> QGroupBox: + group = QGroupBox("CLI Command") + layout = QVBoxLayout(group) + self.command_box = QPlainTextEdit() + self.command_box.setReadOnly(True) + self.command_box.setMinimumHeight(170) + layout.addWidget(self.command_box) + self.json_preview_box = QPlainTextEdit() + self.json_preview_box.setReadOnly(True) + self.json_preview_box.setMinimumHeight(280) + layout.addWidget(self.json_preview_box) + return group + + def _browse_project_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select project folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.project_dir_edit.setText(selected) + self._on_project_dir_changed() + + def _browse_input_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select representative input folder", + str(self._browse_start_dir), + ) + if not selected: + return + self.input_dir_edit.setText(selected) + self._browse_start_dir = Path(selected).expanduser().resolve() + self._inspect_input() + + def _browse_output_dir(self, *_args: object) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Select representative output folder", + self.output_dir_edit.text().strip() or str(self._browse_start_dir), + ) + if selected: + self.output_dir_edit.setText(selected) + self._update_command_preview() + + def _on_project_dir_changed(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + return + self._browse_start_dir = project_dir + self._refresh_run_file_path() + if not self.input_dir_edit.text().strip(): + clusters_dir = self._project_clusters_dir(project_dir) + if clusters_dir is not None and clusters_dir.is_dir(): + self.input_dir_edit.setText(str(clusters_dir)) + self._inspect_input() + + def _on_analysis_mode_changed(self, *_args: object) -> None: + self._refresh_stoichiometry_enabled() + self._refresh_suggested_output_dir() + self._update_command_preview() + + def _inspect_input(self, *_args: object) -> None: + input_text = self.input_dir_edit.text().strip() + if not input_text: + self._inspection = None + self.stoichiometry_combo.clear() + self.inspection_box.setPlainText("No input folder selected.") + self._update_command_preview() + return + try: + inspection = inspect_representative_structure_input(input_text) + except Exception as exc: + self._inspection = None + self.stoichiometry_combo.clear() + self.inspection_box.setPlainText(str(exc)) + self.statusBar().showMessage("Input inspection failed") + self._update_command_preview() + return + self._inspection = inspection + self.stoichiometry_combo.blockSignals(True) + self.stoichiometry_combo.clear() + for stoich in inspection.stoichiometry_folders: + self.stoichiometry_combo.addItem( + stoich.structure_label, + stoich.structure_label, + ) + self.stoichiometry_combo.blockSignals(False) + self.inspection_box.setPlainText(inspection.summary_text()) + self._refresh_stoichiometry_enabled() + self._refresh_suggested_output_dir() + self.statusBar().showMessage( + f"Discovered {inspection.stoichiometry_count} stoichiometry folder(s)" + ) + self._update_command_preview() + + def _refresh_stoichiometry_enabled(self) -> None: + self.stoichiometry_combo.setEnabled( + self._analysis_mode() == "single" + and self.stoichiometry_combo.count() > 0 + ) + + def _refresh_run_file_path(self) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.run_file_edit.clear() + return + self.run_file_edit.setText( + str(default_representativefinder_run_file_path(project_dir)) + ) + + def _refresh_suggested_output_dir(self) -> None: + project_dir = self._project_dir() + input_text = self.input_dir_edit.text().strip() + if project_dir is None or not input_text: + return + try: + suggested = suggest_run_config_output_dir( + project_dir=project_dir, + input_dir=input_text, + analysis_mode=self._analysis_mode(), + ) + except Exception: + return + current = self.output_dir_edit.text().strip() + if not current or current == self._last_suggested_output_dir: + self.output_dir_edit.setText(str(suggested)) + self._last_suggested_output_dir = str(suggested) + + def _reload_presets(self, *, selected_name: str | None = None) -> None: + self._presets = load_presets() + previous_name = selected_name or self._selected_preset_name() + self.preset_combo.blockSignals(True) + self.preset_combo.clear() + self.preset_combo.addItem("Select preset...", None) + selected_index = 0 + for name in ordered_preset_names(self._presets): + preset = self._presets[name] + label = f"{name} (Built-in)" if preset.builtin else name + self.preset_combo.addItem(label, name) + if name == previous_name: + selected_index = self.preset_combo.count() - 1 + self.preset_combo.setCurrentIndex(selected_index) + self.preset_combo.blockSignals(False) + + def _selected_preset_name(self) -> str | None: + payload = self.preset_combo.currentData() + if payload is None: + return None + return str(payload) + + def _load_selected_preset(self, *_args: object) -> None: + preset_name = self._selected_preset_name() + if not preset_name: + QMessageBox.information( + self, + "Representative CLI Setup", + "Choose a preset first.", + ) + return + preset = self._presets.get(preset_name) + if preset is None: + QMessageBox.warning( + self, + "Representative CLI Setup", + f"The selected preset is no longer available: {preset_name}", + ) + return + self.bond_pairs_edit.setPlainText( + "\n".join( + f"{pair.atom1}:{pair.atom2}:{pair.cutoff_angstrom:g}" + for pair in preset.bond_pairs + ) + ) + self.angle_triplets_edit.setPlainText( + "\n".join( + ( + f"{triplet.vertex}:{triplet.arm1}:{triplet.arm2}:" + f"{triplet.cutoff1_angstrom:g}:" + f"{triplet.cutoff2_angstrom:g}" + ) + for triplet in preset.angle_triplets + ) + ) + self.statusBar().showMessage(f"Loaded preset: {preset_name}") + self._update_command_preview() + + def _save_run_file(self, *_args: object) -> None: + try: + project_dir = self._require_project_dir() + config = self._current_config(project_dir) + except Exception as exc: + QMessageBox.warning( + self, + "Representative CLI Setup", + str(exc), + ) + return + run_file_path = default_representativefinder_run_file_path(project_dir) + save_representativefinder_run_config(run_file_path, config) + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + self._update_command_preview() + self.statusBar().showMessage(f"Saved run file: {run_file_path}") + QMessageBox.information( + self, + "Representative CLI Setup", + f"Saved representative CLI run file:\n{run_file_path}", + ) + + def _update_command_preview(self, *_args: object) -> None: + project_dir = self._project_dir() + if project_dir is None: + self.command_box.setPlainText( + "Select a project folder before saving the CLI run file." + ) + self.json_preview_box.clear() + return + run_file_path = default_representativefinder_run_file_path(project_dir) + command = f'representativefinder run "{project_dir}"' + self.command_box.setPlainText( + command + + "\n" + + f'saxshell representativefinder run "{project_dir}"' + ) + try: + config = self._current_config(project_dir) + except Exception as exc: + self.json_preview_box.setPlainText(str(exc)) + return + self.run_file_edit.setText(str(run_file_path)) + self.json_preview_box.setPlainText(save_preview_text(config.to_dict())) + + def _current_config( + self, + project_dir: Path, + ): + input_text = self.input_dir_edit.text().strip() + if not input_text: + raise ValueError("Choose an input folder before saving.") + output_text = self.output_dir_edit.text().strip() + settings = RepresentativeFinderSettings( + selection_algorithm=str( + self.algorithm_combo.currentData() + or "target_distribution_quantile_distance" + ), + bond_weight=float(self.bond_weight_spin.value()), + angle_weight=float(self.angle_weight_spin.value()), + solvent_weight=float(self.solvent_weight_spin.value()), + generate_predicted_optimized_representative=bool( + self.generate_predicted_checkbox.isChecked() + ), + parallel_workers=int(self.worker_spin.value()), + bond_pairs=self._read_bond_pairs(), + angle_triplets=self._read_angle_triplets(), + ) + return build_representativefinder_run_config( + project_dir=project_dir, + input_dir=input_text, + output_dir=output_text or None, + analysis_mode=self._analysis_mode(), + settings=settings, + selected_stoichiometry=self._selected_stoichiometry(), + overwrite_existing=bool( + self.overwrite_existing_checkbox.isChecked() + ), + ) + + def _read_bond_pairs(self) -> tuple[BondPairDefinition, ...]: + definitions: list[BondPairDefinition] = [] + for raw in self.bond_pairs_edit.toPlainText().splitlines(): + text = raw.strip() + if not text: + continue + parts = [part.strip() for part in text.split(":")] + if len(parts) != 3: + raise ValueError("Bond-pair rows must use ATOM1:ATOM2:CUTOFF.") + definitions.append( + BondPairDefinition(parts[0], parts[1], float(parts[2])) + ) + return tuple(definitions) + + def _read_angle_triplets(self) -> tuple[AngleTripletDefinition, ...]: + definitions: list[AngleTripletDefinition] = [] + for raw in self.angle_triplets_edit.toPlainText().splitlines(): + text = raw.strip() + if not text: + continue + parts = [part.strip() for part in text.split(":")] + if len(parts) != 5: + raise ValueError( + "Angle-triplet rows must use " + "VERTEX:ARM1:ARM2:CUTOFF1:CUTOFF2." + ) + definitions.append( + AngleTripletDefinition( + parts[0], + parts[1], + parts[2], + float(parts[3]), + float(parts[4]), + ) + ) + return tuple(definitions) + + def _analysis_mode(self) -> str: + return str(self.analysis_mode_combo.currentData() or "all") + + def _selected_stoichiometry(self) -> str | None: + if self._analysis_mode() != "single": + return None + payload = self.stoichiometry_combo.currentData() + if payload is None: + return None + return str(payload) + + def _project_dir(self) -> Path | None: + text = self.project_dir_edit.text().strip() + if not text: + return None + return Path(text).expanduser().resolve() + + def _require_project_dir(self) -> Path: + project_dir = self._project_dir() + if project_dir is None: + raise ValueError("Choose a project folder before saving.") + if not project_dir.is_dir(): + raise ValueError(f"Project folder does not exist: {project_dir}") + return project_dir + + @staticmethod + def _project_clusters_dir(project_dir: Path) -> Path | None: + try: + settings = SAXSProjectManager().load_project(project_dir) + except Exception: + return None + return settings.resolved_clusters_dir + + @staticmethod + def _new_float_spin(*, value: float) -> QDoubleSpinBox: + spin = QDoubleSpinBox() + spin.setRange(0.0, 100.0) + spin.setSingleStep(0.1) + spin.setDecimals(3) + spin.setValue(value) + return spin + + +def save_preview_text(payload: dict[str, object]) -> str: + import json + + return json.dumps(payload, indent=2) + + +def launch_representativefinder_run_file_ui( + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, +) -> RepresentativeFinderRunFileWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = RepresentativeFinderRunFileWindow( + initial_project_dir=initial_project_dir, + initial_input_path=initial_input_path, + ) + window.show() + window.raise_() + return window + + +__all__ = [ + "RepresentativeFinderRunFileWindow", + "launch_representativefinder_run_file_ui", +] diff --git a/src/saxshell/representativefinder/workflow.py b/src/saxshell/representativefinder/workflow.py new file mode 100644 index 0000000..a42f423 --- /dev/null +++ b/src/saxshell/representativefinder/workflow.py @@ -0,0 +1,3896 @@ +from __future__ import annotations + +import csv +import json +import os +import re +import shutil +from collections import Counter, defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Callable + +import numpy as np + +from saxshell.bondanalysis.bondanalyzer import ( + AngleTripletDefinition, + BondAnalyzer, + BondPairDefinition, +) +from saxshell.saxs.contrast.descriptors import ( + ParsedContrastStructure, + describe_parsed_contrast_structure, + estimate_pair_contact_distance_medians, +) +from saxshell.saxs.debye import load_structure_file +from saxshell.saxs.stoichiometry import parse_stoich_label + +_STRUCTURE_SUFFIXES = {".pdb", ".xyz"} +_DEFAULT_QUANTILES = tuple(np.linspace(0.0, 1.0, 11).tolist()) +RepresentativeFinderProgressCallback = Callable[[int, int, str], None] +RepresentativeFinderLogCallback = Callable[[str], None] +RepresentativeFinderCancelCallback = Callable[[], bool] + + +class RepresentativeFinderOperationCancelled(RuntimeError): + """Raised when representative-structure analysis is canceled.""" + + +def _emit_progress( + callback: RepresentativeFinderProgressCallback | None, + processed: int, + total: int, + message: str, +) -> None: + if callback is None: + return + callback( + max(int(processed), 0), + max(int(total), 1), + str(message).strip(), + ) + + +def _emit_log( + callback: RepresentativeFinderLogCallback | None, + message: str, +) -> None: + if callback is None: + return + text = str(message).strip() + if text: + callback(text) + + +def _raise_if_cancelled( + cancel_callback: RepresentativeFinderCancelCallback | None, +) -> None: + if cancel_callback is not None and cancel_callback(): + raise RepresentativeFinderOperationCancelled( + "Representative-structure analysis canceled." + ) + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderSettings: + selection_algorithm: str = "target_distribution_quantile_distance" + bond_weight: float = 1.0 + angle_weight: float = 1.0 + solvent_weight: float = 1.0 + generate_predicted_optimized_representative: bool = False + parallel_workers: int = 0 + quantiles: tuple[float, ...] = _DEFAULT_QUANTILES + bond_pairs: tuple[BondPairDefinition, ...] = () + angle_triplets: tuple[AngleTripletDefinition, ...] = () + + def __post_init__(self) -> None: + object.__setattr__( + self, + "parallel_workers", + max(int(self.parallel_workers), 0), + ) + + def to_dict(self) -> dict[str, object]: + return { + "selection_algorithm": self.selection_algorithm, + "bond_weight": self.bond_weight, + "angle_weight": self.angle_weight, + "solvent_weight": self.solvent_weight, + "generate_predicted_optimized_representative": bool( + self.generate_predicted_optimized_representative + ), + "parallel_workers": int(self.parallel_workers), + "quantiles": list(self.quantiles), + "bond_pairs": [ + definition.to_dict() for definition in self.bond_pairs + ], + "angle_triplets": [ + definition.to_dict() for definition in self.angle_triplets + ], + } + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderFolderCandidate: + file_path: Path + relative_label: str + motif_label: str + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderFolderInspection: + input_dir: Path + structure_label: str + candidate_count: int + direct_file_count: int + motif_labels: tuple[str, ...] + candidate_labels: tuple[str, ...] + + def summary_text(self) -> str: + lines = [ + f"Input folder: {self.input_dir}", + f"Structure label: {self.structure_label}", + f"Candidate files: {self.candidate_count}", + f"Direct files: {self.direct_file_count}", + "Motif folders: " + + (", ".join(self.motif_labels) if self.motif_labels else "none"), + ] + if self.candidate_labels: + lines.extend(["", "Candidates"]) + lines.extend(f" {label}" for label in self.candidate_labels[:12]) + if len(self.candidate_labels) > 12: + remaining = len(self.candidate_labels) - 12 + lines.append(f" ... and {remaining} more") + return "\n".join(lines) + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderInputInspection: + input_dir: Path + input_is_stoichiometry_folder: bool + stoichiometry_folders: tuple[RepresentativeFinderFolderInspection, ...] + + @property + def stoichiometry_count(self) -> int: + return len(self.stoichiometry_folders) + + @property + def total_candidate_count(self) -> int: + return sum( + inspection.candidate_count + for inspection in self.stoichiometry_folders + ) + + def summary_text(self) -> str: + mode_label = ( + "single stoichiometry folder" + if self.input_is_stoichiometry_folder + else "stoichiometry collection folder" + ) + lines = [ + f"Input folder: {self.input_dir}", + f"Discovery mode: {mode_label}", + f"Discovered stoichiometries: {self.stoichiometry_count}", + f"Total candidate files: {self.total_candidate_count}", + ] + if self.stoichiometry_folders: + lines.extend(["", "Stoichiometries"]) + for inspection in self.stoichiometry_folders[:12]: + motif_text = ( + ", ".join(inspection.motif_labels) + if inspection.motif_labels + else "none" + ) + lines.append( + " " + f"{inspection.structure_label}: " + f"{inspection.candidate_count} candidate(s), " + f"motifs={motif_text}" + ) + if len(self.stoichiometry_folders) > 12: + remaining = len(self.stoichiometry_folders) - 12 + lines.append(f" ... and {remaining} more") + return "\n".join(lines) + + +@dataclass(slots=True) +class RepresentativeFinderCandidate: + file_path: Path + relative_label: str + motif_label: str + atom_count: int + element_counts: dict[str, int] + bond_values: dict[BondPairDefinition, list[float]] + angle_values: dict[AngleTripletDefinition, list[float]] + solvent_metrics: dict[str, float] + solvent_atom_count: int + direct_solvent_atom_count: int + outer_solvent_atom_count: int + mean_direct_solvent_coordination: float + descriptor_notes: tuple[str, ...] = () + score_total: float | None = None + score_bond: float | None = None + score_angle: float | None = None + score_solvent: float | None = None + + @property + def file_name(self) -> str: + return self.file_path.name + + def score_sort_key(self) -> tuple[float, float, float, str]: + return ( + float( + self.score_total if self.score_total is not None else np.inf + ), + float( + self.score_solvent if self.score_solvent is not None else 0.0 + ), + float(self.score_bond if self.score_bond is not None else 0.0), + str(self.file_path), + ) + + def to_dict(self) -> dict[str, object]: + return { + "file_path": str(self.file_path), + "relative_label": self.relative_label, + "motif_label": self.motif_label, + "atom_count": self.atom_count, + "element_counts": dict(sorted(self.element_counts.items())), + "bond_values": _definition_value_map_to_list(self.bond_values), + "angle_values": _definition_value_map_to_list(self.angle_values), + "solvent_metrics": dict(sorted(self.solvent_metrics.items())), + "solvent_atom_count": int(self.solvent_atom_count), + "direct_solvent_atom_count": int(self.direct_solvent_atom_count), + "outer_solvent_atom_count": int(self.outer_solvent_atom_count), + "mean_direct_solvent_coordination": float( + self.mean_direct_solvent_coordination + ), + "descriptor_notes": list(self.descriptor_notes), + "score_total": _optional_float(self.score_total), + "score_bond": _optional_float(self.score_bond), + "score_angle": _optional_float(self.score_angle), + "score_solvent": _optional_float(self.score_solvent), + } + + +@dataclass(slots=True, frozen=True) +class RepresentativeFinderPlotSeries: + category: str + display_label: str + xlabel: str + distribution_values: np.ndarray + candidate_values: tuple[float, ...] + + +@dataclass(slots=True) +class RepresentativeFinderResult: + input_dir: Path + output_dir: Path + structure_label: str + expected_core_counts: dict[str, int] + settings: RepresentativeFinderSettings + generated_at: str + candidates: tuple[RepresentativeFinderCandidate, ...] + selected_candidate: RepresentativeFinderCandidate + representative_output_path: Path + skipped_files: tuple[str, ...] + target_bond_values: dict[BondPairDefinition, np.ndarray] + target_angle_values: dict[AngleTripletDefinition, np.ndarray] + target_solvent_metrics: dict[str, float] + summary_json_path: Path + score_table_path: Path + summary_text_path: Path + predicted_candidate: RepresentativeFinderCandidate | None = None + predicted_output_path: Path | None = None + solvent_completed_predicted_candidate: ( + RepresentativeFinderCandidate | None + ) = None + solvent_completed_predicted_output_path: Path | None = None + predicted_generation_notes: tuple[str, ...] = () + + def summary_text(self) -> str: + lines = [ + "Representative structure selection complete", + f"Generated at: {self.generated_at}", + f"Input folder: {self.input_dir}", + f"Output folder: {self.output_dir}", + f"Stoichiometry label: {self.structure_label}", + f"Candidate files analyzed: {len(self.candidates)}", + f"Skipped files: {len(self.skipped_files)}", + f"Selection algorithm: {self.settings.selection_algorithm}", + ( + "Weights: " + f"bond={self.settings.bond_weight:.3g}, " + f"angle={self.settings.angle_weight:.3g}, " + f"solvent={self.settings.solvent_weight:.3g}" + ), + "", + "Selected representative", + f" File: {self.selected_candidate.file_name}", + f" Source: {self.selected_candidate.relative_label}", + ( + " Scores: " + f"total={_format_score(self.selected_candidate.score_total)}, " + f"bond={_format_score(self.selected_candidate.score_bond)}, " + f"angle={_format_score(self.selected_candidate.score_angle)}, " + f"solvent={_format_score(self.selected_candidate.score_solvent)}" + ), + ( + " Solvent shell: " + f"total={self.selected_candidate.solvent_atom_count}, " + f"direct={self.selected_candidate.direct_solvent_atom_count}, " + f"outer={self.selected_candidate.outer_solvent_atom_count}" + ), + f" Copied output: {self.representative_output_path}", + ] + if self.selected_candidate.descriptor_notes: + lines.extend(["", "Representative notes"]) + lines.extend( + f" {note}" + for note in self.selected_candidate.descriptor_notes + ) + if self.predicted_output_path is not None: + predicted_candidate = self.predicted_candidate + lines.extend( + [ + "", + "Predicted optimized representative", + " File: " + + ( + predicted_candidate.file_name + if predicted_candidate is not None + else self.predicted_output_path.name + ), + " Output: " + str(self.predicted_output_path), + ] + ) + if predicted_candidate is not None: + lines.append( + " Scores: " + f"total={_format_score(predicted_candidate.score_total)}, " + f"bond={_format_score(predicted_candidate.score_bond)}, " + f"angle={_format_score(predicted_candidate.score_angle)}, " + f"solvent={_format_score(predicted_candidate.score_solvent)}" + ) + if self.solvent_completed_predicted_output_path is not None: + completed_candidate = self.solvent_completed_predicted_candidate + lines.extend( + [ + "", + "Solvent-completed predicted representative", + " File: " + + ( + completed_candidate.file_name + if completed_candidate is not None + else self.solvent_completed_predicted_output_path.name + ), + " Output: " + + str(self.solvent_completed_predicted_output_path), + ] + ) + if completed_candidate is not None: + lines.append( + " Scores: " + f"total={_format_score(completed_candidate.score_total)}, " + f"bond={_format_score(completed_candidate.score_bond)}, " + f"angle={_format_score(completed_candidate.score_angle)}, " + f"solvent={_format_score(completed_candidate.score_solvent)}" + ) + if self.predicted_generation_notes: + lines.extend(["", "Predicted representative notes"]) + lines.extend( + f" {note}" for note in self.predicted_generation_notes + ) + if self.skipped_files: + lines.extend(["", "Skipped files"]) + lines.extend(f" {line}" for line in self.skipped_files[:12]) + if len(self.skipped_files) > 12: + lines.append(f" ... and {len(self.skipped_files) - 12} more") + return "\n".join(lines) + + def to_dict(self) -> dict[str, object]: + return { + "version": 3, + "generated_at": self.generated_at, + "input_dir": str(self.input_dir), + "output_dir": str(self.output_dir), + "structure_label": self.structure_label, + "expected_core_counts": dict( + sorted(self.expected_core_counts.items()) + ), + "settings": self.settings.to_dict(), + "representative_output_path": str(self.representative_output_path), + "predicted_output_path": ( + None + if self.predicted_output_path is None + else str(self.predicted_output_path) + ), + "solvent_completed_predicted_output_path": ( + None + if self.solvent_completed_predicted_output_path is None + else str(self.solvent_completed_predicted_output_path) + ), + "summary_json_path": str(self.summary_json_path), + "score_table_path": str(self.score_table_path), + "summary_text_path": str(self.summary_text_path), + "target_solvent_metrics": dict( + sorted(self.target_solvent_metrics.items()) + ), + "target_bond_values": _definition_value_map_to_list( + self.target_bond_values + ), + "target_angle_values": _definition_value_map_to_list( + self.target_angle_values + ), + "candidates": [ + candidate.to_dict() for candidate in self.candidates + ], + "selected_candidate": self.selected_candidate.to_dict(), + "predicted_candidate": ( + None + if self.predicted_candidate is None + else self.predicted_candidate.to_dict() + ), + "solvent_completed_predicted_candidate": ( + None + if self.solvent_completed_predicted_candidate is None + else self.solvent_completed_predicted_candidate.to_dict() + ), + "predicted_generation_notes": list( + self.predicted_generation_notes + ), + "skipped_files": list(self.skipped_files), + } + + def plot_series_for_candidate( + self, + candidate: RepresentativeFinderCandidate | None = None, + ) -> tuple[RepresentativeFinderPlotSeries, ...]: + active_candidate = candidate or self.selected_candidate + bond_series = tuple( + RepresentativeFinderPlotSeries( + category="bond", + display_label=definition.display_label, + xlabel=f"{definition.display_label} distance (Angstrom)", + distribution_values=np.asarray( + self.target_bond_values.get( + definition, + np.array([], dtype=float), + ), + dtype=float, + ), + candidate_values=_line_values( + active_candidate.bond_values.get(definition, []) + ), + ) + for definition in self.settings.bond_pairs + ) + angle_series = tuple( + RepresentativeFinderPlotSeries( + category="angle", + display_label=definition.display_label, + xlabel=f"{definition.display_label} angle (deg)", + distribution_values=np.asarray( + self.target_angle_values.get( + definition, + np.array([], dtype=float), + ), + dtype=float, + ), + candidate_values=_line_values( + active_candidate.angle_values.get(definition, []) + ), + ) + for definition in self.settings.angle_triplets + ) + return bond_series + angle_series + + +def representativefinder_settings_from_dict( + payload: object, +) -> RepresentativeFinderSettings: + source = dict(payload) if isinstance(payload, dict) else {} + quantile_values = source.get("quantiles") or _DEFAULT_QUANTILES + quantiles = tuple(float(value) for value in quantile_values) + return RepresentativeFinderSettings( + selection_algorithm=str( + source.get( + "selection_algorithm", + "target_distribution_quantile_distance", + ) + ).strip() + or "target_distribution_quantile_distance", + bond_weight=_float_from_payload(source.get("bond_weight"), 1.0), + angle_weight=_float_from_payload(source.get("angle_weight"), 1.0), + solvent_weight=_float_from_payload(source.get("solvent_weight"), 1.0), + generate_predicted_optimized_representative=bool( + source.get("generate_predicted_optimized_representative", False) + ), + parallel_workers=_int_from_payload(source.get("parallel_workers"), 0), + quantiles=quantiles or _DEFAULT_QUANTILES, + bond_pairs=tuple( + _bond_pair_definition_from_dict(entry) + for entry in source.get("bond_pairs", []) + if isinstance(entry, dict) + ), + angle_triplets=tuple( + _angle_triplet_definition_from_dict(entry) + for entry in source.get("angle_triplets", []) + if isinstance(entry, dict) + ), + ) + + +def load_representativefinder_result( + result_json_path: str | Path, +) -> RepresentativeFinderResult: + path = Path(result_json_path).expanduser().resolve() + payload = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise ValueError( + f"Representative finder result must be a JSON object: {path}" + ) + return representativefinder_result_from_dict(payload, source_path=path) + + +def representativefinder_result_from_dict( + payload: dict[str, object], + *, + source_path: str | Path | None = None, +) -> RepresentativeFinderResult: + source = dict(payload) + source_result_path = ( + None + if source_path is None + else Path(source_path).expanduser().resolve() + ) + settings = representativefinder_settings_from_dict(source.get("settings")) + candidates = tuple( + _candidate_from_dict(candidate_payload, settings=settings) + for candidate_payload in (source.get("candidates", []) or []) + if isinstance(candidate_payload, dict) + ) + selected_payload = source.get("selected_candidate") + selected_candidate = ( + _matching_candidate_from_payload(candidates, selected_payload) + if isinstance(selected_payload, dict) + else None + ) + if selected_candidate is None and isinstance(selected_payload, dict): + selected_candidate = _candidate_from_dict( + selected_payload, + settings=settings, + ) + if selected_candidate is None and candidates: + selected_candidate = candidates[0] + if selected_candidate is None: + raise ValueError( + "Representative finder result has no selected candidate." + ) + if not candidates: + candidates = (selected_candidate,) + + predicted_payload = source.get("predicted_candidate") + predicted_candidate = ( + _candidate_from_dict(predicted_payload, settings=settings) + if isinstance(predicted_payload, dict) + else None + ) + solvent_completed_payload = source.get( + "solvent_completed_predicted_candidate" + ) + solvent_completed_predicted_candidate = ( + _candidate_from_dict(solvent_completed_payload, settings=settings) + if isinstance(solvent_completed_payload, dict) + else None + ) + + summary_json_path = _path_from_payload( + source.get("summary_json_path"), + fallback=source_result_path, + ) + output_dir = _path_from_payload( + source.get("output_dir"), + fallback=( + summary_json_path.parent + if summary_json_path is not None + else Path.cwd() + ), + ) + return RepresentativeFinderResult( + input_dir=_path_from_payload( + source.get("input_dir"), fallback=Path.cwd() + ), + output_dir=output_dir, + structure_label=str(source.get("structure_label", "")).strip() + or output_dir.name, + expected_core_counts={ + str(element): int(count) + for element, count in dict( + source.get("expected_core_counts", {}) or {} + ).items() + }, + settings=settings, + generated_at=str(source.get("generated_at", "")).strip(), + candidates=candidates, + selected_candidate=selected_candidate, + representative_output_path=_path_from_payload( + source.get("representative_output_path"), + fallback=selected_candidate.file_path, + ), + skipped_files=tuple( + str(item) for item in (source.get("skipped_files", []) or []) + ), + target_bond_values=_definition_value_map_from_list( + source.get("target_bond_values"), + category="bond", + array_values=True, + ), + target_angle_values=_definition_value_map_from_list( + source.get("target_angle_values"), + category="angle", + array_values=True, + ), + target_solvent_metrics=_float_mapping_from_payload( + source.get("target_solvent_metrics") + ), + summary_json_path=( + summary_json_path or output_dir / "representative_selection.json" + ), + score_table_path=_path_from_payload( + source.get("score_table_path"), + fallback=output_dir / "candidate_scores.tsv", + ), + summary_text_path=_path_from_payload( + source.get("summary_text_path"), + fallback=output_dir / "selection_summary.txt", + ), + predicted_candidate=predicted_candidate, + predicted_output_path=_optional_path_from_payload( + source.get("predicted_output_path") + ), + solvent_completed_predicted_candidate=( + solvent_completed_predicted_candidate + ), + solvent_completed_predicted_output_path=_optional_path_from_payload( + source.get("solvent_completed_predicted_output_path") + ), + predicted_generation_notes=tuple( + str(note) + for note in (source.get("predicted_generation_notes", []) or []) + ), + ) + + +def _candidate_from_dict( + payload: dict[str, object], + *, + settings: RepresentativeFinderSettings, +) -> RepresentativeFinderCandidate: + source = dict(payload) + return RepresentativeFinderCandidate( + file_path=_path_from_payload( + source.get("file_path"), fallback=Path.cwd() + ), + relative_label=str(source.get("relative_label", "")).strip(), + motif_label=str(source.get("motif_label", "no_motif")).strip() + or "no_motif", + atom_count=_int_from_payload(source.get("atom_count"), 0), + element_counts={ + str(element): int(count) + for element, count in dict( + source.get("element_counts", {}) or {} + ).items() + }, + bond_values=_definition_value_map_from_list( + source.get("bond_values"), + category="bond", + array_values=False, + ), + angle_values=_definition_value_map_from_list( + source.get("angle_values"), + category="angle", + array_values=False, + ), + solvent_metrics=_float_mapping_from_payload( + source.get("solvent_metrics") + ), + solvent_atom_count=_int_from_payload( + source.get("solvent_atom_count"), 0 + ), + direct_solvent_atom_count=_int_from_payload( + source.get("direct_solvent_atom_count"), + 0, + ), + outer_solvent_atom_count=_int_from_payload( + source.get("outer_solvent_atom_count"), + 0, + ), + mean_direct_solvent_coordination=_float_from_payload( + source.get("mean_direct_solvent_coordination"), + 0.0, + ), + descriptor_notes=tuple( + str(note) for note in (source.get("descriptor_notes", []) or []) + ), + score_total=_optional_float_from_payload(source.get("score_total")), + score_bond=_optional_float_from_payload(source.get("score_bond")), + score_angle=_optional_float_from_payload(source.get("score_angle")), + score_solvent=_optional_float_from_payload( + source.get("score_solvent") + ), + ) + + +def _matching_candidate_from_payload( + candidates: tuple[RepresentativeFinderCandidate, ...], + payload: object, +) -> RepresentativeFinderCandidate | None: + if not isinstance(payload, dict): + return None + target_path = str(payload.get("file_path", "")).strip() + target_relative_label = str(payload.get("relative_label", "")).strip() + for candidate in candidates: + if target_path and str(candidate.file_path) == target_path: + return candidate + if ( + target_relative_label + and candidate.relative_label == target_relative_label + ): + return candidate + return None + + +def _definition_value_map_to_list( + values_by_definition, +) -> list[dict[str, object]]: + rows: list[dict[str, object]] = [] + for definition, values in values_by_definition.items(): + rows.append( + { + "definition": definition.to_dict(), + "values": _float_list(values), + } + ) + return rows + + +def _definition_value_map_from_list( + payload: object, + *, + category: str, + array_values: bool, +): + if not isinstance(payload, list): + return {} + values_by_definition = {} + for entry in payload: + if not isinstance(entry, dict): + continue + definition_payload = entry.get("definition") + if not isinstance(definition_payload, dict): + continue + definition = ( + _bond_pair_definition_from_dict(definition_payload) + if category == "bond" + else _angle_triplet_definition_from_dict(definition_payload) + ) + values = _float_list(entry.get("values", [])) + values_by_definition[definition] = ( + np.asarray(values, dtype=float) if array_values else values + ) + return values_by_definition + + +def _bond_pair_definition_from_dict( + payload: dict[str, object] +) -> BondPairDefinition: + return BondPairDefinition( + str(payload["atom1"]), + str(payload["atom2"]), + float(payload["cutoff_angstrom"]), + ) + + +def _angle_triplet_definition_from_dict( + payload: dict[str, object], +) -> AngleTripletDefinition: + return AngleTripletDefinition( + str(payload["vertex"]), + str(payload["arm1"]), + str(payload["arm2"]), + float(payload["cutoff1_angstrom"]), + float(payload["cutoff2_angstrom"]), + ) + + +def _float_mapping_from_payload(payload: object) -> dict[str, float]: + if not isinstance(payload, dict): + return {} + return { + str(key): float(value) + for key, value in payload.items() + if value is not None + } + + +def _float_list(values: object) -> list[float]: + if values is None: + return [] + array = np.asarray(values, dtype=float) + if array.size <= 0: + return [] + return [float(value) for value in array.reshape(-1).tolist()] + + +def _path_from_payload(value: object, *, fallback: Path | None) -> Path: + text = str(value or "").strip() + if not text: + if fallback is None: + return Path.cwd() + return Path(fallback).expanduser().resolve() + return Path(text).expanduser().resolve() + + +def _optional_path_from_payload(value: object) -> Path | None: + text = str(value or "").strip() + if not text or text.lower() == "none": + return None + return Path(text).expanduser().resolve() + + +def _optional_float_from_payload(value: object) -> float | None: + if value is None: + return None + text = str(value).strip() + if not text or text.lower() in {"none", "n/a", "nan"}: + return None + return float(text) + + +def _float_from_payload(value: object, default: float) -> float: + parsed = _optional_float_from_payload(value) + return float(default) if parsed is None else parsed + + +def _int_from_payload(value: object, default: int) -> int: + if value is None: + return int(default) + text = str(value).strip() + if not text: + return int(default) + return int(text) + + +@dataclass(slots=True) +class _MeasuredCandidateStructure: + candidate: RepresentativeFinderCandidate + coordinates: np.ndarray + elements: tuple[str, ...] + parsed_structure: ParsedContrastStructure | None + + +@dataclass(slots=True, frozen=True) +class _SolventDescriptorMeasurement: + candidate: RepresentativeFinderCandidate + solvent_metrics: dict[str, float] + solvent_atom_count: int + direct_solvent_atom_count: int + outer_solvent_atom_count: int + mean_direct_solvent_coordination: float + descriptor_notes: tuple[str, ...] + + +@dataclass(slots=True, frozen=True) +class _CandidateScoreMeasurement: + candidate: RepresentativeFinderCandidate + score_total: float + score_bond: float + score_angle: float + score_solvent: float + + +def inspect_representative_structure_folder( + input_dir: str | Path, +) -> RepresentativeFinderFolderInspection: + resolved_input_dir = Path(input_dir).expanduser().resolve() + candidates = _discover_candidate_files(resolved_input_dir) + motif_labels = tuple( + sorted( + { + candidate.motif_label + for candidate in candidates + if candidate.motif_label != "no_motif" + }, + key=_natural_sort_key, + ) + ) + direct_file_count = sum( + 1 for candidate in candidates if candidate.motif_label == "no_motif" + ) + return RepresentativeFinderFolderInspection( + input_dir=resolved_input_dir, + structure_label=resolved_input_dir.name, + candidate_count=len(candidates), + direct_file_count=direct_file_count, + motif_labels=motif_labels, + candidate_labels=tuple( + candidate.relative_label for candidate in candidates + ), + ) + + +def inspect_representative_structure_input( + input_dir: str | Path, +) -> RepresentativeFinderInputInspection: + resolved_input_dir = Path(input_dir).expanduser().resolve() + if not resolved_input_dir.is_dir(): + raise ValueError( + f"Input directory does not exist: {resolved_input_dir}" + ) + direct_inspection = inspect_representative_structure_folder( + resolved_input_dir + ) + if direct_inspection.candidate_count > 0: + return RepresentativeFinderInputInspection( + input_dir=resolved_input_dir, + input_is_stoichiometry_folder=True, + stoichiometry_folders=(direct_inspection,), + ) + + stoichiometry_folders = tuple( + inspection + for inspection in ( + inspect_representative_structure_folder(child) + for child in sorted( + [ + child + for child in resolved_input_dir.iterdir() + if child.is_dir() + ], + key=lambda path: _natural_sort_key(path.name), + ) + ) + if inspection.candidate_count > 0 + ) + if not stoichiometry_folders: + raise ValueError( + "No representative-structure candidate folders were found in " + f"{resolved_input_dir}. Choose a stoichiometry folder directly, or " + "choose a parent folder whose immediate subfolders are " + "stoichiometries containing .xyz/.pdb cluster files." + ) + return RepresentativeFinderInputInspection( + input_dir=resolved_input_dir, + input_is_stoichiometry_folder=False, + stoichiometry_folders=stoichiometry_folders, + ) + + +def analyze_representative_structure_folder( + input_dir: str | Path, + *, + settings: RepresentativeFinderSettings, + output_dir: str | Path | None = None, + project_dir: str | Path | None = None, + progress_callback: RepresentativeFinderProgressCallback | None = None, + log_callback: RepresentativeFinderLogCallback | None = None, + cancel_callback: RepresentativeFinderCancelCallback | None = None, +) -> RepresentativeFinderResult: + resolved_input_dir = Path(input_dir).expanduser().resolve() + if not resolved_input_dir.is_dir(): + raise ValueError( + f"Representative input directory does not exist: {resolved_input_dir}" + ) + candidates_to_measure = _discover_candidate_files(resolved_input_dir) + if not candidates_to_measure: + raise ValueError( + "No candidate .xyz or .pdb files were found in the selected " + f"folder: {resolved_input_dir}" + ) + candidate_count = len(candidates_to_measure) + solvent_phase_enabled = settings.solvent_weight > 0.0 + predicted_phase_enabled = bool( + settings.generate_predicted_optimized_representative + ) + predicted_solvent_phase_enabled = ( + predicted_phase_enabled and project_dir is not None + ) + total_work = estimate_representativefinder_total_work( + candidate_count, + solvent_phase_enabled=solvent_phase_enabled, + predicted_phase_enabled=predicted_phase_enabled, + predicted_solvent_phase_enabled=predicted_solvent_phase_enabled, + ) + + resolved_output_dir = ( + Path(output_dir).expanduser().resolve() + if output_dir is not None + else suggest_representativefinder_output_dir(resolved_input_dir) + ) + resolved_output_dir.mkdir(parents=True, exist_ok=True) + analyzer = BondAnalyzer( + bond_pairs=settings.bond_pairs, + angle_triplets=settings.angle_triplets, + ) + expected_core_counts = parse_stoich_label(resolved_input_dir.name) + _emit_log( + log_callback, + f"Scanning {candidate_count} candidate structure file(s).", + ) + _emit_progress( + progress_callback, + 0, + total_work, + "Preparing representative-structure analysis...", + ) + _raise_if_cancelled(cancel_callback) + + parallel_workers = _effective_parallel_workers( + settings.parallel_workers, + candidate_count, + ) + processed_work = 0 + resolved_project_dir = ( + None + if project_dir is None + else Path(project_dir).expanduser().resolve() + ) + + measured_structures, skipped_files, processed_work = ( + _measure_candidate_entries( + candidates_to_measure, + analyzer=analyzer, + include_parsed_structure=settings.solvent_weight > 0.0, + parallel_workers=parallel_workers, + progress_callback=progress_callback, + log_callback=log_callback, + cancel_callback=cancel_callback, + processed_work=processed_work, + total_work=total_work, + ) + ) + measured_candidates = [ + measured.candidate for measured in measured_structures + ] + parsed_structures = [ + parsed_structure + for measured in measured_structures + for parsed_structure in (measured.parsed_structure,) + if parsed_structure is not None + ] + parsed_cache = { + measured.candidate.file_path: measured.parsed_structure + for measured in measured_structures + if measured.parsed_structure is not None + } + + if not measured_candidates: + raise ValueError( + "No valid candidate structures could be measured in the selected " + "folder." + ) + + if _single_atom_shortcut_applies(measured_candidates): + shortcut_note = ( + "Single-atom candidate structures were detected; bond, angle, " + "and solvent-distribution scoring was skipped." + ) + for candidate in measured_candidates: + candidate.score_bond = 0.0 + candidate.score_angle = 0.0 + candidate.score_solvent = 0.0 + candidate.score_total = 0.0 + candidate.descriptor_notes = (shortcut_note,) + + processed_work = max(processed_work, total_work - 2) + _emit_log( + log_callback, + "Detected a uniform single-atom candidate set; skipping full " + "representative-distribution analysis.", + ) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Single-atom structures detected; selecting representative directly.", + ) + ranked_candidates = tuple( + sorted( + measured_candidates, + key=RepresentativeFinderCandidate.score_sort_key, + ) + ) + selected_candidate = ranked_candidates[0] + + predicted_candidate = None + predicted_output_path = None + solvent_completed_predicted_candidate = None + solvent_completed_predicted_output_path = None + predicted_generation_notes: tuple[str, ...] = () + if settings.generate_predicted_optimized_representative: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Generating predicted optimized representative...", + ) + ( + predicted_candidate, + predicted_output_path, + solvent_completed_predicted_candidate, + solvent_completed_predicted_output_path, + predicted_generation_notes, + ) = _build_optional_predicted_representatives( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + project_dir=resolved_project_dir, + settings=settings, + analyzer=analyzer, + expected_core_counts=expected_core_counts, + measured_candidates=measured_candidates, + measured_structures=measured_structures, + selected_candidate=selected_candidate, + target_bond_features={}, + target_angle_features={}, + target_solvent_metrics={}, + pair_contact_distance_medians=None, + ) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + ( + "Predicted optimized representative ready." + if predicted_output_path is not None + else ( + "Predicted optimized representative unavailable for " + "this single-atom stoichiometry." + ) + ), + ) + if resolved_project_dir is not None: + _raise_if_cancelled(cancel_callback) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + ( + "Predicted solvent-shell completion ready." + if solvent_completed_predicted_output_path is not None + else "Skipping solvent-completed predicted representative." + ), + ) + + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Writing representative outputs...", + ) + representative_output_path = _copy_representative_file( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + selected_candidate=selected_candidate, + ) + result = _build_representativefinder_result( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + expected_core_counts=expected_core_counts, + settings=settings, + candidates=ranked_candidates, + selected_candidate=selected_candidate, + representative_output_path=representative_output_path, + skipped_files=tuple(skipped_files), + target_bond_values={}, + target_angle_values={}, + target_solvent_metrics={}, + predicted_candidate=predicted_candidate, + predicted_output_path=predicted_output_path, + solvent_completed_predicted_candidate=( + solvent_completed_predicted_candidate + ), + solvent_completed_predicted_output_path=( + solvent_completed_predicted_output_path + ), + predicted_generation_notes=predicted_generation_notes, + ) + _raise_if_cancelled(cancel_callback) + _write_outputs(result) + processed_work += 1 + _emit_progress( + progress_callback, + total_work, + total_work, + "Representative-structure selection complete.", + ) + _emit_log( + log_callback, + f"Selected {selected_candidate.file_name} as the representative " + "single-atom structure.", + ) + return result + + if ( + not settings.bond_pairs + and not settings.angle_triplets + and settings.solvent_weight <= 0.0 + ): + raise ValueError( + "Provide at least one bond pair, one angle triplet, or a positive " + "solvent weight before running representative selection." + ) + + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Aggregating bond and angle distributions...", + ) + target_bond_values, target_angle_values = _aggregate_candidate_values( + measured_candidates, + settings.bond_pairs, + settings.angle_triplets, + ) + processed_work += 1 + + target_solvent_metrics: dict[str, float] = {} + pair_contact_distance_medians: dict[tuple[str, str], float] | None = None + if settings.solvent_weight > 0.0 and parsed_structures: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Building solvent contact-distance targets...", + ) + pair_contact_distance_medians = estimate_pair_contact_distance_medians( + tuple(parsed_structures) + ) + processed_work += 1 + for candidate in measured_candidates: + _raise_if_cancelled(cancel_callback) + parsed_structure = parsed_cache.get(candidate.file_path) + if parsed_structure is None: + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"No solvent descriptor for {candidate.file_name}", + ) + continue + processed_work = _apply_solvent_descriptors( + [ + (candidate, parsed_cache[candidate.file_path]) + for candidate in measured_candidates + if candidate.file_path in parsed_cache + ], + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + parallel_workers=parallel_workers, + progress_callback=progress_callback, + cancel_callback=cancel_callback, + processed_work=processed_work, + total_work=total_work, + ) + target_solvent_metrics = _median_summary( + [ + candidate.solvent_metrics + for candidate in measured_candidates + if candidate.solvent_metrics + ] + ) + elif settings.solvent_weight > 0.0: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "No solvent descriptors available; continuing with bond/angle scoring.", + ) + processed_work += 1 + candidate_count + _emit_log( + log_callback, + "No solvent descriptors could be computed; solvent weighting was " + "ignored for this run.", + ) + + if settings.selection_algorithm == "target_distribution_moment_distance": + target_bond_features = { + definition: _moment_feature_vector(values) + for definition, values in target_bond_values.items() + } + target_angle_features = { + definition: _moment_feature_vector(values) + for definition, values in target_angle_values.items() + } + else: + target_bond_features = { + definition: _quantile_feature_vector( + values, + quantiles=settings.quantiles, + ) + for definition, values in target_bond_values.items() + } + target_angle_features = { + definition: _quantile_feature_vector( + values, + quantiles=settings.quantiles, + ) + for definition, values in target_angle_values.items() + } + + processed_work = _score_measured_candidates( + measured_candidates, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + parallel_workers=parallel_workers, + progress_callback=progress_callback, + cancel_callback=cancel_callback, + processed_work=processed_work, + total_work=total_work, + ) + + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Ranking candidate scores...", + ) + ranked_candidates = tuple( + sorted( + measured_candidates, + key=RepresentativeFinderCandidate.score_sort_key, + ) + ) + selected_candidate = ranked_candidates[0] + processed_work += 1 + + predicted_candidate = None + predicted_output_path = None + solvent_completed_predicted_candidate = None + solvent_completed_predicted_output_path = None + predicted_generation_notes: tuple[str, ...] = () + if settings.generate_predicted_optimized_representative: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Generating predicted optimized representative...", + ) + if pair_contact_distance_medians is None and measured_structures: + pair_contact_distance_medians = estimate_pair_contact_distance_medians( + tuple( + ( + measured.parsed_structure + if measured.parsed_structure is not None + else ParsedContrastStructure( + file_path=measured.candidate.file_path, + coordinates=np.asarray( + measured.coordinates, dtype=float + ), + elements=measured.elements, + element_counts=dict( + sorted( + measured.candidate.element_counts.items() + ) + ), + ) + ) + for measured in measured_structures + ) + ) + ( + predicted_candidate, + predicted_output_path, + solvent_completed_predicted_candidate, + solvent_completed_predicted_output_path, + predicted_generation_notes, + ) = _build_optional_predicted_representatives( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + project_dir=resolved_project_dir, + settings=settings, + analyzer=analyzer, + expected_core_counts=expected_core_counts, + measured_candidates=measured_candidates, + measured_structures=measured_structures, + selected_candidate=selected_candidate, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + pair_contact_distance_medians=pair_contact_distance_medians, + ) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + ( + "Predicted optimized representative ready." + if predicted_output_path is not None + else "Predicted optimized representative unavailable." + ), + ) + if resolved_project_dir is not None: + _raise_if_cancelled(cancel_callback) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + ( + "Predicted solvent-shell completion ready." + if solvent_completed_predicted_output_path is not None + else "Skipping solvent-completed predicted representative." + ), + ) + + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Writing representative outputs...", + ) + representative_output_path = _copy_representative_file( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + selected_candidate=selected_candidate, + ) + + result = _build_representativefinder_result( + input_dir=resolved_input_dir, + output_dir=resolved_output_dir, + expected_core_counts=expected_core_counts, + settings=settings, + candidates=ranked_candidates, + selected_candidate=selected_candidate, + representative_output_path=representative_output_path, + skipped_files=tuple(skipped_files), + target_bond_values=target_bond_values, + target_angle_values=target_angle_values, + target_solvent_metrics=target_solvent_metrics, + predicted_candidate=predicted_candidate, + predicted_output_path=predicted_output_path, + solvent_completed_predicted_candidate=( + solvent_completed_predicted_candidate + ), + solvent_completed_predicted_output_path=( + solvent_completed_predicted_output_path + ), + predicted_generation_notes=predicted_generation_notes, + ) + _raise_if_cancelled(cancel_callback) + _write_outputs(result) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + "Representative-structure selection complete.", + ) + _emit_log( + log_callback, + f"Selected {selected_candidate.file_name} as the representative structure.", + ) + return result + + +def estimate_representativefinder_total_work( + candidate_count: int, + *, + solvent_phase_enabled: bool, + predicted_phase_enabled: bool = False, + predicted_solvent_phase_enabled: bool = False, +) -> int: + total_work = int(candidate_count) + 1 + int(candidate_count) + 2 + if solvent_phase_enabled: + total_work += 1 + int(candidate_count) + if predicted_phase_enabled: + total_work += 1 + if predicted_solvent_phase_enabled: + total_work += 1 + return max(total_work, 1) + + +def suggest_representativefinder_output_dir( + input_dir: str | Path, + *, + project_dir: str | Path | None = None, + batch: bool = False, +) -> Path: + source_dir = Path(input_dir).expanduser().resolve() + if project_dir is not None: + root = ( + Path(project_dir).expanduser().resolve() / "representative_finder" + ) + else: + root = source_dir.parent + folder_name = _safe_folder_name(source_dir.name) + prefix = "representativefinder_batch" if batch else "representativefinder" + return _next_available_output_dir( + root, + f"{prefix}_{folder_name}", + ) + + +def suggest_representativefinder_target_output_dir( + output_root_dir: str | Path, + structure_label: str, +) -> Path: + root_dir = Path(output_root_dir).expanduser().resolve() + return _next_available_output_dir( + root_dir, + f"representativefinder_{_safe_folder_name(structure_label)}", + ) + + +def _discover_candidate_files( + input_dir: Path, +) -> tuple[RepresentativeFinderFolderCandidate, ...]: + if not input_dir.is_dir(): + raise ValueError(f"Input directory does not exist: {input_dir}") + direct_files = _structure_files_in_dir(input_dir) + candidates: list[RepresentativeFinderFolderCandidate] = [ + RepresentativeFinderFolderCandidate( + file_path=file_path, + relative_label=file_path.name, + motif_label="no_motif", + ) + for file_path in direct_files + ] + motif_dirs = sorted( + [ + child + for child in input_dir.iterdir() + if child.is_dir() and child.name.startswith("motif_") + ], + key=lambda path: _natural_sort_key(path.name), + ) + for motif_dir in motif_dirs: + for file_path in _structure_files_in_dir(motif_dir): + candidates.append( + RepresentativeFinderFolderCandidate( + file_path=file_path, + relative_label=str(file_path.relative_to(input_dir)), + motif_label=motif_dir.name, + ) + ) + return tuple( + sorted( + candidates, + key=lambda entry: _natural_sort_key(entry.relative_label), + ) + ) + + +def _structure_files_in_dir(directory: Path) -> list[Path]: + return sorted( + [ + file_path + for file_path in directory.iterdir() + if file_path.is_file() + and file_path.suffix.lower() in _STRUCTURE_SUFFIXES + ], + key=lambda path: _natural_sort_key(path.name), + ) + + +def _effective_parallel_workers( + configured_workers: int, + item_count: int, +) -> int: + if int(item_count) <= 1: + return 1 + requested = int(configured_workers) + if requested <= 0: + env_value = os.environ.get("SAXSHELL_REPRESENTATIVEFINDER_WORKERS", "") + if env_value.strip(): + try: + requested = max(int(env_value), 1) + except ValueError: + return 1 + else: + return 1 + return max(1, min(int(item_count), requested, 32)) + + +def _measure_candidate_entries( + entries: tuple[RepresentativeFinderFolderCandidate, ...], + *, + analyzer: BondAnalyzer, + include_parsed_structure: bool, + parallel_workers: int, + progress_callback: RepresentativeFinderProgressCallback | None, + log_callback: RepresentativeFinderLogCallback | None, + cancel_callback: RepresentativeFinderCancelCallback | None, + processed_work: int, + total_work: int, +) -> tuple[list[_MeasuredCandidateStructure], list[str], int]: + measured_by_index: dict[int, _MeasuredCandidateStructure] = {} + skipped_by_index: dict[int, str] = {} + worker_count = _effective_parallel_workers(parallel_workers, len(entries)) + if worker_count <= 1: + for index, entry in enumerate(entries): + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Measuring {entry.file_path.name}", + ) + try: + measured = _measure_candidate_structure_file( + entry.file_path, + relative_label=entry.relative_label, + motif_label=entry.motif_label, + analyzer=analyzer, + include_parsed_structure=include_parsed_structure, + ) + except Exception as exc: + skipped_by_index[index] = f"{entry.relative_label}: {exc}" + _emit_log( + log_callback, + "Skipped unreadable structure " + f"{entry.relative_label}: {exc}", + ) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Skipped {entry.file_path.name}", + ) + continue + measured_by_index[index] = measured + processed_work += 1 + _emit_log( + log_callback, + "Measured " + f"{entry.relative_label} ({len(measured.elements)} atoms).", + ) + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Measured {entry.file_path.name}", + ) + return ( + [measured_by_index[index] for index in sorted(measured_by_index)], + [skipped_by_index[index] for index in sorted(skipped_by_index)], + processed_work, + ) + + _emit_log( + log_callback, + "Measuring " + f"{len(entries)} candidate structure file(s) with {worker_count} " + "worker threads.", + ) + _emit_progress( + progress_callback, + processed_work, + total_work, + "Measuring " + f"{len(entries)} candidate structure file(s) with {worker_count} " + "worker thread(s)...", + ) + with ThreadPoolExecutor( + max_workers=worker_count, + thread_name_prefix="representativefinder-measure", + ) as executor: + futures = {} + try: + for index, entry in enumerate(entries): + _raise_if_cancelled(cancel_callback) + futures[ + executor.submit( + _measure_candidate_structure_file, + entry.file_path, + relative_label=entry.relative_label, + motif_label=entry.motif_label, + analyzer=analyzer, + include_parsed_structure=include_parsed_structure, + ) + ] = (index, entry) + for future in as_completed(futures): + _raise_if_cancelled(cancel_callback) + index, entry = futures[future] + try: + measured = future.result() + except Exception as exc: + skipped_by_index[index] = f"{entry.relative_label}: {exc}" + _emit_log( + log_callback, + "Skipped unreadable structure " + f"{entry.relative_label}: {exc}", + ) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Skipped {entry.file_path.name}", + ) + continue + measured_by_index[index] = measured + processed_work += 1 + _emit_log( + log_callback, + "Measured " + f"{entry.relative_label} ({len(measured.elements)} atoms).", + ) + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Measured {entry.file_path.name}", + ) + except BaseException: + for future in futures: + future.cancel() + raise + return ( + [measured_by_index[index] for index in sorted(measured_by_index)], + [skipped_by_index[index] for index in sorted(skipped_by_index)], + processed_work, + ) + + +def _apply_solvent_descriptors( + candidate_rows: list[ + tuple[RepresentativeFinderCandidate, ParsedContrastStructure] + ], + *, + expected_core_counts: dict[str, int], + pair_contact_distance_medians: dict[tuple[str, str], float], + parallel_workers: int, + progress_callback: RepresentativeFinderProgressCallback | None, + cancel_callback: RepresentativeFinderCancelCallback | None, + processed_work: int, + total_work: int, +) -> int: + worker_count = _effective_parallel_workers( + parallel_workers, + len(candidate_rows), + ) + if worker_count <= 1: + for candidate, parsed_structure in candidate_rows: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Building solvent descriptor for {candidate.file_name}", + ) + measurement = _measure_solvent_descriptor( + candidate, + parsed_structure=parsed_structure, + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + ) + _assign_solvent_descriptor_measurement(measurement) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Built solvent descriptor for {candidate.file_name}", + ) + return processed_work + + _emit_progress( + progress_callback, + processed_work, + total_work, + "Building solvent descriptors for " + f"{len(candidate_rows)} candidate structure(s) with {worker_count} " + "worker thread(s)...", + ) + with ThreadPoolExecutor( + max_workers=worker_count, + thread_name_prefix="representativefinder-solvent", + ) as executor: + futures = {} + try: + for candidate, parsed_structure in candidate_rows: + _raise_if_cancelled(cancel_callback) + futures[ + executor.submit( + _measure_solvent_descriptor, + candidate, + parsed_structure=parsed_structure, + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=( + pair_contact_distance_medians + ), + ) + ] = candidate + for future in as_completed(futures): + _raise_if_cancelled(cancel_callback) + candidate = futures[future] + measurement = future.result() + _assign_solvent_descriptor_measurement(measurement) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Built solvent descriptor for {candidate.file_name}", + ) + except BaseException: + for future in futures: + future.cancel() + raise + return processed_work + + +def _measure_solvent_descriptor( + candidate: RepresentativeFinderCandidate, + *, + parsed_structure: ParsedContrastStructure, + expected_core_counts: dict[str, int], + pair_contact_distance_medians: dict[tuple[str, str], float], +) -> _SolventDescriptorMeasurement: + descriptor = describe_parsed_contrast_structure( + parsed_structure, + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + include_geometry_metrics=False, + ) + return _SolventDescriptorMeasurement( + candidate=candidate, + solvent_metrics=descriptor.solvent_metrics(), + solvent_atom_count=descriptor.solvent_atom_count, + direct_solvent_atom_count=descriptor.direct_solvent_atom_count, + outer_solvent_atom_count=descriptor.outer_solvent_atom_count, + mean_direct_solvent_coordination=float( + descriptor.mean_direct_solvent_coordination + ), + descriptor_notes=tuple(descriptor.notes), + ) + + +def _assign_solvent_descriptor_measurement( + measurement: _SolventDescriptorMeasurement, +) -> None: + candidate = measurement.candidate + candidate.solvent_metrics = measurement.solvent_metrics + candidate.solvent_atom_count = measurement.solvent_atom_count + candidate.direct_solvent_atom_count = measurement.direct_solvent_atom_count + candidate.outer_solvent_atom_count = measurement.outer_solvent_atom_count + candidate.mean_direct_solvent_coordination = ( + measurement.mean_direct_solvent_coordination + ) + candidate.descriptor_notes = measurement.descriptor_notes + + +def _score_measured_candidates( + candidates: list[RepresentativeFinderCandidate], + *, + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], + parallel_workers: int, + progress_callback: RepresentativeFinderProgressCallback | None, + cancel_callback: RepresentativeFinderCancelCallback | None, + processed_work: int, + total_work: int, +) -> int: + worker_count = _effective_parallel_workers( + parallel_workers, len(candidates) + ) + if worker_count <= 1: + for candidate in candidates: + _raise_if_cancelled(cancel_callback) + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Scoring {candidate.file_name}", + ) + measurement = _measure_candidate_score( + candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + ) + _assign_candidate_score_measurement(measurement) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Scored {candidate.file_name}", + ) + return processed_work + + _emit_progress( + progress_callback, + processed_work, + total_work, + "Scoring " + f"{len(candidates)} candidate structure(s) with {worker_count} " + "worker thread(s)...", + ) + with ThreadPoolExecutor( + max_workers=worker_count, + thread_name_prefix="representativefinder-score", + ) as executor: + futures = {} + try: + for candidate in candidates: + _raise_if_cancelled(cancel_callback) + futures[ + executor.submit( + _measure_candidate_score, + candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + ) + ] = candidate + for future in as_completed(futures): + _raise_if_cancelled(cancel_callback) + candidate = futures[future] + measurement = future.result() + _assign_candidate_score_measurement(measurement) + processed_work += 1 + _emit_progress( + progress_callback, + processed_work, + total_work, + f"Scored {candidate.file_name}", + ) + except BaseException: + for future in futures: + future.cancel() + raise + return processed_work + + +def _measure_candidate_score( + candidate: RepresentativeFinderCandidate, + *, + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], +) -> _CandidateScoreMeasurement: + if settings.selection_algorithm == "target_distribution_moment_distance": + bond_score = _category_moment_distance_from_features( + target_bond_features, + candidate.bond_values, + ) + angle_score = _category_moment_distance_from_features( + target_angle_features, + candidate.angle_values, + ) + else: + bond_score = _category_distance_from_features( + target_bond_features, + candidate.bond_values, + quantiles=settings.quantiles, + ) + angle_score = _category_distance_from_features( + target_angle_features, + candidate.angle_values, + quantiles=settings.quantiles, + ) + solvent_score = _score_feature_map( + candidate.solvent_metrics, + target_solvent_metrics, + default_scale=1.0, + ) + total_score = float( + settings.bond_weight * bond_score + + settings.angle_weight * angle_score + + settings.solvent_weight * solvent_score + ) + return _CandidateScoreMeasurement( + candidate=candidate, + score_total=total_score, + score_bond=float(bond_score), + score_angle=float(angle_score), + score_solvent=float(solvent_score), + ) + + +def _assign_candidate_score_measurement( + measurement: _CandidateScoreMeasurement, +) -> None: + candidate = measurement.candidate + candidate.score_bond = measurement.score_bond + candidate.score_angle = measurement.score_angle + candidate.score_solvent = measurement.score_solvent + candidate.score_total = measurement.score_total + + +def _aggregate_candidate_values( + candidates: list[RepresentativeFinderCandidate], + bond_pairs: tuple[BondPairDefinition, ...], + angle_triplets: tuple[AngleTripletDefinition, ...], +) -> tuple[ + dict[BondPairDefinition, np.ndarray], + dict[AngleTripletDefinition, np.ndarray], +]: + target_bonds: dict[BondPairDefinition, np.ndarray] = {} + target_angles: dict[AngleTripletDefinition, np.ndarray] = {} + for definition in bond_pairs: + merged = [ + value + for candidate in candidates + for value in candidate.bond_values.get(definition, []) + ] + target_bonds[definition] = np.asarray(merged, dtype=float) + for definition in angle_triplets: + merged = [ + value + for candidate in candidates + for value in candidate.angle_values.get(definition, []) + ] + target_angles[definition] = np.asarray(merged, dtype=float) + return target_bonds, target_angles + + +def _median_summary(rows: list[dict[str, float]]) -> dict[str, float]: + values_by_key: defaultdict[str, list[float]] = defaultdict(list) + for row in rows: + for key, value in row.items(): + values_by_key[str(key)].append(float(value)) + return { + key: float(np.median(np.asarray(values, dtype=float))) + for key, values in sorted(values_by_key.items()) + if values + } + + +def _score_feature_map( + candidate_values: dict[str, float], + target_values: dict[str, float], + *, + default_scale: float, + missing_penalty: float = 1.0, +) -> float: + if not target_values: + return 0.0 + deltas = [ + _relative_difference( + candidate_values.get(key), + expected_value, + scale_floor=default_scale, + missing_penalty=missing_penalty, + ) + for key, expected_value in sorted(target_values.items()) + ] + if not deltas: + return 0.0 + return float(np.mean(np.asarray(deltas, dtype=float))) + + +def _relative_difference( + observed: float | None, + expected: float, + *, + scale_floor: float = 1.0, + missing_penalty: float = 1.0, +) -> float: + if observed is None: + return float(missing_penalty) + scale = max(abs(float(expected)), float(scale_floor)) + return abs(float(observed) - float(expected)) / scale + + +def _category_distance( + target_values: dict[object, np.ndarray], + candidate_values: dict[object, list[float]], + *, + quantiles: tuple[float, ...], +) -> float: + target_features = { + definition: _quantile_feature_vector(target, quantiles=quantiles) + for definition, target in target_values.items() + } + return _category_distance_from_features( + target_features, + candidate_values, + quantiles=quantiles, + ) + + +def _category_distance_from_features( + target_features: dict[object, np.ndarray], + candidate_values: dict[object, list[float]], + *, + quantiles: tuple[float, ...], +) -> float: + distances: list[float] = [] + for definition, target_vector in target_features.items(): + candidate_vector = _quantile_feature_vector( + candidate_values.get(definition, []), + quantiles=quantiles, + ) + distances.append( + float(np.linalg.norm(candidate_vector - target_vector)) + ) + if not distances: + return 0.0 + return float(np.mean(np.asarray(distances, dtype=float))) + + +def _category_moment_distance( + target_values: dict[object, np.ndarray], + candidate_values: dict[object, list[float]], +) -> float: + target_features = { + definition: _moment_feature_vector(target) + for definition, target in target_values.items() + } + return _category_moment_distance_from_features( + target_features, + candidate_values, + ) + + +def _category_moment_distance_from_features( + target_features: dict[object, np.ndarray], + candidate_values: dict[object, list[float]], +) -> float: + distances: list[float] = [] + for definition, target_vector in target_features.items(): + candidate_vector = _moment_feature_vector( + candidate_values.get(definition, []) + ) + distances.append( + float(np.linalg.norm(candidate_vector - target_vector)) + ) + if not distances: + return 0.0 + return float(np.mean(np.asarray(distances, dtype=float))) + + +def _quantile_feature_vector( + values: list[float] | np.ndarray, + *, + quantiles: tuple[float, ...], +) -> np.ndarray: + array = np.asarray(values, dtype=float) + if array.size <= 0: + return np.zeros(len(quantiles), dtype=float) + return np.asarray(np.quantile(array, quantiles), dtype=float) + + +def _moment_feature_vector(values: list[float] | np.ndarray) -> np.ndarray: + array = np.asarray(values, dtype=float) + if array.size <= 0: + return np.zeros(3, dtype=float) + return np.asarray( + [ + float(np.mean(array)), + float(np.std(array)), + float(np.median(array)), + ], + dtype=float, + ) + + +def _line_values(values: list[float] | np.ndarray) -> tuple[float, ...]: + array = np.asarray(values, dtype=float) + if array.size <= 0: + return () + return tuple(float(value) for value in np.unique(array)) + + +def _single_atom_shortcut_applies( + candidates: list[RepresentativeFinderCandidate], +) -> bool: + if not candidates: + return False + if any(int(candidate.atom_count) != 1 for candidate in candidates): + return False + element_signatures = { + tuple(sorted(candidate.element_counts.items())) + for candidate in candidates + } + return len(element_signatures) == 1 + + +def _build_representativefinder_result( + *, + input_dir: Path, + output_dir: Path, + expected_core_counts: dict[str, int], + settings: RepresentativeFinderSettings, + candidates: tuple[RepresentativeFinderCandidate, ...], + selected_candidate: RepresentativeFinderCandidate, + representative_output_path: Path, + skipped_files: tuple[str, ...], + target_bond_values: dict[BondPairDefinition, np.ndarray], + target_angle_values: dict[AngleTripletDefinition, np.ndarray], + target_solvent_metrics: dict[str, float], + predicted_candidate: RepresentativeFinderCandidate | None = None, + predicted_output_path: Path | None = None, + solvent_completed_predicted_candidate: ( + RepresentativeFinderCandidate | None + ) = None, + solvent_completed_predicted_output_path: Path | None = None, + predicted_generation_notes: tuple[str, ...] = (), +) -> RepresentativeFinderResult: + return RepresentativeFinderResult( + input_dir=input_dir, + output_dir=output_dir, + structure_label=input_dir.name, + expected_core_counts=dict(sorted(expected_core_counts.items())), + settings=settings, + generated_at=datetime.now().isoformat(timespec="seconds"), + candidates=candidates, + selected_candidate=selected_candidate, + representative_output_path=representative_output_path, + skipped_files=skipped_files, + target_bond_values=target_bond_values, + target_angle_values=target_angle_values, + target_solvent_metrics=target_solvent_metrics, + summary_json_path=output_dir / "representative_selection.json", + score_table_path=output_dir / "candidate_scores.tsv", + summary_text_path=output_dir / "selection_summary.txt", + predicted_candidate=predicted_candidate, + predicted_output_path=predicted_output_path, + solvent_completed_predicted_candidate=( + solvent_completed_predicted_candidate + ), + solvent_completed_predicted_output_path=( + solvent_completed_predicted_output_path + ), + predicted_generation_notes=predicted_generation_notes, + ) + + +def _copy_representative_file( + *, + input_dir: Path, + output_dir: Path, + selected_candidate: RepresentativeFinderCandidate, +) -> Path: + relative_label = re.sub( + r"[^0-9A-Za-z._-]+", + "_", + selected_candidate.relative_label, + ).strip("_") + destination_name = ( + f"{_safe_folder_name(input_dir.name)}__representative__" + f"{relative_label or selected_candidate.file_name}" + ) + destination = output_dir / destination_name + shutil.copy2(selected_candidate.file_path, destination) + return destination + + +def _write_outputs(result: RepresentativeFinderResult) -> None: + result.summary_json_path.write_text( + json.dumps(result.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + with result.score_table_path.open( + "w", + encoding="utf-8", + newline="", + ) as handle: + writer = csv.writer(handle, delimiter="\t") + writer.writerow( + [ + "rank", + "file_name", + "relative_label", + "motif_label", + "score_total", + "score_bond", + "score_angle", + "score_solvent", + "atom_count", + "solvent_atom_count", + "direct_solvent_atom_count", + "outer_solvent_atom_count", + "mean_direct_solvent_coordination", + ] + ) + for rank, candidate in enumerate(result.candidates, start=1): + writer.writerow( + [ + rank, + candidate.file_name, + candidate.relative_label, + candidate.motif_label, + _format_score(candidate.score_total), + _format_score(candidate.score_bond), + _format_score(candidate.score_angle), + _format_score(candidate.score_solvent), + candidate.atom_count, + candidate.solvent_atom_count, + candidate.direct_solvent_atom_count, + candidate.outer_solvent_atom_count, + f"{candidate.mean_direct_solvent_coordination:.8f}", + ] + ) + result.summary_text_path.write_text( + result.summary_text() + "\n", + encoding="utf-8", + ) + + +def _build_optional_predicted_representatives( + *, + input_dir: Path, + output_dir: Path, + project_dir: Path | None, + settings: RepresentativeFinderSettings, + analyzer: BondAnalyzer, + expected_core_counts: dict[str, int], + measured_candidates: list[RepresentativeFinderCandidate], + measured_structures: list[_MeasuredCandidateStructure], + selected_candidate: RepresentativeFinderCandidate, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], + pair_contact_distance_medians: dict[tuple[str, str], float] | None, +) -> tuple[ + RepresentativeFinderCandidate | None, + Path | None, + RepresentativeFinderCandidate | None, + Path | None, + tuple[str, ...], +]: + notes: list[str] = [] + core_counts = { + str(element).strip(): int(count) + for element, count in expected_core_counts.items() + if str(element).strip() and int(count) > 0 + } + if not core_counts: + notes.append( + "Predicted optimized representative skipped because the folder name " + "did not provide a parseable stoichiometric core." + ) + return None, None, None, None, tuple(notes) + if sum(core_counts.values()) <= 1: + return _build_single_atom_predicted_representatives( + input_dir=input_dir, + output_dir=output_dir, + settings=settings, + analyzer=analyzer, + measured_structures=measured_structures, + selected_candidate=selected_candidate, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + ) + + atom_type_definitions = _infer_predicted_atom_type_definitions( + core_counts=core_counts, + measured_candidates=measured_candidates, + settings=settings, + ) + pair_cutoff_definitions = _infer_predicted_pair_cutoff_definitions( + settings=settings, + core_counts=core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + ) + if not atom_type_definitions.get("node"): + notes.append( + "Predicted optimized representative skipped because no core node " + "elements could be inferred from the current stoichiometry." + ) + return None, None, None, None, tuple(notes) + if not pair_cutoff_definitions: + notes.append( + "Predicted optimized representative skipped because no geometry " + "cutoffs were available from the current bond and angle settings." + ) + return None, None, None, None, tuple(notes) + + try: + from saxshell.clusterdynamicsml.workflow import ( + ClusterDynamicsMLTrainingObservation, + ClusterDynamicsMLWorkflow, + ) + except Exception as exc: + notes.append( + "Predicted optimized representative skipped because the Cluster " + f"Dynamics ML scaffold builder was unavailable: {exc}" + ) + return None, None, None, None, tuple(notes) + + try: + workflow = ClusterDynamicsMLWorkflow( + frames_dir=selected_candidate.file_path.parent, + atom_type_definitions=atom_type_definitions, + pair_cutoff_definitions=pair_cutoff_definitions, + clusters_dir=input_dir, + project_dir=project_dir, + ) + training_observations, source_observation = ( + _build_predicted_training_observations( + observation_cls=ClusterDynamicsMLTrainingObservation, + core_counts=core_counts, + node_elements=workflow._atom_type_elements("node"), + measured_structures=measured_structures, + selected_candidate=selected_candidate, + ) + ) + geometry_statistics = workflow._collect_training_geometry_statistics( + training_observations + ) + predicted_max_radius = _predicted_target_max_radius( + measured_structures, + core_elements=set(core_counts), + ) + generated_elements, generated_coordinates = ( + workflow._generate_predicted_structure( + source_observation, + target_counts=core_counts, + predicted_max_radius=predicted_max_radius, + geometry_statistics=geometry_statistics, + ) + ) + except Exception as exc: + notes.append( + "Predicted optimized representative generation failed while " + f"building the synthetic scaffold: {exc}" + ) + return None, None, None, None, tuple(notes) + + generated_array = np.asarray(generated_coordinates, dtype=float) + if not generated_elements or generated_array.size <= 0: + notes.append( + "Predicted optimized representative generation produced an empty " + "structure." + ) + return None, None, None, None, tuple(notes) + + refined_coordinates = _refine_predicted_coordinates( + elements=tuple(str(element) for element in generated_elements), + coordinates=generated_array, + analyzer=analyzer, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + ) + + predicted_output_path = output_dir / ( + f"{_safe_folder_name(input_dir.name)}" + "__predicted_optimized_representative.xyz" + ) + _write_xyz_structure_file( + predicted_output_path, + tuple(str(element) for element in generated_elements), + refined_coordinates, + comment=( + "Predicted optimized representative generated from aggregate " + "geometry targets" + ), + ) + predicted_measured = _measure_candidate_structure_file( + predicted_output_path, + relative_label=predicted_output_path.name, + motif_label="predicted_optimized", + analyzer=analyzer, + include_parsed_structure=False, + ) + predicted_candidate = predicted_measured.candidate + predicted_candidate.descriptor_notes = ( + "Synthetic predicted structure generated from aggregate bond and " + "angle targets.", + ) + _score_candidate_against_targets( + predicted_candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + include_solvent_component=False, + ) + notes.append( + "Predicted optimized representative generated with the Cluster " + "Dynamics ML geometry scaffold and locally refined against the " + "current geometric scoring target." + ) + + solvent_completed_predicted_candidate = None + solvent_completed_predicted_output_path = None + if project_dir is not None: + ( + solvent_completed_predicted_candidate, + solvent_completed_predicted_output_path, + solvent_completion_notes, + ) = _build_solvent_completed_predicted_representative( + project_dir=project_dir, + output_dir=output_dir, + predicted_output_path=predicted_output_path, + analyzer=analyzer, + expected_core_counts=core_counts, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + pair_contact_distance_medians=pair_contact_distance_medians, + ) + notes.extend(solvent_completion_notes) + + return ( + predicted_candidate, + predicted_output_path, + solvent_completed_predicted_candidate, + solvent_completed_predicted_output_path, + tuple(notes), + ) + + +def _build_single_atom_predicted_representatives( + *, + input_dir: Path, + output_dir: Path, + settings: RepresentativeFinderSettings, + analyzer: BondAnalyzer, + measured_structures: list[_MeasuredCandidateStructure], + selected_candidate: RepresentativeFinderCandidate, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], +) -> tuple[ + RepresentativeFinderCandidate | None, + Path | None, + RepresentativeFinderCandidate | None, + Path | None, + tuple[str, ...], +]: + if not measured_structures: + return ( + None, + None, + None, + None, + ( + "Predicted optimized representative skipped because no measured " + "single-atom structure was available.", + ), + ) + source_structure = measured_structures[0] + predicted_output_path = output_dir / ( + f"{_safe_folder_name(input_dir.name)}" + "__predicted_optimized_representative.xyz" + ) + _write_xyz_structure_file( + predicted_output_path, + source_structure.elements, + source_structure.coordinates, + comment=( + "Predicted optimized representative copied from the single-atom " + "source structure" + ), + ) + predicted_measured = _measure_candidate_structure_file( + predicted_output_path, + relative_label=predicted_output_path.name, + motif_label="predicted_optimized", + analyzer=analyzer, + include_parsed_structure=False, + ) + predicted_candidate = predicted_measured.candidate + predicted_candidate.descriptor_notes = ( + "Single-atom stoichiometry: predicted optimized representative is " + "identical to the observed representative.", + ) + _score_candidate_against_targets( + predicted_candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + include_solvent_component=False, + ) + predicted_candidate.score_total = 0.0 + predicted_candidate.score_bond = 0.0 + predicted_candidate.score_angle = 0.0 + predicted_candidate.score_solvent = 0.0 + return ( + predicted_candidate, + predicted_output_path, + None, + None, + ( + "Single-atom stoichiometry: the predicted optimized representative " + "is the same single-atom structure as the observed representative.", + ), + ) + + +def _build_solvent_completed_predicted_representative( + *, + project_dir: Path, + output_dir: Path, + predicted_output_path: Path, + analyzer: BondAnalyzer, + expected_core_counts: dict[str, int], + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], + pair_contact_distance_medians: dict[tuple[str, str], float] | None, +) -> tuple[ + RepresentativeFinderCandidate | None, + Path | None, + tuple[str, ...], +]: + notes: list[str] = [] + try: + from saxshell.fullrmc.project_model import ensure_rmcsetup_structure + from saxshell.fullrmc.solvent_handling import ( + load_solvent_handling_metadata, + ) + from saxshell.fullrmc.solvent_shell_builder import ( + analyze_solvent_shell, + build_solvent_shell_output, + default_director_atom_name, + ) + except Exception as exc: + return ( + None, + None, + ("Predicted solvent-shell completion was unavailable: " f"{exc}",), + ) + + try: + rmcsetup_paths = ensure_rmcsetup_structure(project_dir) + solvent_metadata = load_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path + ) + except Exception as exc: + return ( + None, + None, + ( + "Predicted solvent-shell completion could not read the project " + f"solvent settings: {exc}", + ), + ) + if solvent_metadata is None: + return ( + None, + None, + ( + "Predicted solvent-shell completion was skipped because the " + "project does not yet have solvent-handling settings.", + ), + ) + + reference_identifier = _solvent_reference_identifier_from_metadata( + solvent_metadata + ) + settings_payload = solvent_metadata.settings + director_atom_name = ( + settings_payload.director_atom_name + or default_director_atom_name(reference_identifier) + ) + if not director_atom_name: + return ( + None, + None, + ( + "Predicted solvent-shell completion was skipped because no " + "director atom could be resolved for the project solvent reference.", + ), + ) + try: + analysis_result = analyze_solvent_shell( + predicted_output_path, + reference_identifier, + reference_match_tolerance_a=( + settings_payload.reference_match_tolerance_a + ), + ) + except Exception as exc: + return ( + None, + None, + ( + "Predicted solvent-shell completion was skipped because the " + f"predicted structure could not be analyzed for solvent anchors: {exc}", + ), + ) + + solute_distance_cutoffs = { + str(element): float(setting.director_distance_cutoff_a) + for element, setting in settings_payload.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and float(setting.director_distance_cutoff_a) > 0.0 + } + coordinating_center_elements = tuple( + sorted( + element + for element, setting in settings_payload.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and setting.coordination_center + and float(setting.target_coordination_number) > 0.0 + ) + ) + target_coordination_numbers = { + str(element): float(setting.target_coordination_number) + for element, setting in settings_payload.solute_atom_settings.items() + if element in analysis_result.solute_element_counts + and setting.coordination_center + and float(setting.target_coordination_number) > 0.0 + } + if not solute_distance_cutoffs and not target_coordination_numbers: + return ( + None, + None, + ( + "Predicted solvent-shell completion was skipped because the " + "project solvent settings do not define any active shell-building " + "cutoffs or coordination targets for this stoichiometry.", + ), + ) + + solvent_completed_output_path = output_dir / ( + f"{_safe_folder_name(predicted_output_path.stem)}" + "__solvent_completed.pdb" + ) + try: + build_result = build_solvent_shell_output( + predicted_output_path, + reference_identifier, + output_path=solvent_completed_output_path, + director_atom_name=director_atom_name, + minimum_solvent_atom_separation_a=( + settings_payload.minimum_solvent_atom_separation_a + ), + solute_distance_cutoffs_a=solute_distance_cutoffs, + coordinating_center_elements=coordinating_center_elements, + target_average_coordination_numbers=target_coordination_numbers, + reference_match_tolerance_a=( + settings_payload.reference_match_tolerance_a + ), + analysis_result=analysis_result, + ) + except Exception as exc: + return ( + None, + None, + ( + "Predicted solvent-shell completion failed while building the " + f"solvent shell: {exc}", + ), + ) + if int(build_result.solvent_molecules_added) <= 0: + return ( + None, + None, + ( + "Predicted solvent-shell completion finished without placing any " + "solvent molecules, so only the no-solvent predicted structure was kept.", + ), + ) + + measured_completed = _measure_candidate_structure_file( + solvent_completed_output_path, + relative_label=solvent_completed_output_path.name, + motif_label="predicted_optimized_solvent_completed", + analyzer=analyzer, + include_parsed_structure=True, + ) + completed_candidate = measured_completed.candidate + if ( + measured_completed.parsed_structure is not None + and pair_contact_distance_medians is not None + ): + _apply_solvent_descriptor( + completed_candidate, + parsed_structure=measured_completed.parsed_structure, + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + ) + completed_candidate.descriptor_notes = ( + "Synthetic predicted structure with a solvent shell built from the " + "current project solvent settings.", + ) + _score_candidate_against_targets( + completed_candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics=target_solvent_metrics, + include_solvent_component=bool(target_solvent_metrics), + ) + notes.append( + "Built a solvent-completed predicted representative from the project " + f"solvent settings ({build_result.solvent_molecules_added} solvent " + "molecule(s) added)." + ) + return completed_candidate, solvent_completed_output_path, tuple(notes) + + +def _infer_predicted_atom_type_definitions( + *, + core_counts: dict[str, int], + measured_candidates: list[RepresentativeFinderCandidate], + settings: RepresentativeFinderSettings, +): + core_elements = { + str(element).strip() for element in core_counts if str(element).strip() + } + observed_elements = { + str(element).strip() + for candidate in measured_candidates + for element in candidate.element_counts + if str(element).strip() + } + node_elements = [ + definition.vertex + for definition in settings.angle_triplets + if definition.vertex in core_elements + ] + if not node_elements: + node_elements = [ + definition.atom1 + for definition in settings.bond_pairs + if definition.atom1 in core_elements + ] + if not node_elements and core_counts: + minimum_count = min(int(count) for count in core_counts.values()) + node_elements = [ + str(element) + for element, count in sorted(core_counts.items()) + if int(count) == minimum_count + ] + normalized_node_elements = tuple(dict.fromkeys(node_elements)) + linker_elements = tuple( + element + for element in sorted(core_elements) + if element not in normalized_node_elements + ) + if not normalized_node_elements: + normalized_node_elements = tuple(sorted(core_elements)) + linker_elements = () + shell_elements = tuple( + sorted( + element + for element in observed_elements + if element not in core_elements + ) + ) + definitions = { + "node": [(element, None) for element in normalized_node_elements], + } + if linker_elements: + definitions["linker"] = [ + (element, None) for element in linker_elements + ] + if shell_elements: + definitions["shell"] = [(element, None) for element in shell_elements] + return definitions + + +def _infer_predicted_pair_cutoff_definitions( + *, + settings: RepresentativeFinderSettings, + core_counts: dict[str, int], + pair_contact_distance_medians: dict[tuple[str, str], float] | None, +): + pair_cutoffs: defaultdict[tuple[str, str], dict[int, float]] = defaultdict( + dict + ) + + def add_cutoff(element_a: str, element_b: str, cutoff: float) -> None: + normalized_pair = tuple( + sorted((str(element_a).strip(), str(element_b).strip())) + ) + if not normalized_pair[0] or not normalized_pair[1]: + return + previous = pair_cutoffs[normalized_pair].get(0, 0.0) + pair_cutoffs[normalized_pair][0] = max(previous, float(cutoff)) + + for definition in settings.bond_pairs: + add_cutoff( + definition.atom1, + definition.atom2, + max(float(definition.cutoff_angstrom), 0.1), + ) + for definition in settings.angle_triplets: + add_cutoff( + definition.vertex, + definition.arm1, + max(float(definition.cutoff1_angstrom), 0.1), + ) + add_cutoff( + definition.vertex, + definition.arm2, + max(float(definition.cutoff2_angstrom), 0.1), + ) + if pair_cutoffs or not pair_contact_distance_medians: + return { + pair: dict(levels) for pair, levels in sorted(pair_cutoffs.items()) + } + + core_elements = set(core_counts) + for pair, median_distance in sorted(pair_contact_distance_medians.items()): + if pair[0] not in core_elements or pair[1] not in core_elements: + continue + add_cutoff( + pair[0], + pair[1], + max(float(median_distance) * 1.15, float(median_distance) + 0.05), + ) + return { + pair: dict(levels) for pair, levels in sorted(pair_cutoffs.items()) + } + + +def _build_predicted_training_observations( + *, + observation_cls, + core_counts: dict[str, int], + node_elements: set[str], + measured_structures: list[_MeasuredCandidateStructure], + selected_candidate: RepresentativeFinderCandidate, +): + observations = [] + source_observation = None + cluster_size = int(sum(core_counts.values())) + node_count = ( + int(sum(core_counts.get(element, 0) for element in node_elements)) + or cluster_size + ) + selected_path = selected_candidate.file_path.resolve() + for measured in sorted( + measured_structures, + key=lambda row: str(row.candidate.file_path), + ): + representative_path = measured.candidate.file_path.resolve() + observation = observation_cls( + label=measured.candidate.relative_label, + node_count=node_count, + cluster_size=cluster_size, + element_counts=dict(sorted(core_counts.items())), + file_count=1, + representative_path=representative_path, + structure_dir=representative_path, + motifs=( + (measured.candidate.motif_label,) + if measured.candidate.motif_label != "no_motif" + else () + ), + mean_atom_count=float(measured.candidate.atom_count), + mean_radius_of_gyration=0.0, + mean_max_radius=_max_radius_from_coordinates( + _filtered_coordinates_for_elements( + measured.coordinates, + measured.elements, + set(core_counts), + ) + ), + mean_semiaxis_a=0.0, + mean_semiaxis_b=0.0, + mean_semiaxis_c=0.0, + total_observations=1, + occupied_frames=1, + mean_count_per_frame=1.0, + occupancy_fraction=1.0, + association_events=0, + dissociation_events=0, + association_rate_per_ps=0.0, + dissociation_rate_per_ps=0.0, + completed_lifetime_count=0, + window_truncated_lifetime_count=0, + mean_lifetime_fs=None, + std_lifetime_fs=None, + ) + observations.append(observation) + if representative_path == selected_path: + source_observation = observation + if not observations: + raise ValueError( + "No measured structures were available for prediction." + ) + return observations, source_observation or observations[0] + + +def _predicted_target_max_radius( + measured_structures: list[_MeasuredCandidateStructure], + *, + core_elements: set[str], +) -> float: + radii = [ + _max_radius_from_coordinates( + _filtered_coordinates_for_elements( + measured.coordinates, + measured.elements, + core_elements, + ) + ) + for measured in measured_structures + ] + finite_radii = [ + float(radius) + for radius in radii + if np.isfinite(float(radius)) and float(radius) > 0.0 + ] + if not finite_radii: + return 1.0 + return float(np.median(np.asarray(finite_radii, dtype=float))) + + +def _measure_candidate_structure_file( + file_path: Path, + *, + relative_label: str, + motif_label: str, + analyzer: BondAnalyzer, + include_parsed_structure: bool, +) -> _MeasuredCandidateStructure: + coordinates, elements = load_structure_file(file_path) + coordinates_array = np.asarray(coordinates, dtype=float) + normalized_elements = tuple(str(element).strip() for element in elements) + bond_values, angle_values = analyzer.measure_structure_data( + coordinates_array, + normalized_elements, + ) + element_counts = Counter(normalized_elements) + parsed_structure = None + if include_parsed_structure: + parsed_structure = ParsedContrastStructure( + file_path=file_path, + coordinates=coordinates_array, + elements=normalized_elements, + element_counts=dict(sorted(element_counts.items())), + ) + candidate = RepresentativeFinderCandidate( + file_path=file_path, + relative_label=relative_label, + motif_label=motif_label, + atom_count=len(normalized_elements), + element_counts=dict(sorted(element_counts.items())), + bond_values=bond_values, + angle_values=angle_values, + solvent_metrics={}, + solvent_atom_count=0, + direct_solvent_atom_count=0, + outer_solvent_atom_count=0, + mean_direct_solvent_coordination=0.0, + ) + return _MeasuredCandidateStructure( + candidate=candidate, + coordinates=coordinates_array, + elements=normalized_elements, + parsed_structure=parsed_structure, + ) + + +def _apply_solvent_descriptor( + candidate: RepresentativeFinderCandidate, + *, + parsed_structure: ParsedContrastStructure, + expected_core_counts: dict[str, int], + pair_contact_distance_medians: dict[tuple[str, str], float], +) -> None: + descriptor = describe_parsed_contrast_structure( + parsed_structure, + expected_core_counts=expected_core_counts, + pair_contact_distance_medians=pair_contact_distance_medians, + include_geometry_metrics=False, + ) + candidate.solvent_metrics = descriptor.solvent_metrics() + candidate.solvent_atom_count = descriptor.solvent_atom_count + candidate.direct_solvent_atom_count = descriptor.direct_solvent_atom_count + candidate.outer_solvent_atom_count = descriptor.outer_solvent_atom_count + candidate.mean_direct_solvent_coordination = float( + descriptor.mean_direct_solvent_coordination + ) + candidate.descriptor_notes = tuple(descriptor.notes) + + +def _score_candidate_against_targets( + candidate: RepresentativeFinderCandidate, + *, + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], + target_solvent_metrics: dict[str, float], + include_solvent_component: bool, +) -> None: + if settings.selection_algorithm == "target_distribution_moment_distance": + bond_score = _category_moment_distance_from_features( + target_bond_features, + candidate.bond_values, + ) + angle_score = _category_moment_distance_from_features( + target_angle_features, + candidate.angle_values, + ) + else: + bond_score = _category_distance_from_features( + target_bond_features, + candidate.bond_values, + quantiles=settings.quantiles, + ) + angle_score = _category_distance_from_features( + target_angle_features, + candidate.angle_values, + quantiles=settings.quantiles, + ) + solvent_score: float | None + total_score = float( + settings.bond_weight * bond_score + settings.angle_weight * angle_score + ) + if include_solvent_component: + solvent_score = _score_feature_map( + candidate.solvent_metrics, + target_solvent_metrics, + default_scale=1.0, + ) + total_score += settings.solvent_weight * float(solvent_score) + else: + solvent_score = 0.0 if settings.solvent_weight <= 0.0 else None + candidate.score_bond = float(bond_score) + candidate.score_angle = float(angle_score) + candidate.score_solvent = solvent_score + candidate.score_total = total_score + + +def _refine_predicted_coordinates( + *, + elements: tuple[str, ...], + coordinates: np.ndarray, + analyzer: BondAnalyzer, + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], +) -> np.ndarray: + coordinate_array = np.asarray(coordinates, dtype=float) + if ( + coordinate_array.ndim != 2 + or coordinate_array.shape[0] <= 1 + or (not target_bond_features and not target_angle_features) + ): + return coordinate_array + seed_basis = "|".join(elements) + f":{coordinate_array.shape[0]}" + rng = np.random.default_rng(_stable_seed_from_text(seed_basis)) + best_coordinates = coordinate_array.copy() + best_score = _geometry_only_score_from_coordinates( + elements=elements, + coordinates=best_coordinates, + analyzer=analyzer, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + ) + for step in range(48): + perturbation_scale = max(0.03, 0.22 * (0.95**step)) + trial = best_coordinates.copy() + trial[1:] += rng.normal( + 0.0, + perturbation_scale, + size=trial[1:].shape, + ) + trial -= np.mean(trial, axis=0, keepdims=True) + trial_score = _geometry_only_score_from_coordinates( + elements=elements, + coordinates=trial, + analyzer=analyzer, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + ) + if trial_score < best_score: + best_coordinates = trial + best_score = trial_score + return best_coordinates + + +def _geometry_only_score_from_coordinates( + *, + elements: tuple[str, ...], + coordinates: np.ndarray, + analyzer: BondAnalyzer, + settings: RepresentativeFinderSettings, + target_bond_features: dict[object, np.ndarray], + target_angle_features: dict[object, np.ndarray], +) -> float: + bond_values, angle_values = analyzer.measure_structure_data( + np.asarray(coordinates, dtype=float), + elements, + ) + temp_candidate = RepresentativeFinderCandidate( + file_path=Path("."), + relative_label="predicted_trial", + motif_label="predicted_trial", + atom_count=len(elements), + element_counts=dict(sorted(Counter(elements).items())), + bond_values=bond_values, + angle_values=angle_values, + solvent_metrics={}, + solvent_atom_count=0, + direct_solvent_atom_count=0, + outer_solvent_atom_count=0, + mean_direct_solvent_coordination=0.0, + ) + _score_candidate_against_targets( + temp_candidate, + settings=settings, + target_bond_features=target_bond_features, + target_angle_features=target_angle_features, + target_solvent_metrics={}, + include_solvent_component=False, + ) + return float(temp_candidate.score_total or 0.0) + + +def _filtered_coordinates_for_elements( + coordinates: np.ndarray, + elements: tuple[str, ...], + allowed_elements: set[str], +) -> np.ndarray: + if not allowed_elements: + return np.asarray(coordinates, dtype=float) + indices = [ + index + for index, element in enumerate(elements) + if element in allowed_elements + ] + if not indices: + return np.asarray(coordinates, dtype=float) + return np.asarray(coordinates, dtype=float)[indices] + + +def _max_radius_from_coordinates(coordinates: np.ndarray) -> float: + coordinate_array = np.asarray(coordinates, dtype=float) + if coordinate_array.size <= 0: + return 0.0 + centered = coordinate_array - np.mean( + coordinate_array, + axis=0, + keepdims=True, + ) + radial = np.linalg.norm(centered, axis=1) + if radial.size <= 0: + return 0.0 + return float(np.max(radial)) + + +def _write_xyz_structure_file( + output_path: Path, + elements: tuple[str, ...], + coordinates: np.ndarray, + *, + comment: str, +) -> Path: + coordinate_array = np.asarray(coordinates, dtype=float) + lines = [str(len(elements)), str(comment).strip() or output_path.stem] + for element, (x_coord, y_coord, z_coord) in zip( + elements, + coordinate_array, + strict=False, + ): + lines.append(f"{element} {x_coord:.6f} {y_coord:.6f} {z_coord:.6f}") + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return output_path + + +def _stable_seed_from_text(text: str) -> int: + seed = 0 + for character in str(text): + seed = (seed * 131 + ord(character)) % (2**32) + return int(seed) + + +def persist_representativefinder_result_to_project( + project_dir: str | Path, + result: RepresentativeFinderResult, +) -> Path: + from saxshell.fullrmc.project_model import ensure_rmcsetup_structure + from saxshell.fullrmc.representatives import ( + DistributionSelectionEntry, + DistributionSelectionMetadata, + RepresentativeSelectionEntry, + RepresentativeSelectionMetadata, + RepresentativeSelectionSettings, + load_representative_selection_metadata, + save_representative_selection_metadata, + ) + from saxshell.fullrmc.solvent_handling import ( + load_solvent_handling_metadata, + save_solvent_handling_metadata, + ) + from saxshell.saxs.project_manager import DreamBestFitSelection + + resolved_project_dir = Path(project_dir).expanduser().resolve() + rmcsetup_paths = ensure_rmcsetup_structure(resolved_project_dir) + solvent_metadata = load_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path + ) + source_solvent_mode = _classify_project_representative_source_solvent_mode( + result, + solvent_metadata=solvent_metadata, + rmcsetup_paths=rmcsetup_paths, + ) + shared_output_path, mirrored_output_paths = ( + _copy_project_representative_file( + rmcsetup_paths, + result, + source_solvent_mode=source_solvent_mode, + ) + ) + project_cached_results_path = _write_project_cached_result( + rmcsetup_paths, + result, + ) + updated_key = ( + result.structure_label, + "no_motif", + result.structure_label, + ) + + now = datetime.now().isoformat(timespec="seconds") + existing_metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + _remove_stale_project_representative_artifacts( + rmcsetup_paths, + updated_key, + existing_metadata=existing_metadata, + solvent_metadata=solvent_metadata, + preserved_paths=tuple(mirrored_output_paths), + ) + selection = ( + existing_metadata.selection + if existing_metadata is not None + else DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path=str( + rmcsetup_paths.representative_clusters_dir.relative_to( + resolved_project_dir + ) + ), + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at=now, + ) + ) + selection.run_relative_path = str( + rmcsetup_paths.representative_clusters_dir.relative_to( + resolved_project_dir + ) + ) + selection.label = selection.label or "Representative Structure Finder" + selection.selection_source = ( + selection.selection_source or "representativefinder" + ) + selection.selected_at = now + + updated_entry = RepresentativeSelectionEntry( + structure=result.structure_label, + motif="no_motif", + param=result.structure_label, + selected_weight=0.0, + cluster_count=max(len(result.candidates), 1), + source_dir=str(shared_output_path.parent), + source_file=str(shared_output_path), + source_file_name=shared_output_path.name, + atom_count=int(result.selected_candidate.atom_count), + element_counts=dict( + sorted(result.selected_candidate.element_counts.items()) + ), + source_solvent_mode=source_solvent_mode, + analysis_source=( + "representativefinder:" + f"{result.selected_candidate.relative_label}" + ), + score_total=_optional_float(result.selected_candidate.score_total), + score_bond=_optional_float(result.selected_candidate.score_bond), + score_angle=_optional_float(result.selected_candidate.score_angle), + cached_results_path=str(result.summary_json_path), + project_cached_results_path=str(project_cached_results_path), + ) + + merged_entries = _merge_project_representative_entries( + existing_metadata, + updated_entry, + ) + _reweight_project_representative_entries(merged_entries) + distribution_entries = [ + DistributionSelectionEntry( + param=entry.param, + structure=entry.structure, + motif=entry.motif, + selected_weight=float(entry.selected_weight), + vary=True, + cluster_count=int(entry.cluster_count), + source_dir=entry.source_dir, + source_file=entry.source_file, + source_file_name=entry.source_file_name, + source_kind="representative_structure", + is_active=True, + ) + for entry in merged_entries + ] + distribution_selection = DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir=str(rmcsetup_paths.representative_clusters_dir), + updated_at=now, + entries=distribution_entries, + ) + metadata = RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=distribution_selection, + settings=RepresentativeSelectionSettings( + selection_mode="representative_finder", + selection_algorithm=result.settings.selection_algorithm, + minimum_cluster_count_for_analysis=1, + bond_weight=float(result.settings.bond_weight), + angle_weight=float(result.settings.angle_weight), + quantiles=tuple( + float(value) for value in result.settings.quantiles + ), + bond_pairs=tuple(result.settings.bond_pairs), + angle_triplets=tuple(result.settings.angle_triplets), + ), + updated_at=now, + representative_entries=merged_entries, + missing_bins=[], + invalid_bins=[], + ) + save_representative_selection_metadata( + rmcsetup_paths.representative_selection_path, + metadata, + ) + if solvent_metadata is not None and solvent_metadata.entries: + filtered_entries = [ + entry + for entry in solvent_metadata.entries + if not _entry_matches_project_representative_key( + entry, + updated_key, + ) + ] + if len(filtered_entries) != len(solvent_metadata.entries): + solvent_metadata.entries = filtered_entries + solvent_metadata.updated_at = now + save_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path, + solvent_metadata, + ) + return shared_output_path + + +def _write_project_cached_result( + rmcsetup_paths, + result: RepresentativeFinderResult, +) -> Path: + cache_dir = ( + rmcsetup_paths.representative_clusters_dir + / "analysis_cache" + / _safe_folder_name(result.structure_label) + ) + cache_dir.mkdir(parents=True, exist_ok=True) + cached_result_path = cache_dir / "representative_selection.json" + cached_score_table_path = cache_dir / "candidate_scores.tsv" + cached_summary_text_path = cache_dir / "selection_summary.txt" + + payload = result.to_dict() + payload["project_cached_at"] = datetime.now().isoformat(timespec="seconds") + payload["summary_json_path"] = str(cached_result_path) + payload["score_table_path"] = str(cached_score_table_path) + payload["summary_text_path"] = str(cached_summary_text_path) + cached_result_path.write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", + ) + if result.score_table_path.is_file(): + shutil.copy2(result.score_table_path, cached_score_table_path) + else: + _write_candidate_score_table(result, cached_score_table_path) + if result.summary_text_path.is_file(): + shutil.copy2(result.summary_text_path, cached_summary_text_path) + else: + cached_summary_text_path.write_text( + result.summary_text() + "\n", + encoding="utf-8", + ) + return cached_result_path.resolve() + + +def _write_candidate_score_table( + result: RepresentativeFinderResult, + output_path: Path, +) -> None: + with output_path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.writer(handle, delimiter="\t") + writer.writerow( + [ + "rank", + "file_name", + "relative_label", + "motif_label", + "score_total", + "score_bond", + "score_angle", + "score_solvent", + "atom_count", + "solvent_atom_count", + "direct_solvent_atom_count", + "outer_solvent_atom_count", + "mean_direct_solvent_coordination", + ] + ) + for rank, candidate in enumerate(result.candidates, start=1): + writer.writerow( + [ + rank, + candidate.file_name, + candidate.relative_label, + candidate.motif_label, + _format_score(candidate.score_total), + _format_score(candidate.score_bond), + _format_score(candidate.score_angle), + _format_score(candidate.score_solvent), + candidate.atom_count, + candidate.solvent_atom_count, + candidate.direct_solvent_atom_count, + candidate.outer_solvent_atom_count, + f"{candidate.mean_direct_solvent_coordination:.8f}", + ] + ) + + +def _copy_project_representative_file( + rmcsetup_paths, + result: RepresentativeFinderResult, + *, + source_solvent_mode: str, +) -> tuple[Path, tuple[Path, ...]]: + relative_label = re.sub( + r"[^0-9A-Za-z._-]+", + "_", + result.selected_candidate.relative_label, + ).strip("_") + destination_name = ( + f"{_safe_folder_name(result.structure_label)}__representative__" + f"{relative_label or result.selected_candidate.file_name}" + ) + destination_paths: list[Path] = [] + for destination_root in _project_representative_destination_roots( + rmcsetup_paths, + result, + source_solvent_mode=source_solvent_mode, + ): + destination_dir = destination_root / _safe_folder_name( + result.structure_label + ) + destination_dir.mkdir(parents=True, exist_ok=True) + destination = destination_dir / destination_name + shutil.copy2(result.selected_candidate.file_path, destination) + destination_paths.append(destination.resolve()) + primary_destination = next( + ( + path + for path in destination_paths + if path.parent.parent.resolve() + == _project_representative_destination_root( + rmcsetup_paths, + source_solvent_mode=source_solvent_mode, + ).resolve() + ), + destination_paths[0], + ) + return primary_destination, tuple(destination_paths) + + +def _project_representative_destination_roots( + rmcsetup_paths, + result: RepresentativeFinderResult, + *, + source_solvent_mode: str, +) -> tuple[Path, ...]: + primary_root = _project_representative_destination_root( + rmcsetup_paths, + source_solvent_mode=source_solvent_mode, + ) + if int(result.selected_candidate.atom_count) != 1: + return (primary_root,) + return ( + rmcsetup_paths.pdb_no_solvent_dir, + rmcsetup_paths.representative_partial_solvent_dir, + rmcsetup_paths.pdb_with_solvent_dir, + ) + + +def _project_representative_destination_root( + rmcsetup_paths, + *, + source_solvent_mode: str, +) -> Path: + normalized = str(source_solvent_mode).strip().lower() + if normalized == "nosolv": + return rmcsetup_paths.pdb_no_solvent_dir + if normalized == "fullsolv": + return rmcsetup_paths.pdb_with_solvent_dir + return rmcsetup_paths.representative_partial_solvent_dir + + +def _classify_project_representative_source_solvent_mode( + result: RepresentativeFinderResult, + *, + solvent_metadata, + rmcsetup_paths, +) -> str: + candidate = result.selected_candidate + if int(candidate.solvent_atom_count) <= 0: + return "nosolv" + inferred_from_path = _infer_project_representative_source_mode_from_path( + candidate.file_path, + rmcsetup_paths=rmcsetup_paths, + ) + if inferred_from_path is not None: + return inferred_from_path + if solvent_metadata is None: + return "partialsolv" + try: + from saxshell.fullrmc.solvent_shell_builder import ( + analyze_solvent_shell, + ) + except Exception: + return "partialsolv" + try: + analysis_result = analyze_solvent_shell( + candidate.file_path, + _solvent_reference_identifier_from_metadata(solvent_metadata), + reference_match_tolerance_a=( + solvent_metadata.settings.reference_match_tolerance_a + ), + ) + except Exception: + return "partialsolv" + if analysis_result.complete_solvent_molecule_count > 0: + if analysis_result.partial_solvent_molecule_count > 0: + return "partialsolv" + return "fullsolv" + if analysis_result.partial_solvent_molecule_count > 0: + return "partialsolv" + return "nosolv" + + +def _infer_project_representative_source_mode_from_path( + file_path: Path, + *, + rmcsetup_paths, +) -> str | None: + resolved_path = Path(file_path).expanduser().resolve() + for mode, directory in ( + ("nosolv", rmcsetup_paths.pdb_no_solvent_dir), + ("partialsolv", rmcsetup_paths.representative_partial_solvent_dir), + ("fullsolv", rmcsetup_paths.pdb_with_solvent_dir), + ): + try: + resolved_path.relative_to(directory.resolve()) + except ValueError: + continue + return mode + return None + + +def _remove_stale_project_representative_artifacts( + rmcsetup_paths, + target_key: tuple[str, str, str], + *, + existing_metadata, + solvent_metadata, + preserved_paths: tuple[Path, ...] = (), +) -> None: + preserved = {Path(path).expanduser().resolve() for path in preserved_paths} + candidate_paths: set[Path] = set() + if existing_metadata is not None: + candidate_paths.update( + _tracked_representative_source_paths_for_key( + existing_metadata, + target_key, + ) + ) + if solvent_metadata is not None: + candidate_paths.update( + _tracked_solvent_output_paths_for_key( + solvent_metadata, + target_key, + ) + ) + allowed_roots = ( + rmcsetup_paths.representative_clusters_dir.resolve(), + rmcsetup_paths.representative_partial_solvent_dir.resolve(), + rmcsetup_paths.pdb_no_solvent_dir.resolve(), + rmcsetup_paths.pdb_with_solvent_dir.resolve(), + ) + for path in candidate_paths: + resolved = Path(path).expanduser().resolve() + if resolved in preserved or not resolved.is_file(): + continue + if not any( + _path_is_within_dir(resolved, root) for root in allowed_roots + ): + continue + resolved.unlink() + + +def _tracked_representative_source_paths_for_key( + metadata, + target_key: tuple[str, str, str], +) -> set[Path]: + return { + Path(entry.source_file).expanduser().resolve() + for entry in metadata.representative_entries + if _entry_matches_project_representative_key(entry, target_key) + and str(entry.source_file).strip() + } + + +def _tracked_solvent_output_paths_for_key( + solvent_metadata, + target_key: tuple[str, str, str], +) -> set[Path]: + tracked_paths: set[Path] = set() + for entry in solvent_metadata.entries: + if not _entry_matches_project_representative_key(entry, target_key): + continue + for candidate in (entry.no_solvent_pdb, entry.completed_pdb): + text = str(candidate).strip() + if text: + tracked_paths.add(Path(text).expanduser().resolve()) + return tracked_paths + + +def _entry_matches_project_representative_key( + entry, + target_key: tuple[str, str, str], +) -> bool: + return ( + str(entry.structure).strip(), + str(entry.motif).strip() or "no_motif", + str(entry.param).strip(), + ) == target_key + + +def _path_is_within_dir(path: Path, directory: Path) -> bool: + try: + path.relative_to(directory) + except ValueError: + return False + return True + + +def _solvent_reference_identifier_from_metadata(solvent_metadata) -> str: + settings = solvent_metadata.settings + if ( + str(settings.reference_source).strip().lower() == "custom" + and settings.custom_reference_path + ): + return str(Path(settings.custom_reference_path).expanduser().resolve()) + return str(settings.preset_name).strip() or "dmf" + + +def _merge_project_representative_entries( + existing_metadata, + updated_entry, +): + existing_entries = ( + [] + if existing_metadata is None + else list(existing_metadata.representative_entries) + ) + updated_key = ( + updated_entry.structure, + updated_entry.motif, + updated_entry.param, + ) + merged_entries = [ + entry + for entry in existing_entries + if ( + str(entry.structure).strip(), + str(entry.motif).strip() or "no_motif", + str(entry.param).strip(), + ) + != updated_key + ] + merged_entries.append(updated_entry) + merged_entries.sort( + key=lambda entry: ( + _natural_sort_key(entry.structure), + _natural_sort_key(entry.motif), + _natural_sort_key(entry.param), + ) + ) + return merged_entries + + +def _reweight_project_representative_entries(entries) -> None: + total_cluster_count = max( + sum(max(int(entry.cluster_count), 0) for entry in entries), + 1, + ) + for entry in entries: + entry.selected_weight = ( + max(int(entry.cluster_count), 0) / total_cluster_count + ) + + +def _next_available_output_dir(parent_dir: Path, folder_name: str) -> Path: + candidate = parent_dir / folder_name + if not candidate.exists(): + return candidate + index = 1 + while True: + candidate = parent_dir / f"{folder_name}{index:04d}" + if not candidate.exists(): + return candidate + index += 1 + + +def _safe_folder_name(text: str) -> str: + cleaned = re.sub(r"[^0-9A-Za-z]+", "_", str(text)).strip("_") + return cleaned or "structure_folder" + + +def _natural_sort_key(value: str) -> list[object]: + return [ + int(token) if token.isdigit() else token.lower() + for token in re.split(r"(\d+)", value) + if token + ] + + +def _optional_float(value: float | None) -> float | None: + if value is None: + return None + return float(value) + + +def _format_score(value: float | None) -> str: + if value is None: + return "n/a" + return f"{float(value):.8f}" + + +__all__ = [ + "RepresentativeFinderCandidate", + "RepresentativeFinderFolderInspection", + "RepresentativeFinderInputInspection", + "RepresentativeFinderOperationCancelled", + "RepresentativeFinderPlotSeries", + "RepresentativeFinderResult", + "RepresentativeFinderSettings", + "analyze_representative_structure_folder", + "estimate_representativefinder_total_work", + "inspect_representative_structure_input", + "inspect_representative_structure_folder", + "load_representativefinder_result", + "persist_representativefinder_result_to_project", + "representativefinder_result_from_dict", + "representativefinder_settings_from_dict", + "suggest_representativefinder_output_dir", + "suggest_representativefinder_target_output_dir", +] diff --git a/src/saxshell/saxs/contrast/descriptors.py b/src/saxshell/saxs/contrast/descriptors.py index 59a7a60..c9408bc 100644 --- a/src/saxshell/saxs/contrast/descriptors.py +++ b/src/saxshell/saxs/contrast/descriptors.py @@ -55,6 +55,19 @@ def _mean(values: list[float] | tuple[float, ...]) -> float: return float(np.mean(np.asarray(values, dtype=float))) +def _element_index_lookup( + elements: tuple[str, ...], +) -> dict[str, np.ndarray]: + grouped: defaultdict[str, list[int]] = defaultdict(list) + for atom_index, element in enumerate(elements): + grouped[str(element)].append(int(atom_index)) + return { + element: np.asarray(indices, dtype=int) + for element, indices in sorted(grouped.items()) + if indices + } + + def _angle_between_vectors( vector_a: np.ndarray, vector_b: np.ndarray, @@ -211,23 +224,41 @@ def estimate_pair_contact_distance_medians( pair_distances: defaultdict[str, list[float]] = defaultdict(list) for parsed in parsed_structures: coordinates = np.asarray(parsed.coordinates, dtype=float) - elements = list(parsed.elements) - for atom_index, element in enumerate(elements): - nearest_by_pair: dict[str, float] = {} - for other_index, other_element in enumerate(elements): - if atom_index == other_index: + elements = parsed.elements + if len(elements) < 2: + continue + element_indices = _element_index_lookup(elements) + ordered_elements = tuple(element_indices.keys()) + for index, element_a in enumerate(ordered_elements): + indices_a = element_indices[element_a] + coords_a = coordinates[indices_a] + for element_b in ordered_elements[index:]: + indices_b = element_indices[element_b] + if element_a == element_b and len(indices_a) < 2: continue - pair_key = _pair_key(element, other_element) - distance = float( - np.linalg.norm( - coordinates[atom_index] - coordinates[other_index] - ) + coords_b = coordinates[indices_b] + pair_key = _pair_key(element_a, element_b) + distances = np.linalg.norm( + coords_a[:, np.newaxis, :] - coords_b[np.newaxis, :, :], + axis=2, + ) + if element_a == element_b: + np.fill_diagonal(distances, np.inf) + nearest = distances.min(axis=1) + finite = nearest[np.isfinite(nearest)] + if finite.size > 0: + pair_distances[pair_key].extend( + finite.astype(float).tolist() + ) + continue + nearest_a = distances.min(axis=1) + nearest_b = distances.min(axis=0) + pair_distances[pair_key].extend( + nearest_a.astype(float).tolist() + ) + pair_distances[pair_key].extend( + nearest_b.astype(float).tolist() ) - previous = nearest_by_pair.get(pair_key) - if previous is None or distance < previous: - nearest_by_pair[pair_key] = distance - for pair_key, distance in nearest_by_pair.items(): - pair_distances[pair_key].append(float(distance)) return { pair_key: _median(values) for pair_key, values in sorted(pair_distances.items()) @@ -242,18 +273,39 @@ def _build_contact_neighbors( ) -> tuple[list[set[int]], set[tuple[int, int]]]: neighbors = [set() for _ in range(len(elements))] contact_pairs: set[tuple[int, int]] = set() - for index_a, index_b in combinations(range(len(elements)), 2): - pair_key = _pair_key(elements[index_a], elements[index_b]) - cutoff = float(pair_contact_distance_medians.get(pair_key, 0.0)) - if cutoff <= 0.0: - continue - distance = float( - np.linalg.norm(coordinates[index_a] - coordinates[index_b]) + if len(elements) < 2: + return neighbors, contact_pairs + + coordinate_array = np.asarray(coordinates, dtype=float) + distance_matrix = np.linalg.norm( + coordinate_array[:, np.newaxis, :] + - coordinate_array[np.newaxis, :, :], + axis=2, + ) + thresholds = { + key: float(value) * _CONTACT_DISTANCE_SCALE + for key, value in pair_contact_distance_medians.items() + if float(value) > 0.0 + } + if not thresholds: + return neighbors, contact_pairs + + row_indices, column_indices = np.triu_indices(len(elements), k=1) + for atom_index, neighbor_index, distance in zip( + row_indices.tolist(), + column_indices.tolist(), + distance_matrix[row_indices, column_indices].tolist(), + strict=False, + ): + threshold = thresholds.get( + _pair_key(elements[atom_index], elements[neighbor_index]), + 0.0, ) - if distance <= cutoff * _CONTACT_DISTANCE_SCALE: - neighbors[index_a].add(index_b) - neighbors[index_b].add(index_a) - contact_pairs.add((index_a, index_b)) + if threshold <= 0.0 or float(distance) > threshold: + continue + neighbors[atom_index].add(neighbor_index) + neighbors[neighbor_index].add(atom_index) + contact_pairs.add((atom_index, neighbor_index)) return neighbors, contact_pairs @@ -329,6 +381,7 @@ def describe_parsed_contrast_structure( *, expected_core_counts: dict[str, int], pair_contact_distance_medians: dict[str, float], + include_geometry_metrics: bool = True, ) -> ContrastStructureDescriptor: coordinates = np.asarray(parsed_structure.coordinates, dtype=float) elements = parsed_structure.elements @@ -349,52 +402,54 @@ def describe_parsed_contrast_structure( solvent_set = set(solvent_indices) bond_lengths: defaultdict[str, list[float]] = defaultdict(list) - for atom_index in range(len(elements)): - for neighbor_index in neighbors[atom_index]: - if neighbor_index <= atom_index: - continue - bond_lengths[ - _pair_key(elements[atom_index], elements[neighbor_index]) - ].append( - float( - np.linalg.norm( - coordinates[atom_index] - coordinates[neighbor_index] + angle_values: defaultdict[str, list[float]] = defaultdict(list) + coordination_values: defaultdict[str, list[int]] = defaultdict(list) + if include_geometry_metrics: + for atom_index in range(len(elements)): + for neighbor_index in neighbors[atom_index]: + if neighbor_index <= atom_index: + continue + bond_lengths[ + _pair_key(elements[atom_index], elements[neighbor_index]) + ].append( + float( + np.linalg.norm( + coordinates[atom_index] + - coordinates[neighbor_index] + ) ) ) - ) - angle_values: defaultdict[str, list[float]] = defaultdict(list) - for center_index, center_neighbors in enumerate(neighbors): - if len(center_neighbors) < 2: - continue - for neighbor_a, neighbor_b in combinations( - sorted(center_neighbors), 2 - ): - angle = _angle_between_vectors( - coordinates[neighbor_a] - coordinates[center_index], - coordinates[neighbor_b] - coordinates[center_index], - ) - if angle is None: + for center_index, center_neighbors in enumerate(neighbors): + if len(center_neighbors) < 2: continue - angle_values[ - _triplet_key( - elements[neighbor_a], - elements[center_index], - elements[neighbor_b], + for neighbor_a, neighbor_b in combinations( + sorted(center_neighbors), 2 + ): + angle = _angle_between_vectors( + coordinates[neighbor_a] - coordinates[center_index], + coordinates[neighbor_b] - coordinates[center_index], ) - ].append(float(angle)) + if angle is None: + continue + angle_values[ + _triplet_key( + elements[neighbor_a], + elements[center_index], + elements[neighbor_b], + ) + ].append(float(angle)) - present_elements = tuple(sorted(set(elements))) - coordination_values: defaultdict[str, list[int]] = defaultdict(list) - for center_index, center_element in enumerate(elements): - neighbor_counts = Counter( - elements[neighbor_index] - for neighbor_index in neighbors[center_index] - ) - for neighbor_element in present_elements: - coordination_values[ - _coordination_key(center_element, neighbor_element) - ].append(int(neighbor_counts.get(neighbor_element, 0))) + present_elements = tuple(sorted(set(elements))) + for center_index, center_element in enumerate(elements): + neighbor_counts = Counter( + elements[neighbor_index] + for neighbor_index in neighbors[center_index] + ) + for neighbor_element in present_elements: + coordination_values[ + _coordination_key(center_element, neighbor_element) + ].append(int(neighbor_counts.get(neighbor_element, 0))) direct_solvent_indices: set[int] = set() for atom_index in solvent_set: diff --git a/tests/representativefinder_performance/benchmark_representativefinder.py b/tests/representativefinder_performance/benchmark_representativefinder.py new file mode 100644 index 0000000..5b7ce2f --- /dev/null +++ b/tests/representativefinder_performance/benchmark_representativefinder.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import argparse +import json +import shutil +import tempfile +import time +from datetime import datetime +from pathlib import Path + +from saxshell.bondanalysis import AngleTripletDefinition, BondPairDefinition +from saxshell.representativefinder import ( + RepresentativeFinderSettings, + analyze_representative_structure_folder, +) + +REFERENCE_PB2I4_DIR = Path( + "/Users/keithwhite/repos/cluster_extraction/" + "041_cp2k_pbi2_dmf_0p7M_RT/" + "clusters_xyz2pdb_splitxyz_f1002_t497p5fs0001/Pb2I4" +) +OUTPUT_ROOT = Path(__file__).resolve().parent / "output_results" + + +def _settings(worker_count: int) -> RepresentativeFinderSettings: + return RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 4.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 4.2, 4.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + parallel_workers=worker_count, + ) + + +def _prepare_sample( + reference_dir: Path, + workspace_dir: Path, + *, + sample_count: int, +) -> tuple[Path, list[str]]: + sample_dir = workspace_dir / reference_dir.name + sample_dir.mkdir(parents=True) + source_files = sorted(reference_dir.glob("*.pdb"))[:sample_count] + for source_path in source_files: + target_path = sample_dir / source_path.name + try: + target_path.symlink_to(source_path) + except OSError: + shutil.copy2(source_path, target_path) + return sample_dir, [path.name for path in source_files] + + +def _run_once( + sample_dir: Path, + output_dir: Path, + *, + worker_count: int, +) -> dict[str, object]: + started = time.perf_counter() + result = analyze_representative_structure_folder( + sample_dir, + settings=_settings(worker_count), + output_dir=output_dir, + ) + elapsed = time.perf_counter() - started + return { + "worker_count": int(worker_count), + "elapsed_seconds": round(float(elapsed), 6), + "candidate_count": len(result.candidates), + "selected_candidate": result.selected_candidate.file_name, + "selected_score_total": result.selected_candidate.score_total, + "skipped_files": len(result.skipped_files), + "summary_json_path": str(result.summary_json_path), + "score_table_path": str(result.score_table_path), + "summary_text_path": str(result.summary_text_path), + } + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark representative finder serial and parallel runs." + ) + parser.add_argument( + "--reference-dir", + default=str(REFERENCE_PB2I4_DIR), + help="Folder containing Pb2I4 .pdb representative candidates.", + ) + parser.add_argument( + "--sample-size", + type=int, + default=256, + help="Number of reference files to include in the benchmark sample.", + ) + parser.add_argument( + "--workers", + type=int, + nargs="+", + default=[1, 0], + help="Worker counts to benchmark. Use 0 for auto-sized parallelism.", + ) + parser.add_argument( + "--output-root", + default=str(OUTPUT_ROOT), + help="Directory where benchmark outputs and reports are written.", + ) + args = parser.parse_args() + + reference_dir = Path(args.reference_dir).expanduser().resolve() + if not reference_dir.is_dir(): + raise SystemExit(f"Reference folder is missing: {reference_dir}") + + output_root = Path(args.output_root).expanduser().resolve() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_dir = output_root / f"pb2i4_benchmark_{timestamp}" + run_dir.mkdir(parents=True, exist_ok=True) + workspace_dir = Path(tempfile.mkdtemp(prefix="pb2i4_benchmark_")) + try: + sample_dir, sample_files = _prepare_sample( + reference_dir, + workspace_dir, + sample_count=max(int(args.sample_size), 1), + ) + runs = [] + for worker_count in args.workers: + label = ( + "auto" if int(worker_count) == 0 else str(int(worker_count)) + ) + runs.append( + _run_once( + sample_dir, + run_dir / f"workers_{label}", + worker_count=int(worker_count), + ) + ) + report = { + "generated_at": datetime.now().isoformat(timespec="seconds"), + "reference_dir": str(reference_dir), + "sample_size": len(sample_files), + "sample_files": sample_files, + "runs": runs, + } + report_path = run_dir / "benchmark_results.json" + report_path.write_text( + json.dumps(report, indent=2) + "\n", + encoding="utf-8", + ) + latest_path = output_root / "latest_benchmark_results.json" + latest_path.write_text( + json.dumps(report, indent=2) + "\n", + encoding="utf-8", + ) + finally: + shutil.rmtree(workspace_dir, ignore_errors=True) + + print(report_path) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/representativefinder_performance/test_parallel_representativefinder.py b/tests/representativefinder_performance/test_parallel_representativefinder.py new file mode 100644 index 0000000..1717159 --- /dev/null +++ b/tests/representativefinder_performance/test_parallel_representativefinder.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import shutil +from pathlib import Path + +import pytest + +from saxshell.bondanalysis import AngleTripletDefinition, BondPairDefinition +from saxshell.representativefinder import ( + RepresentativeFinderSettings, + analyze_representative_structure_folder, +) + +REFERENCE_PB2I4_DIR = Path( + "/Users/keithwhite/repos/cluster_extraction/" + "041_cp2k_pbi2_dmf_0p7M_RT/" + "clusters_xyz2pdb_splitxyz_f1002_t497p5fs0001/Pb2I4" +) + + +def _write_xyz_structure( + path: Path, + atoms: list[tuple[str, float, float, float]], +) -> None: + lines = [str(len(atoms)), path.stem] + for element, x_coord, y_coord, z_coord in atoms: + lines.append(f"{element} {x_coord:.3f} {y_coord:.3f} {z_coord:.3f}") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _build_representative_sample_folder(tmp_path: Path) -> Path: + stoich_dir = tmp_path / "PbI2" + stoich_dir.mkdir(parents=True) + for index, distance in enumerate( + (1.95, 2.05, 2.15, 2.25, 2.35, 2.45, 2.55, 2.65), + start=1, + ): + _write_xyz_structure( + stoich_dir / f"candidate_{index:02d}.xyz", + [ + ("Pb", 0.0, 0.0, 0.0), + ("I", distance, 0.0, 0.0), + ("I", 0.0, distance, 0.0), + ("O", 0.0, 0.0, distance + 0.35), + ("O", 0.0, 0.0, distance + 2.7), + ], + ) + return stoich_dir + + +def _settings(worker_count: int) -> RepresentativeFinderSettings: + return RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 4.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 4.2, 4.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + parallel_workers=worker_count, + ) + + +def _link_reference_sample( + reference_dir: Path, + sample_dir: Path, + *, + sample_count: int, +) -> None: + sample_dir.mkdir(parents=True) + for source_path in sorted(reference_dir.glob("*.pdb"))[:sample_count]: + target_path = sample_dir / source_path.name + try: + target_path.symlink_to(source_path) + except OSError: + shutil.copy2(source_path, target_path) + + +def test_parallel_representativefinder_matches_serial_result(tmp_path): + stoich_dir = _build_representative_sample_folder(tmp_path) + + serial = analyze_representative_structure_folder( + stoich_dir, + settings=_settings(1), + output_dir=tmp_path / "serial_output", + ) + parallel = analyze_representative_structure_folder( + stoich_dir, + settings=_settings(4), + output_dir=tmp_path / "parallel_output", + ) + + assert parallel.selected_candidate.file_name == ( + serial.selected_candidate.file_name + ) + assert len(parallel.candidates) == len(serial.candidates) + serial_scores = { + candidate.relative_label: candidate.score_total + for candidate in serial.candidates + } + for candidate in parallel.candidates: + assert candidate.score_total == pytest.approx( + serial_scores[candidate.relative_label], + rel=0.0, + abs=1.0e-12, + ) + assert parallel.summary_json_path.is_file() + assert parallel.score_table_path.is_file() + + +@pytest.mark.skipif( + not REFERENCE_PB2I4_DIR.is_dir(), + reason="Pb2I4 representative benchmark reference folder is unavailable.", +) +def test_reference_pb2i4_sample_matches_serial_and_parallel(tmp_path): + sample_dir = tmp_path / "Pb2I4" + _link_reference_sample( + REFERENCE_PB2I4_DIR, + sample_dir, + sample_count=32, + ) + + serial = analyze_representative_structure_folder( + sample_dir, + settings=_settings(1), + output_dir=tmp_path / "reference_serial_output", + ) + parallel = analyze_representative_structure_folder( + sample_dir, + settings=_settings(4), + output_dir=tmp_path / "reference_parallel_output", + ) + + assert len(serial.candidates) == 32 + assert len(parallel.candidates) == 32 + assert parallel.selected_candidate.file_name == ( + serial.selected_candidate.file_name + ) + assert parallel.selected_candidate.score_total == pytest.approx( + serial.selected_candidate.score_total, + rel=0.0, + abs=1.0e-12, + ) diff --git a/tests/test_representativefinder.py b/tests/test_representativefinder.py new file mode 100644 index 0000000..8890d46 --- /dev/null +++ b/tests/test_representativefinder.py @@ -0,0 +1,1676 @@ +from __future__ import annotations + +import os +import time +from datetime import datetime +from pathlib import Path + +import numpy as np +import pytest +from PySide6.QtWidgets import QApplication, QPushButton, QScrollArea, QSplitter + +from saxshell.bondanalysis import ( + AngleTripletDefinition, + BondAnalysisPreset, + BondPairDefinition, +) +from saxshell.fullrmc.project_model import ensure_rmcsetup_structure +from saxshell.fullrmc.representatives import ( + load_representative_selection_metadata, +) +from saxshell.fullrmc.solvent_handling import ( + SolventHandlingEntry, + SolventHandlingMetadata, + SolventHandlingSettings, + available_representative_structure_modes, + load_solvent_handling_metadata, + representative_structure_mode_is_ready, + representative_structure_path_for_mode, + save_solvent_handling_metadata, +) +from saxshell.representativefinder import ( + RepresentativeFinderCandidate, + RepresentativeFinderOperationCancelled, + RepresentativeFinderResult, + RepresentativeFinderSettings, + analyze_representative_structure_folder, + build_representativefinder_run_config, + default_representativefinder_run_file_path, + inspect_representative_structure_input, + load_representativefinder_result, + load_representativefinder_run_config, + persist_representativefinder_result_to_project, + run_representativefinder_run_config, + save_representativefinder_run_config, +) +from saxshell.representativefinder.cli import ( + main as representativefinder_cli_main, +) +from saxshell.representativefinder.ui.main_window import ( + RepresentativeStructureFinderMainWindow, +) +from saxshell.representativefinder.ui.run_file_window import ( + RepresentativeFinderRunFileWindow, +) +from saxshell.saxs.electron_density_mapping.ui.viewer import ( + ElectronDensityStructureViewer, +) +from saxshell.saxs.project_manager import SAXSProjectManager +from saxshell.structure import PDBAtom, PDBStructure +from saxshell.xyz2pdb import create_reference_molecule + + +def _write_xyz_structure( + path: Path, atoms: list[tuple[str, float, float, float]] +) -> None: + lines = [str(len(atoms)), path.stem] + for element, x_coord, y_coord, z_coord in atoms: + lines.append(f"{element} {x_coord:.3f} {y_coord:.3f} {z_coord:.3f}") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _build_test_solvent_reference_library( + tmp_path: Path, +) -> tuple[Path, Path]: + reference_source = tmp_path / "water_source.pdb" + PDBStructure( + atoms=[ + PDBAtom( + atom_id=1, + atom_name="O1", + residue_name="HOH", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="O", + ), + PDBAtom( + atom_id=2, + atom_name="H1", + residue_name="HOH", + residue_number=1, + coordinates=np.array([0.958, 0.0, 0.0], dtype=float), + element="H", + ), + PDBAtom( + atom_id=3, + atom_name="H2", + residue_name="HOH", + residue_number=1, + coordinates=np.array([-0.239, 0.927, 0.0], dtype=float), + element="H", + ), + ], + source_name="water_source", + ).write_pdb_file(reference_source) + reference_library_dir = tmp_path / "reference_library" + reference_library_dir.mkdir(parents=True, exist_ok=True) + result = create_reference_molecule( + reference_source, + reference_name="water_test", + residue_name="HOH", + library_dir=reference_library_dir, + ) + return reference_library_dir, result.path + + +def _write_complete_solvent_representative_pdb( + tmp_path: Path, + *, + reference_path: Path, +) -> Path: + reference_structure = PDBStructure.from_file(reference_path) + atoms = [ + PDBAtom( + atom_id=1, + atom_name="PB1", + residue_name="PBI", + residue_number=1, + coordinates=np.array([0.0, 0.0, 0.0], dtype=float), + element="Pb", + ) + ] + atom_id = 2 + for residue_name, residue_number, shift in ( + ("HOH", 2, np.array([3.0, 0.0, 0.0], dtype=float)), + ("ALT", 3, np.array([6.0, 0.0, 0.0], dtype=float)), + ): + for reference_atom in reference_structure.atoms: + atoms.append( + PDBAtom( + atom_id=atom_id, + atom_name=reference_atom.atom_name, + residue_name=residue_name, + residue_number=residue_number, + coordinates=reference_atom.coordinates.copy() + shift, + element=reference_atom.element, + ) + ) + atom_id += 1 + output_path = tmp_path / "fullsolv_candidate.pdb" + PDBStructure(atoms=atoms, source_name="fullsolv_candidate").write_pdb_file( + output_path + ) + return output_path + + +def _build_manual_representative_result( + structure_path: Path, + *, + output_dir: Path, + structure_label: str = "Pb", + atom_count: int | None = None, + element_counts: dict[str, int] | None = None, + solvent_atom_count: int = 0, +) -> RepresentativeFinderResult: + output_dir.mkdir(parents=True, exist_ok=True) + candidate = RepresentativeFinderCandidate( + file_path=structure_path.resolve(), + relative_label=structure_path.name, + motif_label="no_motif", + atom_count=int(atom_count if atom_count is not None else 1), + element_counts=dict(element_counts or {"Pb": 1}), + bond_values={}, + angle_values={}, + solvent_metrics={}, + solvent_atom_count=int(solvent_atom_count), + direct_solvent_atom_count=int(solvent_atom_count), + outer_solvent_atom_count=0, + mean_direct_solvent_coordination=0.0, + score_total=0.0, + score_bond=0.0, + score_angle=0.0, + score_solvent=0.0, + ) + summary_json_path = output_dir / "summary.json" + summary_json_path.write_text("{}", encoding="utf-8") + score_table_path = output_dir / "scores.tsv" + score_table_path.write_text("", encoding="utf-8") + summary_text_path = output_dir / "summary.txt" + summary_text_path.write_text("", encoding="utf-8") + representative_output_path = output_dir / structure_path.name + representative_output_path.write_text("", encoding="utf-8") + return RepresentativeFinderResult( + input_dir=structure_path.parent.resolve(), + output_dir=output_dir.resolve(), + structure_label=structure_label, + expected_core_counts=dict(element_counts or {"Pb": 1}), + settings=RepresentativeFinderSettings(), + generated_at=datetime.now().isoformat(timespec="seconds"), + candidates=(candidate,), + selected_candidate=candidate, + representative_output_path=representative_output_path, + skipped_files=(), + target_bond_values={}, + target_angle_values={}, + target_solvent_metrics={}, + summary_json_path=summary_json_path, + score_table_path=score_table_path, + summary_text_path=summary_text_path, + ) + + +def _build_representative_test_folder(tmp_path: Path) -> Path: + stoich_dir = tmp_path / "PbI2" + motif_dir = stoich_dir / "motif_corner" + stoich_dir.mkdir(parents=True) + motif_dir.mkdir(parents=True) + + _write_xyz_structure( + stoich_dir / "candidate_low.xyz", + [ + ("Pb", 0.0, 0.0, 0.0), + ("I", 2.0, 0.0, 0.0), + ("I", 0.0, 2.0, 0.0), + ("O", 0.0, 0.0, 2.4), + ("O", 0.0, 0.0, 4.9), + ], + ) + _write_xyz_structure( + motif_dir / "candidate_mid.xyz", + [ + ("Pb", 0.0, 0.0, 0.0), + ("I", 2.2, 0.0, 0.0), + ("I", 0.0, 2.2, 0.0), + ("O", 0.0, 0.0, 2.6), + ("O", 0.0, 0.0, 5.2), + ], + ) + _write_xyz_structure( + stoich_dir / "candidate_high.xyz", + [ + ("Pb", 0.0, 0.0, 0.0), + ("I", 2.8, 0.0, 0.0), + ("I", 0.0, 2.8, 0.0), + ("O", 0.0, 0.0, 1.8), + ], + ) + return stoich_dir + + +def _build_single_atom_test_folder(tmp_path: Path) -> Path: + stoich_dir = tmp_path / "I" + stoich_dir.mkdir(parents=True) + _write_xyz_structure( + stoich_dir / "candidate_01.xyz", + [("I", 0.0, 0.0, 0.0)], + ) + _write_xyz_structure( + stoich_dir / "candidate_02.xyz", + [("I", 0.0, 0.0, 0.0)], + ) + return stoich_dir + + +def _build_multi_stoichiometry_root(tmp_path: Path) -> tuple[Path, Path, Path]: + root_dir = tmp_path / "cluster_root" + root_dir.mkdir(parents=True) + pb_dir = _build_representative_test_folder(root_dir) + + sn_dir = root_dir / "SnBr2" + sn_motif_dir = sn_dir / "motif_edge" + sn_dir.mkdir(parents=True) + sn_motif_dir.mkdir(parents=True) + _write_xyz_structure( + sn_dir / "candidate_low.xyz", + [ + ("Sn", 0.0, 0.0, 0.0), + ("Br", 2.0, 0.0, 0.0), + ("Br", 0.0, 2.0, 0.0), + ("O", 0.0, 0.0, 2.4), + ("O", 0.0, 0.0, 4.9), + ], + ) + _write_xyz_structure( + sn_motif_dir / "candidate_mid.xyz", + [ + ("Sn", 0.0, 0.0, 0.0), + ("Br", 2.2, 0.0, 0.0), + ("Br", 0.0, 2.2, 0.0), + ("O", 0.0, 0.0, 2.6), + ("O", 0.0, 0.0, 5.2), + ], + ) + _write_xyz_structure( + sn_dir / "candidate_high.xyz", + [ + ("Sn", 0.0, 0.0, 0.0), + ("Br", 2.8, 0.0, 0.0), + ("Br", 0.0, 2.8, 0.0), + ("O", 0.0, 0.0, 1.8), + ], + ) + return root_dir, pb_dir, sn_dir + + +def _build_zinc_stoichiometry_folder(root_dir: Path) -> Path: + stoich_dir = root_dir / "ZnCl2" + motif_dir = stoich_dir / "motif_edge" + stoich_dir.mkdir(parents=True) + motif_dir.mkdir(parents=True) + _write_xyz_structure( + stoich_dir / "candidate_low.xyz", + [ + ("Zn", 0.0, 0.0, 0.0), + ("Cl", 2.0, 0.0, 0.0), + ("Cl", 0.0, 2.0, 0.0), + ], + ) + _write_xyz_structure( + motif_dir / "candidate_mid.xyz", + [ + ("Zn", 0.0, 0.0, 0.0), + ("Cl", 2.2, 0.0, 0.0), + ("Cl", 0.0, 2.2, 0.0), + ], + ) + return stoich_dir + + +@pytest.fixture(scope="module") +def qapp(): + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + app = QApplication.instance() + if app is None: + app = QApplication([]) + yield app + + +def test_representativefinder_workflow_selects_middle_candidate(tmp_path): + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + + result = analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + ) + + assert result.selected_candidate.file_name == "candidate_mid.xyz" + assert result.representative_output_path.is_file() + assert result.summary_json_path.is_file() + assert result.score_table_path.is_file() + assert result.summary_text_path.is_file() + assert "candidate_mid.xyz" in result.summary_text() + + +def test_representativefinder_result_json_preserves_analysis_details(tmp_path): + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + result = analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + ) + + loaded = load_representativefinder_result(result.summary_json_path) + bond_definition = loaded.settings.bond_pairs[0] + angle_definition = loaded.settings.angle_triplets[0] + + assert ( + loaded.selected_candidate.file_name + == result.selected_candidate.file_name + ) + assert loaded.settings.bond_pairs == settings.bond_pairs + assert loaded.settings.angle_triplets == settings.angle_triplets + assert loaded.target_bond_values[bond_definition].size > 0 + assert loaded.target_angle_values[angle_definition].size > 0 + assert loaded.selected_candidate.bond_values[bond_definition] + assert loaded.selected_candidate.angle_values[angle_definition] + assert loaded.candidates[0].score_total is not None + + +def test_representativefinder_single_atom_shortcuts_full_analysis(tmp_path): + stoich_dir = _build_single_atom_test_folder(tmp_path) + progress_events: list[tuple[int, int, str]] = [] + log_messages: list[str] = [] + + result = analyze_representative_structure_folder( + stoich_dir, + settings=RepresentativeFinderSettings(), + output_dir=tmp_path / "single_atom_output", + progress_callback=lambda processed, total, message: progress_events.append( + (processed, total, message) + ), + log_callback=log_messages.append, + ) + + assert result.selected_candidate.atom_count == 1 + assert result.selected_candidate.element_counts == {"I": 1} + assert result.selected_candidate.score_total == pytest.approx(0.0) + assert result.selected_candidate.score_bond == pytest.approx(0.0) + assert result.selected_candidate.score_angle == pytest.approx(0.0) + assert result.selected_candidate.score_solvent == pytest.approx(0.0) + assert any( + "Single-atom candidate structures were detected" in note + for note in result.selected_candidate.descriptor_notes + ) + assert result.representative_output_path.is_file() + assert result.summary_json_path.is_file() + assert result.score_table_path.is_file() + assert result.summary_text_path.is_file() + assert not any( + "Aggregating bond and angle distributions" in message + for _processed, _total, message in progress_events + ) + assert not any( + "Scoring " in message for _p, _t, message in progress_events + ) + assert any( + "single-atom candidate set" in message.lower() + for message in log_messages + ) + assert progress_events[-1][0] == progress_events[-1][1] + assert ( + progress_events[-1][2] + == "Representative-structure selection complete." + ) + + +def test_representativefinder_workflow_reports_post_measurement_progress( + tmp_path, +): + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + progress_events: list[tuple[int, int, str]] = [] + + analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + progress_callback=lambda processed, total, message: progress_events.append( + (processed, total, message) + ), + ) + + assert progress_events + assert any("Scoring " in message for _p, _t, message in progress_events) + assert any( + "Writing representative outputs" in message + for _p, _t, message in progress_events + ) + assert progress_events[-1][0] == progress_events[-1][1] + assert ( + progress_events[-1][2] + == "Representative-structure selection complete." + ) + + +def test_representativefinder_workflow_generates_optional_predicted_output( + tmp_path, +): + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + generate_predicted_optimized_representative=True, + ) + + result = analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + ) + + assert result.selected_candidate.file_name == "candidate_mid.xyz" + assert result.predicted_candidate is not None + assert result.predicted_output_path is not None + assert result.predicted_output_path.is_file() + assert result.predicted_candidate.file_path == result.predicted_output_path + assert result.predicted_candidate.atom_count == 3 + assert result.solvent_completed_predicted_candidate is None + assert result.solvent_completed_predicted_output_path is None + assert any( + "Cluster Dynamics ML geometry scaffold" in note + for note in result.predicted_generation_notes + ) + assert "Predicted optimized representative" in result.summary_text() + + +def test_representativefinder_project_persistence_writes_shared_partialsolv_outputs( + tmp_path, +): + stoich_dir = _build_representative_test_folder(tmp_path) + sn_root, _pb_dir, sn_dir = _build_multi_stoichiometry_root( + tmp_path / "multi" + ) + del sn_root + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + pb_result = analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "pb_output", + ) + sn_result = analyze_representative_structure_folder( + sn_dir, + settings=settings, + output_dir=tmp_path / "sn_output", + ) + + pb_shared_path = persist_representativefinder_result_to_project( + tmp_path, + pb_result, + ) + sn_shared_path = persist_representativefinder_result_to_project( + tmp_path, + sn_result, + ) + + rmcsetup_paths = ensure_rmcsetup_structure(tmp_path) + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + assert metadata is not None + assert metadata.selection_mode == "representative_finder" + assert pb_shared_path.is_file() + assert sn_shared_path.is_file() + assert pb_shared_path.parent == ( + rmcsetup_paths.representative_partial_solvent_dir / "PbI2" + ) + assert sn_shared_path.parent == ( + rmcsetup_paths.representative_partial_solvent_dir / "SnBr2" + ) + assert [entry.structure for entry in metadata.representative_entries] == [ + "PbI2", + "SnBr2", + ] + assert all( + entry.motif == "no_motif" for entry in metadata.representative_entries + ) + assert all( + entry.source_solvent_mode == "partialsolv" + for entry in metadata.representative_entries + ) + assert all( + Path(entry.source_file).is_file() + for entry in metadata.representative_entries + ) + assert all( + entry.project_cached_results_path + and Path(entry.project_cached_results_path).is_file() + for entry in metadata.representative_entries + ) + cached_pb_result = load_representativefinder_result( + metadata.representative_entries[0].project_cached_results_path + ) + assert cached_pb_result.structure_label == "PbI2" + assert cached_pb_result.target_bond_values + assert cached_pb_result.selected_candidate.bond_values + assert ( + pytest.approx( + sum( + entry.selected_weight + for entry in metadata.representative_entries + ), + rel=0.0, + abs=1.0e-9, + ) + == 1.0 + ) + + state = SAXSProjectManager().inspect_representative_structures(tmp_path) + assert state.representative_count == 2 + assert state.source_files_ready is True + assert "partialsolv" in state.available_modes + assert ( + state.partialsolv_dir + == rmcsetup_paths.representative_partial_solvent_dir + ) + + +def test_representativefinder_run_file_round_trip_uses_project_relative_paths( + tmp_path, +): + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + output_dir = tmp_path / "representative_finder" / "batch_run" + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + solvent_weight=0.0, + parallel_workers=2, + ) + config = build_representativefinder_run_config( + project_dir=tmp_path, + input_dir=root_dir, + output_dir=output_dir, + analysis_mode="all", + settings=settings, + overwrite_existing=True, + ) + run_file_path = default_representativefinder_run_file_path(tmp_path) + + save_representativefinder_run_config(run_file_path, config) + loaded = load_representativefinder_run_config(run_file_path) + + assert loaded.input_dir == "cluster_root" + assert loaded.output_dir == "representative_finder/batch_run" + assert loaded.analysis_mode == "all" + assert loaded.overwrite_existing is True + assert loaded.settings.parallel_workers == 2 + assert [pair.display_label for pair in loaded.settings.bond_pairs] == [ + "Pb-I", + "Sn-Br", + ] + + +def test_representativefinder_cli_run_file_publishes_project_registry( + tmp_path, +): + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.0, + parallel_workers=2, + ) + config = build_representativefinder_run_config( + project_dir=tmp_path, + input_dir=root_dir, + output_dir=tmp_path / "representative_finder" / "cli_batch", + analysis_mode="all", + settings=settings, + overwrite_existing=True, + ) + run_file_path = default_representativefinder_run_file_path(tmp_path) + save_representativefinder_run_config(run_file_path, config) + + summary = run_representativefinder_run_config( + tmp_path, + load_representativefinder_run_config(run_file_path), + run_file_path=run_file_path, + ) + + assert summary.completed_count == 2 + assert summary.failed_count == 0 + assert len(summary.project_representative_paths) == 2 + assert all(path.is_file() for path in summary.project_representative_paths) + + rmcsetup_paths = ensure_rmcsetup_structure(tmp_path) + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + assert metadata is not None + assert metadata.selection_mode == "representative_finder" + assert [entry.structure for entry in metadata.representative_entries] == [ + "PbI2", + "SnBr2", + ] + assert all( + Path(entry.source_file).is_file() + for entry in metadata.representative_entries + ) + state = SAXSProjectManager().inspect_representative_structures(tmp_path) + assert state.representative_count == 2 + assert state.source_files_ready is True + assert "nosolv" in state.available_modes + + +def test_representativefinder_cli_command_uses_project_default_run_file( + tmp_path, + capsys, +): + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + solvent_weight=0.0, + parallel_workers=2, + ) + config = build_representativefinder_run_config( + project_dir=tmp_path, + input_dir=root_dir, + output_dir=tmp_path / "representative_finder" / "cli_entrypoint", + analysis_mode="all", + settings=settings, + overwrite_existing=True, + ) + save_representativefinder_run_config( + default_representativefinder_run_file_path(tmp_path), + config, + ) + + exit_code = representativefinder_cli_main(["run", str(tmp_path)]) + + captured = capsys.readouterr() + assert exit_code == 0 + assert "Representative CLI run complete" in captured.out + assert "Completed: 2" in captured.out + assert "Failed: 0" in captured.out + + +def test_representativefinder_run_file_window_builds_beta_config( + qapp, + tmp_path, +): + del qapp + root_dir, _pb_dir, _sn_dir = _build_multi_stoichiometry_root(tmp_path) + window = RepresentativeFinderRunFileWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + window.bond_pairs_edit.setPlainText("Pb:I:3.2\nSn:Br:3.2") + window.angle_triplets_edit.setPlainText("Pb:I:I:3.2:3.2\nSn:Br:Br:3.2:3.2") + window.output_dir_edit.setText( + str(tmp_path / "representative_finder" / "window_config") + ) + window.analysis_mode_combo.setCurrentIndex(0) + window.worker_spin.setValue(4) + + config = window._current_config(tmp_path) + + assert config.analysis_mode == "all" + assert config.input_dir == "cluster_root" + assert config.output_dir == "representative_finder/window_config" + assert config.settings.parallel_workers == 4 + assert [pair.display_label for pair in config.settings.bond_pairs] == [ + "Pb-I", + "Sn-Br", + ] + window.close() + + +def test_representativefinder_project_persistence_classifies_fullsolv_outputs( + tmp_path, +): + reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + representative_path = _write_complete_solvent_representative_pdb( + tmp_path, + reference_path=reference_path, + ) + result = _build_manual_representative_result( + representative_path, + output_dir=tmp_path / "representative_output", + structure_label="Pb", + atom_count=7, + element_counts={"Pb": 1, "O": 2, "H": 4}, + solvent_atom_count=6, + ) + + rmcsetup_paths = ensure_rmcsetup_structure(tmp_path) + solvent_metadata = SolventHandlingMetadata( + settings=SolventHandlingSettings.from_dict( + { + "reference_source": "custom", + "custom_reference_path": str(reference_path), + "reference_match_tolerance_a": 0.25, + } + ), + reference_path=str(reference_path), + reference_name="water_test", + reference_residue_name="HOH", + updated_at=datetime.now().isoformat(timespec="seconds"), + representative_selection_mode="representative_finder", + detected_distribution_status="unknown", + detected_distribution_note="", + aggregate_solute_element_counts={"Pb": 1}, + entries=[], + ) + save_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path, + solvent_metadata, + ) + + shared_path = persist_representativefinder_result_to_project( + tmp_path, + result, + ) + + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + assert metadata is not None + assert shared_path.parent == rmcsetup_paths.pdb_with_solvent_dir / "Pb" + assert metadata.representative_entries[0].source_solvent_mode == "fullsolv" + assert ( + Path(metadata.representative_entries[0].source_file).resolve() + == shared_path.resolve() + ) + + state = SAXSProjectManager().inspect_representative_structures(tmp_path) + assert state.representative_count == 1 + assert state.source_files_ready is True + assert "fullsolv" in state.available_modes + assert "partialsolv" not in state.available_modes + + assert available_representative_structure_modes(metadata, None) == [ + "full_solvent" + ] + assert ( + representative_structure_path_for_mode( + metadata.representative_entries[0], + None, + "full_solvent", + ) + == shared_path.resolve() + ) + + +def test_representativefinder_project_persistence_mirrors_single_atom_outputs_to_all_solvent_variants( + tmp_path, +): + single_atom_path = tmp_path / "I_candidate.xyz" + _write_xyz_structure( + single_atom_path, + [("I", 0.0, 0.0, 0.0)], + ) + result = _build_manual_representative_result( + single_atom_path, + output_dir=tmp_path / "single_atom_output", + structure_label="I", + atom_count=1, + element_counts={"I": 1}, + solvent_atom_count=0, + ) + + shared_path = persist_representativefinder_result_to_project( + tmp_path, + result, + ) + + rmcsetup_paths = ensure_rmcsetup_structure(tmp_path) + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + assert metadata is not None + assert len(metadata.representative_entries) == 1 + + entry = metadata.representative_entries[0] + nosolv_path = ( + rmcsetup_paths.pdb_no_solvent_dir / "I" / shared_path.name + ).resolve() + partialsolv_path = ( + rmcsetup_paths.representative_partial_solvent_dir + / "I" + / shared_path.name + ).resolve() + fullsolv_path = ( + rmcsetup_paths.pdb_with_solvent_dir / "I" / shared_path.name + ).resolve() + + assert shared_path.resolve() == nosolv_path + assert entry.source_solvent_mode == "nosolv" + assert nosolv_path.is_file() + assert partialsolv_path.is_file() + assert fullsolv_path.is_file() + + state = SAXSProjectManager().inspect_representative_structures(tmp_path) + assert state.representative_count == 1 + assert state.source_files_ready is True + assert set(state.available_modes) == {"nosolv", "partialsolv", "fullsolv"} + + assert available_representative_structure_modes(metadata, None) == [ + "no_solvent", + "partial_solvent", + "full_solvent", + ] + assert representative_structure_mode_is_ready(metadata, None) is True + assert ( + representative_structure_path_for_mode( + entry, + None, + "no_solvent", + ) + == nosolv_path + ) + assert ( + representative_structure_path_for_mode( + entry, + None, + "partial_solvent", + ) + == partialsolv_path + ) + assert ( + representative_structure_path_for_mode( + entry, + None, + "full_solvent", + ) + == fullsolv_path + ) + + +def test_representativefinder_project_persistence_replaces_stale_mode_artifacts( + tmp_path, +): + partial_source = tmp_path / "partial_candidate.xyz" + _write_xyz_structure( + partial_source, + [ + ("Pb", 0.0, 0.0, 0.0), + ("O", 2.5, 0.0, 0.0), + ], + ) + partial_result = _build_manual_representative_result( + partial_source, + output_dir=tmp_path / "partial_output", + structure_label="Pb", + atom_count=2, + element_counts={"Pb": 1, "O": 1}, + solvent_atom_count=1, + ) + partial_shared_path = persist_representativefinder_result_to_project( + tmp_path, + partial_result, + ) + + rmcsetup_paths = ensure_rmcsetup_structure(tmp_path) + stale_no_solvent_path = ( + rmcsetup_paths.pdb_no_solvent_dir / "Pb" / "Pb__stale_nosolv.pdb" + ) + stale_full_solvent_path = ( + rmcsetup_paths.pdb_with_solvent_dir / "Pb" / "Pb__stale_fullsolv.pdb" + ) + stale_no_solvent_path.parent.mkdir(parents=True, exist_ok=True) + stale_full_solvent_path.parent.mkdir(parents=True, exist_ok=True) + stale_no_solvent_path.write_text("stale\n", encoding="utf-8") + stale_full_solvent_path.write_text("stale\n", encoding="utf-8") + + _reference_library_dir, reference_path = ( + _build_test_solvent_reference_library(tmp_path) + ) + solvent_metadata = SolventHandlingMetadata( + settings=SolventHandlingSettings.from_dict( + { + "reference_source": "custom", + "custom_reference_path": str(reference_path), + "reference_match_tolerance_a": 0.25, + } + ), + reference_path=str(reference_path), + reference_name="water_test", + reference_residue_name="HOH", + updated_at=datetime.now().isoformat(timespec="seconds"), + representative_selection_mode="representative_finder", + detected_distribution_status="partial_solvent", + detected_distribution_note="", + aggregate_solute_element_counts={"Pb": 1}, + entries=[ + SolventHandlingEntry( + structure="Pb", + motif="no_motif", + param="Pb", + source_file=str(partial_shared_path), + no_solvent_pdb=str(stale_no_solvent_path), + completed_pdb=str(stale_full_solvent_path), + atom_count_no_solvent=1, + atom_count_completed=4, + solvent_atoms_added=3, + solvent_molecules_added=1, + solvent_mode="partial_solvent", + completion_strategy="stale", + heuristic_note="stale", + ) + ], + ) + save_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path, + solvent_metadata, + ) + + fullsolv_source = _write_complete_solvent_representative_pdb( + tmp_path, + reference_path=reference_path, + ) + fullsolv_result = _build_manual_representative_result( + fullsolv_source, + output_dir=tmp_path / "fullsolv_output", + structure_label="Pb", + atom_count=7, + element_counts={"Pb": 1, "O": 2, "H": 4}, + solvent_atom_count=6, + ) + + shared_path = persist_representativefinder_result_to_project( + tmp_path, + fullsolv_result, + ) + + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + reloaded_solvent_metadata = load_solvent_handling_metadata( + rmcsetup_paths.solvent_handling_path + ) + + assert metadata is not None + assert shared_path.parent == rmcsetup_paths.pdb_with_solvent_dir / "Pb" + assert shared_path.is_file() + assert not partial_shared_path.exists() + assert not stale_no_solvent_path.exists() + assert not stale_full_solvent_path.exists() + assert metadata.representative_entries[0].source_solvent_mode == "fullsolv" + assert ( + Path(metadata.representative_entries[0].source_file).resolve() + == shared_path.resolve() + ) + assert reloaded_solvent_metadata is not None + assert reloaded_solvent_metadata.entries == [] + + +def test_representativefinder_window_restores_saved_project_representative_attributes( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + root_dir, pb_dir, sn_dir = _build_multi_stoichiometry_root(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + pb_result = analyze_representative_structure_folder( + pb_dir, + settings=settings, + output_dir=tmp_path / "pb_output", + ) + sn_result = analyze_representative_structure_folder( + sn_dir, + settings=settings, + output_dir=tmp_path / "sn_output", + ) + pb_shared_path = persist_representativefinder_result_to_project( + tmp_path, + pb_result, + ) + sn_shared_path = persist_representativefinder_result_to_project( + tmp_path, + sn_result, + ) + + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + + assert window.stoichiometry_table.item(0, 3).text() == "Complete" + assert window.stoichiometry_table.item(0, 4).text() == pb_shared_path.name + assert window.stoichiometry_table.item(0, 5).text() == ( + f"{float(pb_result.selected_candidate.score_total):.6f}" + ) + assert window.stoichiometry_table.item(0, 6).text() == "pb_output" + pb_open_button = window.stoichiometry_table.cellWidget(0, 7) + assert isinstance(pb_open_button, QPushButton) + assert pb_open_button.isEnabled() is True + + assert window.stoichiometry_table.item(1, 3).text() == "Complete" + assert window.stoichiometry_table.item(1, 4).text() == sn_shared_path.name + assert window.stoichiometry_table.item(1, 5).text() == ( + f"{float(sn_result.selected_candidate.score_total):.6f}" + ) + assert window.stoichiometry_table.item(1, 6).text() == "sn_output" + sn_open_button = window.stoichiometry_table.cellWidget(1, 7) + assert isinstance(sn_open_button, QPushButton) + assert sn_open_button.isEnabled() is True + + window._select_stoichiometry_row_by_key(str(sn_dir.resolve())) + summary = window.result_summary_box.toPlainText() + assert ( + window.run_status_label.text() + == "Representative selection: restored from saved project analysis" + ) + assert "Status: Complete" in summary + assert f"Representative: {sn_shared_path.name}" in summary + assert ( + f"Project representative file: {sn_shared_path.resolve()}" in summary + ) + assert "Source solvent mode: partialsolv" in summary + assert window.candidate_table.rowCount() == len(sn_result.candidates) + assert window.plot_widget.distribution_selector_combo.count() == 2 + labels = [ + window.plot_widget.distribution_selector_combo.itemText(index) + for index in range( + window.plot_widget.distribution_selector_combo.count() + ) + ] + assert any("Sn-Br" in label for label in labels) + assert window.bond_pair_table.rowCount() == 2 + assert window.angle_triplet_table.rowCount() == 2 + window.close() + + +def test_representativefinder_all_mode_skips_saved_project_representatives_until_overwrite( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + root_dir, pb_dir, sn_dir = _build_multi_stoichiometry_root(tmp_path) + zn_dir = _build_zinc_stoichiometry_folder(root_dir) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + pb_result = analyze_representative_structure_folder( + pb_dir, + settings=settings, + output_dir=tmp_path / "pb_output", + ) + sn_result = analyze_representative_structure_folder( + sn_dir, + settings=settings, + output_dir=tmp_path / "sn_output", + ) + persist_representativefinder_result_to_project(tmp_path, pb_result) + persist_representativefinder_result_to_project(tmp_path, sn_result) + + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + window.analysis_mode_combo.setCurrentIndex(1) + assert window.overwrite_existing_checkbox.isChecked() is False + + targets = window._analysis_targets_from_inputs( + output_root=tmp_path / "batch_output", + settings=RepresentativeFinderSettings(), + ) + assert [target.inspection.input_dir for target in targets] == [ + zn_dir.resolve() + ] + + window._reset_stoichiometry_run_state( + {str(target.inspection.input_dir) for target in targets} + ) + pb_row = window._stoichiometry_row_by_input_dir[str(pb_dir.resolve())] + sn_row = window._stoichiometry_row_by_input_dir[str(sn_dir.resolve())] + zn_row = window._stoichiometry_row_by_input_dir[str(zn_dir.resolve())] + assert window.stoichiometry_table.item(pb_row, 3).text() == "Complete" + assert window.stoichiometry_table.item(sn_row, 3).text() == "Complete" + assert window.stoichiometry_table.item(zn_row, 3).text() == "Queued" + + window.overwrite_existing_checkbox.setChecked(True) + overwrite_targets = window._analysis_targets_from_inputs( + output_root=tmp_path / "batch_output", + settings=RepresentativeFinderSettings(), + ) + assert [target.inspection.input_dir for target in overwrite_targets] == [ + pb_dir.resolve(), + sn_dir.resolve(), + zn_dir.resolve(), + ] + window.close() + + +def test_representativefinder_workflow_supports_cancellation(tmp_path): + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + + with pytest.raises(RepresentativeFinderOperationCancelled): + analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + cancel_callback=lambda: True, + ) + + +def test_representativefinder_input_inspection_discovers_stoichiometry_subfolders( + tmp_path, +): + root_dir, pb_dir, sn_dir = _build_multi_stoichiometry_root(tmp_path) + + inspection = inspect_representative_structure_input(root_dir) + + assert inspection.input_dir == root_dir.resolve() + assert inspection.input_is_stoichiometry_folder is False + assert inspection.stoichiometry_count == 2 + assert inspection.total_candidate_count == 6 + assert [ + item.structure_label for item in inspection.stoichiometry_folders + ] == [ + pb_dir.name, + sn_dir.name, + ] + + +def test_representativefinder_window_builds_split_scrollable_layout( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + stoich_dir = _build_representative_test_folder(tmp_path) + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=stoich_dir, + ) + + assert window.windowTitle() == "Representative Structures" + assert isinstance(window._pane_splitter, QSplitter) + assert isinstance(window._left_scroll, QScrollArea) + assert isinstance(window._right_scroll, QScrollArea) + assert isinstance(window._right_splitter, QSplitter) + assert window._right_splitter.count() == 5 + assert window.stoichiometry_table.columnCount() == 8 + assert isinstance(window.viewer_widget, ElectronDensityStructureViewer) + open_button = window.stoichiometry_table.cellWidget(0, 7) + assert isinstance(open_button, QPushButton) + assert open_button.isEnabled() is False + assert window.input_dir_edit.text() == str(stoich_dir.resolve()) + assert ( + "Discovered stoichiometries: 1" + in window.input_preview_box.toPlainText() + ) + assert window.stoichiometry_table.rowCount() == 1 + assert window.run_button.text() == "Analyze Selected Stoichiometry" + assert window.overwrite_existing_checkbox.isChecked() is False + assert "representativefinder_PbI2" in window.output_dir_edit.text() + assert window.plot_widget.distribution_selector_combo.count() == 0 + assert window.plot_widget.distribution_selector_combo.isEnabled() is False + assert window.solvent_shell_toggle_button.isChecked() is False + assert window.solvent_shell_body.isVisible() is False + + window.show() + app = QApplication.instance() + assert app is not None + app.processEvents() + right_sizes = window._right_splitter.sizes() + assert len(right_sizes) == 5 + assert right_sizes[0] >= right_sizes[1] + assert right_sizes[0] >= right_sizes[2] + + window.load_preset("DMF") + assert window.bond_pair_table.rowCount() == 7 + assert window.angle_triplet_table.rowCount() == 5 + window.close() + + +def test_representativefinder_window_switches_display_by_stoichiometry_row( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + root_dir, pb_dir, sn_dir = _build_multi_stoichiometry_root(tmp_path) + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + + assert window.stoichiometry_table.rowCount() == 2 + assert ( + "Discovered stoichiometries: 2" + in window.input_preview_box.toPlainText() + ) + assert ( + "representativefinder_batch_cluster_root" + in window.output_dir_edit.text() + ) + + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + pb_result = analyze_representative_structure_folder( + pb_dir, + settings=settings, + output_dir=tmp_path / "pb_output", + ) + sn_result = analyze_representative_structure_folder( + sn_dir, + settings=settings, + output_dir=tmp_path / "sn_output", + ) + + window._on_target_result_ready(pb_result) + window._on_target_result_ready(sn_result) + + opened_paths: list[Path] = [] + monkeypatch.setattr( + window, + "_reveal_path_in_file_manager", + lambda path: opened_paths.append(path), + ) + + window._select_stoichiometry_row_by_key(str(sn_dir.resolve())) + assert "Stoichiometry: SnBr2" in window.result_summary_box.toPlainText() + assert window.candidate_table.rowCount() == len(sn_result.candidates) + assert window.plot_widget.distribution_selector_combo.count() == 2 + assert window.viewer_widget.current_structure is not None + assert window.viewer_widget.current_structure.display_label == ( + sn_result.selected_candidate.relative_label + ) + assert window.viewer_widget.current_mesh_geometry is not None + sn_labels = [ + window.plot_widget.distribution_selector_combo.itemText(index) + for index in range( + window.plot_widget.distribution_selector_combo.count() + ) + ] + assert any("Sn-Br" in label for label in sn_labels) + assert all("Pb-I" not in label for label in sn_labels) + assert window.stoichiometry_table.item(1, 3).text() == "Complete" + sn_open_button = window.stoichiometry_table.cellWidget(1, 7) + assert isinstance(sn_open_button, QPushButton) + assert sn_open_button.isEnabled() is True + sn_open_button.click() + assert opened_paths[-1] == window._project_representative_path_for_key( + str(sn_dir.resolve()) + ) + + window._select_stoichiometry_row_by_key(str(pb_dir.resolve())) + assert "Stoichiometry: PbI2" in window.result_summary_box.toPlainText() + assert window.candidate_table.rowCount() == len(pb_result.candidates) + assert window.plot_widget.distribution_selector_combo.count() == 2 + assert window.viewer_widget.current_structure is not None + assert window.viewer_widget.current_structure.display_label == ( + pb_result.selected_candidate.relative_label + ) + assert window.viewer_widget.current_mesh_geometry is not None + pb_labels = [ + window.plot_widget.distribution_selector_combo.itemText(index) + for index in range( + window.plot_widget.distribution_selector_combo.count() + ) + ] + assert any("Pb-I" in label for label in pb_labels) + assert all("Sn-Br" not in label for label in pb_labels) + assert window.stoichiometry_table.item(0, 3).text() == "Complete" + pb_open_button = window.stoichiometry_table.cellWidget(0, 7) + assert isinstance(pb_open_button, QPushButton) + assert pb_open_button.isEnabled() is True + pb_open_button.click() + assert opened_paths[-1] == window._project_representative_path_for_key( + str(pb_dir.resolve()) + ) + window.close() + + +def test_representativefinder_window_switches_between_observed_and_predicted_outputs( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + stoich_dir = _build_representative_test_folder(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=(AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2),), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + generate_predicted_optimized_representative=True, + ) + result = analyze_representative_structure_folder( + stoich_dir, + settings=settings, + output_dir=tmp_path / "representative_output", + ) + + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=stoich_dir, + ) + window._on_target_result_ready(result) + + assert window.display_mode_combo.count() == 3 + display_modes = [ + window.display_mode_combo.itemData(index) + for index in range(window.display_mode_combo.count()) + ] + assert display_modes == [ + "selected_candidate", + "observed_representative", + "predicted_optimized_representative", + ] + + assert ( + window._active_representative_input_path() + == result.selected_candidate.file_path + ) + + predicted_index = display_modes.index("predicted_optimized_representative") + window.display_mode_combo.setCurrentIndex(predicted_index) + + assert "Displayed structure: Predicted Optimized Representative" in ( + window.result_summary_box.toPlainText() + ) + assert ( + window._active_representative_input_path() + == result.predicted_output_path + ) + assert window.viewer_widget.current_structure is not None + assert ( + window.viewer_widget.current_structure.display_label + == "Predicted Optimized Representative" + ) + + observed_index = display_modes.index("observed_representative") + window.display_mode_combo.setCurrentIndex(observed_index) + assert "Displayed structure: Observed Representative" in ( + window.result_summary_box.toPlainText() + ) + assert window._active_representative_input_path() == ( + window._project_representative_path_for_key(str(stoich_dir.resolve())) + ) + window.close() + + +def test_representativefinder_window_restores_project_session_results( + qapp, + tmp_path, + monkeypatch, +): + del qapp + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + root_dir, pb_dir, sn_dir = _build_multi_stoichiometry_root(tmp_path) + settings = RepresentativeFinderSettings( + bond_pairs=( + BondPairDefinition("Pb", "I", 3.2), + BondPairDefinition("Sn", "Br", 3.2), + ), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + AngleTripletDefinition("Sn", "Br", "Br", 3.2, 3.2), + ), + bond_weight=1.0, + angle_weight=0.5, + solvent_weight=0.5, + ) + pb_result = analyze_representative_structure_folder( + pb_dir, + settings=settings, + output_dir=tmp_path / "pb_output", + ) + sn_result = analyze_representative_structure_folder( + sn_dir, + settings=settings, + output_dir=tmp_path / "sn_output", + ) + + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + window.analysis_mode_combo.setCurrentIndex(1) + window.output_dir_edit.setText(str(tmp_path / "restored_batch_output")) + window._on_target_result_ready(pb_result) + window._on_target_result_ready(sn_result) + window._select_stoichiometry_row_by_key(str(sn_dir.resolve())) + window.close() + + restored_window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=root_dir, + ) + + assert ( + restored_window.run_status_label.text() + == "Representative selection: restored from project session" + ) + assert restored_window._current_analysis_mode() == "all" + assert restored_window.output_dir_edit.text() == str( + tmp_path / "restored_batch_output" + ) + assert restored_window.stoichiometry_table.item(0, 3).text() == "Complete" + assert restored_window.stoichiometry_table.item(1, 3).text() == "Complete" + assert ( + "Stoichiometry: SnBr2" + in restored_window.result_summary_box.toPlainText() + ) + assert restored_window.candidate_table.rowCount() == len( + sn_result.candidates + ) + assert restored_window.viewer_widget.current_structure is not None + assert restored_window.viewer_widget.current_structure.display_label == ( + sn_result.selected_candidate.relative_label + ) + restored_window.close() + + +def test_representativefinder_window_close_cancels_active_analysis( + qapp, + tmp_path, + monkeypatch, +): + monkeypatch.setenv( + "SAXSHELL_BONDANALYSIS_PRESETS_PATH", + str(tmp_path / "bondanalysis_presets.json"), + ) + monkeypatch.setattr( + "saxshell.representativefinder.ui.main_window.load_presets", + lambda: { + "Test": BondAnalysisPreset( + name="Test", + bond_pairs=(BondPairDefinition("Pb", "I", 3.2),), + angle_triplets=( + AngleTripletDefinition("Pb", "I", "I", 3.2, 3.2), + ), + ) + }, + ) + + cancel_seen: list[bool] = [] + + def _fake_analyze( + _input_dir, + *, + settings, + output_dir, + project_dir=None, + progress_callback=None, + log_callback=None, + cancel_callback=None, + ): + del settings, output_dir, project_dir + if log_callback is not None: + log_callback("Starting fake analysis.") + start = time.monotonic() + while time.monotonic() - start < 1.0: + if cancel_callback is not None and cancel_callback(): + cancel_seen.append(True) + raise RepresentativeFinderOperationCancelled( + "Representative-structure analysis canceled." + ) + time.sleep(0.01) + if progress_callback is not None: + progress_callback(0, 1, "Waiting for cancellation...") + raise AssertionError( + "Representative finder close did not cancel the worker." + ) + + monkeypatch.setattr( + "saxshell.representativefinder.ui.main_window.analyze_representative_structure_folder", + _fake_analyze, + ) + + stoich_dir = _build_representative_test_folder(tmp_path) + output_dir = tmp_path / "representative_output" + window = RepresentativeStructureFinderMainWindow( + initial_project_dir=tmp_path, + initial_input_path=stoich_dir, + ) + window.output_dir_edit.setText(str(output_dir)) + window.show() + qapp.processEvents() + + window._run_analysis() + + deadline = time.monotonic() + 2.0 + while ( + window._analysis_thread is None + or not window._analysis_thread.isRunning() + ) and time.monotonic() < deadline: + qapp.processEvents() + time.sleep(0.01) + + assert window._analysis_thread is not None + assert window._analysis_thread.isRunning() + + window.close() + + deadline = time.monotonic() + 2.0 + while window._analysis_thread is not None and time.monotonic() < deadline: + qapp.processEvents() + time.sleep(0.01) + + assert window._analysis_thread is None + assert window.isVisible() is False + assert cancel_seen in ([], [True]) From f20c96d7d5cad48a63a5145efe17795441c6d575 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:48:00 -0600 Subject: [PATCH 4/7] feat(viewer): allow structure background color edits Add a structure-viewer background color control, preserve the selected background across redraws, and cover the behavior in the viewer tests. --- .../saxs/structure_viewer/ui/widget.py | 70 ++++++++++++++++--- tests/test_structure_viewer.py | 14 +++- 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/src/saxshell/saxs/structure_viewer/ui/widget.py b/src/saxshell/saxs/structure_viewer/ui/widget.py index 1543010..db7e982 100644 --- a/src/saxshell/saxs/structure_viewer/ui/widget.py +++ b/src/saxshell/saxs/structure_viewer/ui/widget.py @@ -58,6 +58,13 @@ def _clamp_fraction(value: float) -> float: return float(max(0.0, min(1.0, float(value)))) +def _button_foreground_color(color_value: str) -> str: + color = QColor(str(color_value)) + if not color.isValid(): + return "#f8fafc" + return "#111827" if color.lightnessF() >= 0.62 else "#f8fafc" + + def _rotation_matrix_from_axis_angle( axis: np.ndarray, angle_radians: float, @@ -146,6 +153,7 @@ def __init__(self, parent: QWidget | None = None) -> None: self._mesh_contrast = _DEFAULT_MESH_CONTRAST self._mesh_linewidth = _DEFAULT_MESH_LINEWIDTH self._mesh_color = _DEFAULT_MESH_COLOR + self._background_color = _BACKGROUND_COLOR self._atom_render_mode = "points" self._pending_view_update = False self._pending_axis_reference_refresh = False @@ -247,6 +255,12 @@ def _build_ui(self) -> None: self.mesh_color_button.clicked.connect(self._choose_mesh_color) contrast_row.addWidget(self.mesh_color_button) + self.background_color_button = QPushButton("Background Color") + self.background_color_button.clicked.connect( + self._choose_background_color + ) + contrast_row.addWidget(self.background_color_button) + self.point_atoms_checkbox = QCheckBox("Point Atoms") self.point_atoms_checkbox.setChecked(True) self.point_atoms_checkbox.toggled.connect( @@ -256,6 +270,7 @@ def _build_ui(self) -> None: contrast_row.addStretch(1) layout.addLayout(contrast_row) self._update_mesh_color_button_style() + self._update_background_color_button_style() self.canvas.setMinimumHeight(520) layout.addWidget(self.canvas, stretch=1) @@ -293,7 +308,9 @@ def draw_placeholder(self) -> None: self._clear_pending_view_update() self._active_settings_artist = None self.figure.clear() + self.figure.set_facecolor(self._background_color) axis = self.figure.add_subplot(111) + axis.set_facecolor(self._background_color) axis.text( 0.5, 0.57, @@ -453,8 +470,8 @@ def _draw_view( axis_reference.set_in_layout(False) self._axis = axis self._axis_reference = axis_reference - self.figure.set_facecolor(_BACKGROUND_COLOR) - axis.set_facecolor(_BACKGROUND_COLOR) + self.figure.set_facecolor(self._background_color) + axis.set_facecolor(self._background_color) axis.grid(False) axis.set_xticks([]) axis.set_yticks([]) @@ -781,7 +798,7 @@ def _draw_legend( fontsize=8.2, ) frame = legend.get_frame() - frame.set_facecolor(_BACKGROUND_COLOR) + frame.set_facecolor(self._background_color) frame.set_edgecolor("#c6ccd4") frame.set_linewidth(0.8) @@ -918,6 +935,7 @@ def _active_settings_readout_text(self) -> str: f"MESH {self._mesh_contrast * 100.0:05.1f}%", f"LINE {self._mesh_linewidth:04.2f}px", f"COLOR {self._mesh_color.upper()}", + f"BG {self._background_color.upper()}", ) ) @@ -937,17 +955,36 @@ def _refresh_active_settings_readout(self) -> None: ) def _update_mesh_color_button_style(self) -> None: - self.mesh_color_button.setStyleSheet( + self._update_color_button_style( + self.mesh_color_button, + label="Mesh Color", + color_value=self._mesh_color, + ) + + def _update_background_color_button_style(self) -> None: + self._update_color_button_style( + self.background_color_button, + label="Background Color", + color_value=self._background_color, + ) + + def _update_color_button_style( + self, + button: QPushButton, + *, + label: str, + color_value: str, + ) -> None: + text_color = _button_foreground_color(color_value) + button.setStyleSheet( "QPushButton {" - f"background-color: {self._mesh_color};" - "color: white;" + f"background-color: {color_value};" + f"color: {text_color};" "border: 1px solid #475569;" "padding: 4px 8px;" "}" ) - self.mesh_color_button.setText( - f"Mesh Color {self._mesh_color.upper()}" - ) + button.setText(f"{label} {color_value.upper()}") def _choose_mesh_color(self) -> None: selected = QColorDialog.getColor( @@ -962,6 +999,21 @@ def _choose_mesh_color(self) -> None: if self.current_structure is not None: self._draw_view(reset_view=False) + def _choose_background_color(self) -> None: + selected = QColorDialog.getColor( + QColor(self._background_color), + self, + "Select Background Color", + ) + if not selected.isValid(): + return + self._background_color = str(selected.name()) + self._update_background_color_button_style() + if self.current_structure is not None: + self._draw_view(reset_view=False) + return + self.draw_placeholder() + def _schedule_view_update( self, *, diff --git a/tests/test_structure_viewer.py b/tests/test_structure_viewer.py index 954006f..f991f9e 100644 --- a/tests/test_structure_viewer.py +++ b/tests/test_structure_viewer.py @@ -54,6 +54,10 @@ def test_structure_viewer_loads_single_structure(qapp, tmp_path): == structure_path.resolve() ) assert window.structure_viewer.current_mesh_geometry is not None + assert ( + window.structure_viewer.background_color_button.text() + == "Background Color #FDFBF8" + ) assert "shells=" in window.active_mesh_value.text() window.close() @@ -86,11 +90,13 @@ def test_structure_viewer_center_updates_preserve_viewer_display( viewer.mesh_linewidth_spin.interpretText() viewer.point_atoms_checkbox.setChecked(True) + selected_colors = iter((QColor("#ff6600"), QColor("#112233"))) monkeypatch.setattr( "saxshell.saxs.structure_viewer.ui.widget.QColorDialog.getColor", - lambda *args, **kwargs: QColor("#ff6600"), + lambda *args, **kwargs: next(selected_colors), ) viewer.mesh_color_button.click() + viewer.background_color_button.click() viewer._view_radius = 7.5 viewer._view_center = np.asarray([0.2, -0.4, 0.6], dtype=float) @@ -113,6 +119,7 @@ def test_structure_viewer_center_updates_preserve_viewer_display( assert viewer._mesh_contrast == pytest.approx(0.45) assert viewer._mesh_linewidth == pytest.approx(2.7) assert viewer._mesh_color == "#ff6600" + assert viewer._background_color == "#112233" assert viewer._atom_render_mode == "points" assert viewer._view_radius == pytest.approx(7.5) assert np.allclose(viewer._view_center, [0.2, -0.4, 0.6]) @@ -141,6 +148,11 @@ def test_structure_viewer_center_updates_preserve_viewer_display( assert "MESH 045.0%" in overlay_text assert "LINE 2.70px" in overlay_text assert "#FF6600" in overlay_text + assert "BG #112233" in overlay_text + assert np.allclose( + viewer.figure.get_facecolor()[:3], + QColor("#112233").getRgbF()[:3], + ) window.close() From b00a55232719a7c48b71f666418b7cef5462fadc Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:49:28 -0600 Subject: [PATCH 5/7] feat(saxs): add representative and FFT Born workflows Add representative-structure component sources, template metadata capabilities, persisted plot editor state, and 1D/3D Born component workflow integration. Introduce the 3D FFT Born backend/UI and expand project, prefit, electron-density, and SAXS UI tests. --- .../saxs/_model_templates/__init__.py | 158 + ...ream_monosq_normalized_scaled_solvent.json | 18 + ...ydream_monosq_normalized_scaled_solvent.py | 139 + src/saxshell/saxs/born_refinement/__init__.py | 21 + src/saxshell/saxs/born_refinement/backend.py | 389 ++ src/saxshell/saxs/contrast/__init__.py | 4 + src/saxshell/saxs/contrast/settings.py | 21 +- src/saxshell/saxs/contrast/ui/main_window.py | 3 +- src/saxshell/saxs/contrast_fft/__init__.py | 21 + src/saxshell/saxs/contrast_fft/backend.py | 650 +++ src/saxshell/saxs/contrast_fft/ui/__init__.py | 9 + .../saxs/contrast_fft/ui/main_window.py | 4788 +++++++++++++++++ .../electron_density_mapping/ui/__init__.py | 26 +- .../ui/main_window.py | 1438 ++++- .../electron_density_mapping/ui/viewer.py | 3 +- .../saxs/electron_density_mapping/workflow.py | 55 +- src/saxshell/saxs/prefit/cluster_geometry.py | 7 +- src/saxshell/saxs/prefit/workflow.py | 324 +- src/saxshell/saxs/project_manager/__init__.py | 10 + .../saxs/project_manager/prior_plot.py | 161 +- src/saxshell/saxs/project_manager/project.py | 1013 +++- src/saxshell/saxs/ui/__init__.py | 33 +- src/saxshell/saxs/ui/_pane_snap.py | 20 +- src/saxshell/saxs/ui/main_window.py | 1063 +++- src/saxshell/saxs/ui/prefit_tab.py | 517 +- src/saxshell/saxs/ui/project_setup_tab.py | 1358 ++++- .../saxs/ui/solution_scattering_widget.py | 14 +- src/saxshell/saxshell.py | 22 + tests/test_contrast_fft_backend.py | 44 + tests/test_electron_density_mapping.py | 485 +- tests/test_saxs_prefit.py | 442 +- tests/test_saxs_ui.py | 3490 +++++++++++- 32 files changed, 15836 insertions(+), 910 deletions(-) create mode 100644 src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.json create mode 100644 src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.py create mode 100644 src/saxshell/saxs/born_refinement/__init__.py create mode 100644 src/saxshell/saxs/born_refinement/backend.py create mode 100644 src/saxshell/saxs/contrast_fft/__init__.py create mode 100644 src/saxshell/saxs/contrast_fft/backend.py create mode 100644 src/saxshell/saxs/contrast_fft/ui/__init__.py create mode 100644 src/saxshell/saxs/contrast_fft/ui/main_window.py create mode 100644 tests/test_contrast_fft_backend.py diff --git a/src/saxshell/saxs/_model_templates/__init__.py b/src/saxshell/saxs/_model_templates/__init__.py index bc9d9af..6c1f999 100644 --- a/src/saxshell/saxs/_model_templates/__init__.py +++ b/src/saxshell/saxs/_model_templates/__init__.py @@ -46,6 +46,20 @@ def runtime_input_names(self) -> tuple[str, ...]: return tuple(binding.runtime_name for binding in self.runtime_bindings) +@dataclass(slots=True, frozen=True) +class TemplateSolutionScatteringSupport: + volume_fraction_parameter: str | None = None + volume_fraction_kind: str = "solute" + volume_fraction_source: str = "saxs_effective" + solvent_contribution_scale_mode: str = "unscaled" + + +@dataclass(slots=True, frozen=True) +class TemplatePrefitSupport: + auto_apply_autoscale_on_load: bool = False + autoscale_bounds_mode: str = "preserve_existing" + + @dataclass(slots=True, frozen=True) class TemplateSpec: name: str @@ -61,6 +75,8 @@ class TemplateSpec: param_columns: tuple[str, ...] parameters: tuple[TemplateParameter, ...] cluster_geometry_support: TemplateClusterGeometrySupport + solution_scattering_support: TemplateSolutionScatteringSupport + prefit_support: TemplatePrefitSupport @property def label(self) -> str: @@ -184,6 +200,8 @@ def load_template_spec( param_columns=_split_csv(directives["param_columns"]), parameters=tuple(_parse_param_lines(module_path)), cluster_geometry_support=metadata["cluster_geometry_support"], + solution_scattering_support=metadata["solution_scattering_support"], + prefit_support=metadata["prefit_support"], ) _validate_template_runtime_contract(spec) return spec @@ -311,6 +329,8 @@ def _load_template_metadata( "cluster_geometry_support": TemplateClusterGeometrySupport( supported=False ), + "solution_scattering_support": TemplateSolutionScatteringSupport(), + "prefit_support": TemplatePrefitSupport(), } payload = json.loads(metadata_path.read_text(encoding="utf-8")) display_name = str(payload.get("display_name", "")).strip() @@ -330,10 +350,20 @@ def _load_template_metadata( metadata_path=metadata_path, directives=directives, ) + solution_scattering_support = _parse_solution_scattering_support( + payload, + metadata_path=metadata_path, + ) + prefit_support = _parse_prefit_support( + payload, + metadata_path=metadata_path, + ) return { "display_name": display_name, "description": description, "cluster_geometry_support": cluster_geometry_support, + "solution_scattering_support": solution_scattering_support, + "prefit_support": prefit_support, } @@ -532,6 +562,132 @@ def _parse_cluster_geometry_support( ) +def _parse_solution_scattering_support( + payload: dict[str, object], + *, + metadata_path: Path, +) -> TemplateSolutionScatteringSupport: + capabilities = payload.get("capabilities", {}) + if not isinstance(capabilities, dict): + capabilities = {} + raw_support = capabilities.get("solution_scattering_estimator", {}) + if raw_support is None: + raw_support = {} + if not isinstance(raw_support, dict): + raise ValueError( + f"Template metadata file {metadata_path.name} defines " + "'capabilities.solution_scattering_estimator' with an invalid " + "schema." + ) + + raw_target = raw_support.get("volume_fraction_target") + parameter_name: str | None = None + fraction_kind = "solute" + fraction_source = "saxs_effective" + if raw_target is not None: + if not isinstance(raw_target, dict): + raise ValueError( + f"Template metadata file {metadata_path.name} defines " + "'volume_fraction_target' with an invalid schema." + ) + parameter_name = str(raw_target.get("parameter", "")).strip() or None + if parameter_name is None: + raise ValueError( + f"Template metadata file {metadata_path.name} declares a " + "volume_fraction_target without a parameter name." + ) + fraction_kind = ( + str(raw_target.get("fraction_kind", "solute")).strip().lower() + or "solute" + ) + if fraction_kind not in {"solute", "solvent"}: + raise ValueError( + f"Template metadata file {metadata_path.name} declares " + f"unsupported fraction_kind {fraction_kind!r}." + ) + fraction_source = ( + str(raw_target.get("source", "saxs_effective")).strip().lower() + or "saxs_effective" + ) + if fraction_source not in {"saxs_effective", "physical"}: + raise ValueError( + f"Template metadata file {metadata_path.name} declares " + f"unsupported volume fraction source {fraction_source!r}." + ) + + scale_mode = ( + str( + raw_support.get( + "solvent_contribution_scale_mode", + "unscaled", + ) + ) + .strip() + .lower() + or "unscaled" + ) + if scale_mode not in {"unscaled", "global_scale"}: + raise ValueError( + f"Template metadata file {metadata_path.name} declares unsupported " + f"solvent_contribution_scale_mode {scale_mode!r}." + ) + + return TemplateSolutionScatteringSupport( + volume_fraction_parameter=parameter_name, + volume_fraction_kind=fraction_kind, + volume_fraction_source=fraction_source, + solvent_contribution_scale_mode=scale_mode, + ) + + +def _parse_prefit_support( + payload: dict[str, object], + *, + metadata_path: Path, +) -> TemplatePrefitSupport: + capabilities = payload.get("capabilities", {}) + if not isinstance(capabilities, dict): + capabilities = {} + raw_support = capabilities.get("prefit", {}) + if raw_support is None: + raw_support = {} + if not isinstance(raw_support, dict): + raise ValueError( + f"Template metadata file {metadata_path.name} defines " + "'capabilities.prefit' with an invalid schema." + ) + + auto_apply_payload = raw_support.get( + "auto_apply_autoscale_on_load", + False, + ) + if isinstance(auto_apply_payload, bool): + auto_apply = auto_apply_payload + else: + auto_apply = _parse_bool_directive( + str(auto_apply_payload), + field_name="capabilities.prefit.auto_apply_autoscale_on_load", + source_name=f"Template metadata file {metadata_path.name}", + ) + + bounds_mode = ( + str(raw_support.get("autoscale_bounds_mode", "preserve_existing")) + .strip() + .lower() + or "preserve_existing" + ) + if bounds_mode not in {"preserve_existing", "adaptive"}: + raise ValueError( + f"Template metadata file {metadata_path.name} declares unsupported " + f"autoscale_bounds_mode {bounds_mode!r}." + ) + + return TemplatePrefitSupport( + auto_apply_autoscale_on_load=auto_apply, + autoscale_bounds_mode=bounds_mode, + ) + + def _parse_bool_directive( value: str, *, @@ -552,7 +708,9 @@ def _parse_bool_directive( __all__ = [ "TemplateClusterGeometrySupport", "TemplateParameter", + "TemplatePrefitSupport", "TemplateRuntimeMetadataBinding", + "TemplateSolutionScatteringSupport", "TemplateSpec", "default_template_dir", "list_template_specs", diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.json b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.json new file mode 100644 index 0000000..ba21d0b --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.json @@ -0,0 +1,18 @@ +{ + "display_name": "pyDREAM MonoSQ Normalized (Scaled Solvent Weight)", + "description": "pyDREAM MonoSQ Normalized (Scaled Solvent Weight)\n\nPurpose:\nScale-coupled-solvent MonoSQ variant for projects where the experimental solvent trace should be placed under the same global intensity scale as the MD-derived solute model. The original pyDREAM MonoSQ Normalized template remains available unchanged for projects that rely on the historical unscaled-solvent convention.\n\nStructure Factor:\nMonodisperse hard-sphere structure factor evaluated with calc_monodisperse_sq using the Percus-Yevick approximation. The effective particle size is controlled by eff_r and the hard-sphere packing term is controlled by vol_frac. The default eff_r is 3 A so new projects start at the lower bound instead of assuming a larger interaction radius before the user or fit has evidence for it.\n\nForm Factor:\nWeighted mixture of MD-derived SAXS component profiles assembled from the project's averaged cluster scattering curves. The weighted component mixture is multiplied by the hard-sphere structure factor. The solvent trace is multiplied by solv_w, added to the solute branch, and then the combined solute + solvent expression is multiplied by the global scale before adding offset.\n\nModel Equation:\nI_model(q) = scale * (sum_i w_i I_i(q) S_HS(q; eff_r, vol_frac) + solv_w * I_solv(q)) + offset\n\nCalculator Integration:\nThe solution-scattering estimator can pre-populate vol_frac from the physical solute-associated volume fraction, because vol_frac is the hard-sphere packing term rather than the solvent background multiplier. The estimator can also pre-populate solv_w with the combined solvent background multiplier. In this template solv_w is interpreted before the global scale is applied, so it can remain representative of a solvent-background fraction instead of compensating for the arbitrary MD model intensity scale.\n\nPrefit Startup:\nWhen this template is loaded with experimental data available, Prefit automatically applies its autoscale estimate to the scale and offset parameters unless a saved Best Prefit or current Prefit state already exists. The scale and offset limits are then centered around the autoscale result instead of preserving broad template-default bounds.\n\nLikelihood Convention:\nThe pyDREAM likelihood is normalized by the number of experimental points. This keeps the log-likelihood on a per-point basis so that its magnitude is less sensitive to changes in fitted q-range, resampling density, or dataset length.\n\nModel Parameters:\nw1, w2, ..., wN: Relative contribution of each MD-derived component profile.\nsolv_w: Solvent-background multiplier applied to the experimental solvent trace before the global scale is applied; constrained to [0, 1] and fixed by default for calculator-seeded workflows.\noffset: Additive baseline term used to capture residual background intensity.\neff_r: Effective hard-sphere radius used in the structure-factor calculation; defaults to 3 A.\nvol_frac: Hard-sphere packing fraction used in the Percus-Yevick structure factor; the solution-scattering estimator targets this field with the physical solute-associated volume fraction.\nscale: Multiplicative scale factor applied to the combined solute + weighted solvent contribution.", + "capabilities": { + "solution_scattering_estimator": { + "volume_fraction_target": { + "parameter": "vol_frac", + "fraction_kind": "solute", + "source": "physical" + }, + "solvent_contribution_scale_mode": "global_scale" + }, + "prefit": { + "auto_apply_autoscale_on_load": true, + "autoscale_bounds_mode": "adaptive" + } + } +} diff --git a/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.py b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.py new file mode 100644 index 0000000..3387393 --- /dev/null +++ b/src/saxshell/saxs/_model_templates/template_pydream_monosq_normalized_scaled_solvent.py @@ -0,0 +1,139 @@ +import numpy as np +from scipy.stats import norm + +# ============================================== +# model_lmfit: lmfit_model_profile +# model_pydream: log_likelihood_monosq_scaled_solvent +# inputs_lmfit: q, solvent_data, model_data, params +# inputs_pydream: q, solvent_data, model_data, params +# param_columns: Structure, Motif, Param, Value, Vary, Min, Max +# +# param: solv_w,1.0,False,0.0,1.0 +# param: offset,0,True,-20,30 +# param: eff_r,3.0,True,3,20 +# param: vol_frac,0.0,False,0.0,0.5 +# param: scale,5e-4,False,1e-8,5e-3 +# +# MonoSQ normalized variant with scale-coupled solvent contribution: +# I(q) = scale * (I_solute(q) + solv_w * I_solvent(q)) + offset +# +# The original template_pydream_monosq_normalized.py is intentionally left +# unchanged for existing projects that rely on the unscaled solvent branch. +# ============================================== + + +def calc_monodisperse_sq(r, vol_frac, q_values): + """Return the hard-sphere Percus-Yevick structure factor.""" + sqs = [] + + alpha = (1 + 2 * vol_frac) ** 2 / (1 - vol_frac) ** 4 + beta = -6 * vol_frac * (1 + vol_frac / 2) ** 2 / (1 - vol_frac) ** 4 + gamma = 0.5 * vol_frac * (1 + 2 * vol_frac) ** 2 / (1 - vol_frac) ** 4 + + for q in q_values: + a = 2 * q * r + g1 = alpha / a**2 * (np.sin(a) - a * np.cos(a)) + g2 = beta / a**3 * (2 * a * np.sin(a) + (2 - a**2) * np.cos(a) - 2) + g3 = ( + gamma + / a**5 + * ( + -(a**4) * np.cos(a) + + 4 + * ((3 * a**2 - 6) * np.cos(a) + (a**3 - 6 * a) * np.sin(a) + 6) + ) + ) + g = g1 + g2 + g3 + sqs.append(1 / (1 + 24 * vol_frac * (g / a))) + + return np.asarray(sqs) + + +def _bounded_solvent_weight(value): + return float(np.clip(float(value), 0.0, 1.0)) + + +def structure_factor_profile(q, solvent_data, model_data, **params): + """Return the pure hard-sphere structure-factor trace S(q).""" + del solvent_data, model_data + return calc_monodisperse_sq( + params["eff_r"], + params["vol_frac"], + np.asarray(q, dtype=float), + ) + + +def lmfit_model_profile(q, solvent_data, model_data, **params): + """Evaluate the scale-coupled-solvent monodisperse SAXS model.""" + weight_keys = sorted( + (key for key in params if key.startswith("w")), + key=lambda key: int(key.lstrip("w")), + ) + weights = [params[key] for key in weight_keys] + + solv_w = _bounded_solvent_weight(params["solv_w"]) + offset = params["offset"] + eff_r = params["eff_r"] + vol_frac = params["vol_frac"] + scale = params["scale"] + + mixture = np.zeros_like(q) + for weight, component in zip(weights, model_data): + mixture += weight * component + + solute_intensity = mixture * calc_monodisperse_sq(eff_r, vol_frac, q) + solvent_contribution = solv_w * np.asarray(solvent_data, dtype=float) + + return scale * (solute_intensity + solvent_contribution) + offset + + +def model_monosq_scaled_solvent(params): + """Return the forward model intensity for pyDREAM.""" + global q_values + global theoretical_intensities + global solvent_intensities + + n_profiles = len(theoretical_intensities) + + weights = params[:n_profiles] + solv_w = _bounded_solvent_weight(params[n_profiles]) + offset = params[n_profiles + 1] + eff_r = params[n_profiles + 2] + vol_frac = params[n_profiles + 3] + scale = params[n_profiles + 4] + + mixture = np.zeros_like(q_values) + for weight, component in zip(weights, theoretical_intensities): + mixture += weight * component + + solute_intensity = mixture * calc_monodisperse_sq( + eff_r, + vol_frac, + q_values, + ) + solvent_contribution = solv_w * np.asarray( + solvent_intensities, dtype=float + ) + + return scale * (solute_intensity + solvent_contribution) + offset + + +def log_likelihood_monosq_scaled_solvent(params): + """Return the normalized Gaussian log-likelihood for pyDREAM.""" + global experimental_intensities + + model_intensity = model_monosq_scaled_solvent(params) + n_points = len(experimental_intensities) + + log_likelihood = np.sum( + norm.logpdf( + experimental_intensities, + loc=model_intensity, + scale=1e-4, + ) + ) + + if n_points == 0: + return log_likelihood + + return log_likelihood / n_points diff --git a/src/saxshell/saxs/born_refinement/__init__.py b/src/saxshell/saxs/born_refinement/__init__.py new file mode 100644 index 0000000..5a961e4 --- /dev/null +++ b/src/saxshell/saxs/born_refinement/__init__.py @@ -0,0 +1,21 @@ +from .backend import ( + GridBornResult, + GridBornSettings, + build_shared_q_grid, + compute_constant_weight_debye_intensity, + compute_directional_born_intensity, + compute_fft_grid_born_intensity, + compute_spherical_average_point_born_intensity, + fibonacci_sphere_directions, +) + +__all__ = [ + "GridBornResult", + "GridBornSettings", + "build_shared_q_grid", + "compute_constant_weight_debye_intensity", + "compute_directional_born_intensity", + "compute_fft_grid_born_intensity", + "compute_spherical_average_point_born_intensity", + "fibonacci_sphere_directions", +] diff --git a/src/saxshell/saxs/born_refinement/backend.py b/src/saxshell/saxs/born_refinement/backend.py new file mode 100644 index 0000000..59f21dc --- /dev/null +++ b/src/saxshell/saxs/born_refinement/backend.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +from scipy.spatial.distance import cdist, pdist + + +def build_shared_q_grid( + q_min: float, + q_max: float, + *, + q_step: float = 0.01, +) -> np.ndarray: + if float(q_step) <= 0.0: + raise ValueError("q_step must be greater than zero.") + if float(q_max) < float(q_min): + raise ValueError("q_max must be greater than or equal to q_min.") + step = float(q_step) + count = int(np.floor((float(q_max) - float(q_min)) / step + 0.5)) + 1 + q_values = float(q_min) + step * np.arange(count, dtype=float) + upper = float(q_max) + q_values = q_values[q_values <= upper + 1.0e-12] + if q_values.size == 0: + raise ValueError("The requested q-grid did not contain any samples.") + endpoint_tolerance = max( + 1.0e-12, + 1.0e-9 * max(abs(float(q_min)), abs(upper), 1.0), + ) + if abs(float(q_values[-1]) - upper) <= endpoint_tolerance: + q_values[-1] = upper + elif float(q_values[-1]) < upper: + q_values = np.append(q_values, upper) + return np.asarray(q_values, dtype=float) + + +def fibonacci_sphere_directions(direction_count: int) -> np.ndarray: + count = int(direction_count) + if count < 1: + raise ValueError("direction_count must be at least 1.") + indices = np.arange(count, dtype=float) + golden_angle = np.pi * (3.0 - np.sqrt(5.0)) + y_values = 1.0 - (2.0 * indices + 1.0) / float(count) + radial_values = np.sqrt(np.clip(1.0 - y_values * y_values, 0.0, None)) + theta_values = golden_angle * indices + x_values = np.cos(theta_values) * radial_values + z_values = np.sin(theta_values) * radial_values + return np.asarray( + np.column_stack((x_values, y_values, z_values)), + dtype=float, + ) + + +def compute_constant_weight_debye_intensity( + coordinates: np.ndarray, + weights: np.ndarray, + q_values: np.ndarray, + *, + atom_block_size: int = 128, + q_chunk_size: int = 24, +) -> np.ndarray: + q_grid = np.asarray(q_values, dtype=float) + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("coordinates must be an Nx3 array.") + if atom_weights.ndim != 1 or atom_weights.shape[0] != coords.shape[0]: + raise ValueError( + "weights must be a one-dimensional array matching coordinates." + ) + if q_grid.ndim != 1: + raise ValueError("q_values must be one-dimensional.") + if coords.size == 0: + return np.zeros_like(q_grid, dtype=float) + + intensity = np.sum(np.square(atom_weights), dtype=float) * np.ones_like( + q_grid, + dtype=float, + ) + block = max(int(atom_block_size), 1) + q_block = max(int(q_chunk_size), 1) + atom_count = int(coords.shape[0]) + + def accumulate_pair_terms( + pair_distances: np.ndarray, + pair_weights: np.ndarray, + ) -> None: + if pair_distances.size == 0 or pair_weights.size == 0: + return + for q_start in range(0, q_grid.size, q_block): + q_stop = min(q_start + q_block, q_grid.size) + q_chunk = q_grid[q_start:q_stop] + kernel = np.sinc( + pair_distances[:, np.newaxis] * q_chunk[np.newaxis, :] / np.pi + ) + intensity[q_start:q_stop] += 2.0 * np.sum( + pair_weights[:, np.newaxis] * kernel, + axis=0, + dtype=float, + ) + + for i_start in range(0, atom_count, block): + i_stop = min(i_start + block, atom_count) + coords_i = coords[i_start:i_stop] + weights_i = atom_weights[i_start:i_stop] + local_count = int(i_stop - i_start) + if local_count > 1: + intra_distances = pdist(coords_i, metric="euclidean") + intra_weights = np.concatenate( + [ + weights_i[index] * weights_i[index + 1 :] + for index in range(local_count - 1) + ] + ) + accumulate_pair_terms(intra_distances, intra_weights) + for j_start in range(i_stop, atom_count, block): + j_stop = min(j_start + block, atom_count) + coords_j = coords[j_start:j_stop] + weights_j = atom_weights[j_start:j_stop] + inter_distances = cdist(coords_i, coords_j).reshape(-1) + inter_weights = ( + weights_i[:, np.newaxis] * weights_j[np.newaxis, :] + ).reshape(-1) + accumulate_pair_terms(inter_distances, inter_weights) + return np.asarray(intensity, dtype=float) + + +def compute_spherical_average_point_born_intensity( + coordinates: np.ndarray, + weights: np.ndarray, + q_values: np.ndarray, +) -> np.ndarray: + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + q_grid = np.asarray(q_values, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("coordinates must be an Nx3 array.") + if atom_weights.ndim != 1 or atom_weights.shape[0] != coords.shape[0]: + raise ValueError( + "weights must be a one-dimensional array matching coordinates." + ) + radii = np.linalg.norm(coords, axis=1) + amplitude = np.sum( + atom_weights[np.newaxis, :] + * np.sinc(q_grid[:, np.newaxis] * radii[np.newaxis, :] / np.pi), + axis=1, + dtype=float, + ) + return np.square(np.abs(amplitude)) + + +def compute_directional_born_intensity( + coordinates: np.ndarray, + weights: np.ndarray, + q_values: np.ndarray, + *, + direction_count: int = 256, + q_chunk_size: int = 12, +) -> np.ndarray: + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + q_grid = np.asarray(q_values, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("coordinates must be an Nx3 array.") + if atom_weights.ndim != 1 or atom_weights.shape[0] != coords.shape[0]: + raise ValueError( + "weights must be a one-dimensional array matching coordinates." + ) + directions = fibonacci_sphere_directions(direction_count) + projections = np.asarray(coords @ directions.T, dtype=float) + intensity = np.zeros_like(q_grid, dtype=float) + q_block = max(int(q_chunk_size), 1) + for q_start in range(0, q_grid.size, q_block): + q_stop = min(q_start + q_block, q_grid.size) + q_chunk = q_grid[q_start:q_stop] + phase = ( + q_chunk[:, np.newaxis, np.newaxis] * projections[np.newaxis, :, :] + ) + amplitude = np.sum( + atom_weights[np.newaxis, :, np.newaxis] * np.exp(1j * phase), + axis=1, + ) + intensity[q_start:q_stop] = np.mean( + np.square(np.abs(amplitude)), + axis=1, + dtype=float, + ) + return np.asarray(intensity, dtype=float) + + +@dataclass(slots=True, frozen=True) +class GridBornSettings: + spacing_a: float + padding_a: float + sigma_a: float + support_sigma: float = 4.0 + + def normalized(self) -> "GridBornSettings": + spacing = float(self.spacing_a) + padding = float(self.padding_a) + sigma = float(self.sigma_a) + support_sigma = float(self.support_sigma) + if spacing <= 0.0: + raise ValueError("spacing_a must be greater than zero.") + if padding < 0.0: + raise ValueError("padding_a must be non-negative.") + if sigma < 0.0: + raise ValueError("sigma_a must be non-negative.") + if support_sigma <= 0.0: + raise ValueError("support_sigma must be greater than zero.") + return GridBornSettings( + spacing_a=spacing, + padding_a=padding, + sigma_a=sigma, + support_sigma=support_sigma, + ) + + +@dataclass(slots=True, frozen=True) +class GridBornResult: + settings: GridBornSettings + q_values: np.ndarray + intensity: np.ndarray + q_shell_counts: np.ndarray + density_integral: float + expected_weight: float + grid_shape: tuple[int, int, int] + box_lengths_a: tuple[float, float, float] + voxel_spacing_a: tuple[float, float, float] + q_nyquist_a_inverse: float + q_frequency_step_a_inverse: tuple[float, float, float] + q_convention: str + uses_two_pi_frequency_conversion: bool + density_subtraction_active: bool + + +def _deposit_density_to_grid( + coordinates: np.ndarray, + weights: np.ndarray, + settings: GridBornSettings, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + normalized = settings.normalized() + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + spacing = float(normalized.spacing_a) + padding = float(normalized.padding_a) + sigma = float(normalized.sigma_a) + d_volume = spacing**3 + + coord_min = np.min(coords, axis=0) - padding + coord_max = np.max(coords, axis=0) + padding + span = np.asarray(coord_max - coord_min, dtype=float) + grid_shape = tuple( + int(np.ceil(float(axis_span) / spacing)) + 1 for axis_span in span + ) + density = np.zeros(grid_shape, dtype=float) + + if sigma <= 0.0: + grid_indices = np.rint((coords - coord_min) / spacing).astype(int) + for grid_index, atom_weight in zip( + grid_indices, + atom_weights, + strict=False, + ): + ix, iy, iz = ( + int(grid_index[0]), + int(grid_index[1]), + int(grid_index[2]), + ) + density[ix, iy, iz] += float(atom_weight) / d_volume + return density, coord_min, span + + support_radius = int( + np.ceil(float(normalized.support_sigma) * sigma / spacing) + ) + for point, atom_weight in zip(coords, atom_weights, strict=False): + center_index = np.rint((point - coord_min) / spacing).astype(int) + axis_ranges: list[np.ndarray] = [] + axis_kernels: list[np.ndarray] = [] + axis_sums: list[float] = [] + for axis in range(3): + start = max(int(center_index[axis]) - support_radius, 0) + stop = min( + int(center_index[axis]) + support_radius + 1, grid_shape[axis] + ) + index_range = np.arange(start, stop, dtype=int) + axis_positions = ( + coord_min[axis] + index_range.astype(float) * spacing + ) + kernel = np.exp( + -0.5 * np.square((axis_positions - float(point[axis])) / sigma) + ) + axis_ranges.append(index_range) + axis_kernels.append(np.asarray(kernel, dtype=float)) + axis_sums.append(float(np.sum(kernel))) + normalization = axis_sums[0] * axis_sums[1] * axis_sums[2] * d_volume + contribution = ( + (float(atom_weight) / normalization) + * axis_kernels[0][:, np.newaxis, np.newaxis] + * axis_kernels[1][np.newaxis, :, np.newaxis] + * axis_kernels[2][np.newaxis, np.newaxis, :] + ) + density[ + np.ix_(axis_ranges[0], axis_ranges[1], axis_ranges[2]) + ] += contribution + return density, coord_min, span + + +def compute_fft_grid_born_intensity( + coordinates: np.ndarray, + weights: np.ndarray, + q_values: np.ndarray, + settings: GridBornSettings, +) -> GridBornResult: + q_grid = np.asarray(q_values, dtype=float) + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + normalized = settings.normalized() + density, _origin, requested_span = _deposit_density_to_grid( + coords, + atom_weights, + normalized, + ) + spacing = float(normalized.spacing_a) + d_volume = spacing**3 + amplitude = np.fft.fftn(density) * d_volume + intensity_grid = np.square(np.abs(amplitude)) + q_axes = [ + 2.0 * np.pi * np.fft.fftfreq(axis_count, d=spacing) + for axis_count in density.shape + ] + qx, qy, qz = np.meshgrid(*q_axes, indexing="ij") + q_magnitude = np.sqrt(qx * qx + qy * qy + qz * qz).reshape(-1) + flat_intensity = intensity_grid.reshape(-1) + + if q_grid.size > 1: + q_step = float(np.median(np.diff(q_grid))) + else: + q_step = max(float(q_grid[0]) * 0.5, 1.0e-6) + q_edges = np.concatenate( + ( + [max(0.0, float(q_grid[0]) - 0.5 * q_step)], + 0.5 * (q_grid[:-1] + q_grid[1:]), + [float(q_grid[-1]) + 0.5 * q_step], + ) + ) + bin_indices = np.digitize(q_magnitude, q_edges) - 1 + valid_mask = (bin_indices >= 0) & (bin_indices < q_grid.size) + q_shell_counts = np.bincount( + bin_indices[valid_mask], + minlength=q_grid.size, + ) + shell_sums = np.bincount( + bin_indices[valid_mask], + weights=flat_intensity[valid_mask], + minlength=q_grid.size, + ) + shell_average = np.divide( + shell_sums, + q_shell_counts, + out=np.full(q_grid.shape, np.nan, dtype=float), + where=q_shell_counts > 0, + ) + box_lengths = tuple( + float(axis_count * spacing) for axis_count in density.shape + ) + q_frequency_steps = tuple( + 0.0 if len(axis_values) < 2 else float(axis_values[1] - axis_values[0]) + for axis_values in q_axes + ) + return GridBornResult( + settings=normalized, + q_values=np.asarray(q_grid, dtype=float), + intensity=np.asarray(shell_average, dtype=float), + q_shell_counts=np.asarray(q_shell_counts, dtype=int), + density_integral=float(np.sum(density, dtype=float) * d_volume), + expected_weight=float(np.sum(atom_weights, dtype=float)), + grid_shape=tuple(int(axis_count) for axis_count in density.shape), + box_lengths_a=box_lengths, + voxel_spacing_a=(spacing, spacing, spacing), + q_nyquist_a_inverse=float(np.pi / spacing), + q_frequency_step_a_inverse=q_frequency_steps, + q_convention=( + "3D FFT shell average with q = 2πf, where f is the Cartesian " + "FFT frequency in cycles per Å." + ), + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + ) diff --git a/src/saxshell/saxs/contrast/__init__.py b/src/saxshell/saxs/contrast/__init__.py index 1924f25..714561a 100644 --- a/src/saxshell/saxs/contrast/__init__.py +++ b/src/saxshell/saxs/contrast/__init__.py @@ -31,6 +31,8 @@ ) from .settings import ( COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_1D, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, COMPONENT_BUILD_MODE_CONTRAST, COMPONENT_BUILD_MODE_NO_CONTRAST, ContrastModeLaunchContext, @@ -42,6 +44,8 @@ __all__ = [ "COMPONENT_BUILD_MODE_BORN_APPROXIMATION", + "COMPONENT_BUILD_MODE_BORN_APPROXIMATION_1D", + "COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT", "COMPONENT_BUILD_MODE_CONTRAST", "COMPONENT_BUILD_MODE_NO_CONTRAST", "CONTRAST_SOLVENT_METHOD_NEAT", diff --git a/src/saxshell/saxs/contrast/settings.py b/src/saxshell/saxs/contrast/settings.py index 2cae163..2bc88a9 100644 --- a/src/saxshell/saxs/contrast/settings.py +++ b/src/saxshell/saxs/contrast/settings.py @@ -6,10 +6,17 @@ COMPONENT_BUILD_MODE_NO_CONTRAST = "no_contrast" COMPONENT_BUILD_MODE_CONTRAST = "contrast" COMPONENT_BUILD_MODE_BORN_APPROXIMATION = "born_approximation" +COMPONENT_BUILD_MODE_BORN_APPROXIMATION_1D = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION +) +COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT = "born_approximation_3d_fft" _COMPONENT_BUILD_MODE_LABELS = { COMPONENT_BUILD_MODE_NO_CONTRAST: "No Contrast (Debye)", COMPONENT_BUILD_MODE_CONTRAST: "Contrast (Debye)", - COMPONENT_BUILD_MODE_BORN_APPROXIMATION: "Born Approximation (Average)", + COMPONENT_BUILD_MODE_BORN_APPROXIMATION: "1D Born Approximation (Average)", + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT: ( + "3D FFT Born Approximation" + ), } @@ -25,9 +32,21 @@ def normalize_component_build_mode(value: object) -> str: "born_approx", "born_approximation_mode", "born_approximation_average", + "born_approximation_1d", + "born_approximation_average_1d", + "1d_born_approximation", "average", }: return COMPONENT_BUILD_MODE_BORN_APPROXIMATION + if normalized in { + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + "born_approximation_fft", + "born_approximation_3d", + "born_fft", + "3d_fft_born", + "3d_fft_born_approximation", + }: + return COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT return COMPONENT_BUILD_MODE_NO_CONTRAST diff --git a/src/saxshell/saxs/contrast/ui/main_window.py b/src/saxshell/saxs/contrast/ui/main_window.py index 44ccd18..c56875f 100644 --- a/src/saxshell/saxs/contrast/ui/main_window.py +++ b/src/saxshell/saxs/contrast/ui/main_window.py @@ -49,6 +49,7 @@ QWidget, ) +from saxshell.plotting import Q_A_INVERSE_LABEL from saxshell.saxs.contrast.descriptors import ContrastStructureDescriptor from saxshell.saxs.contrast.electron_density import ( CONTRAST_SOLVENT_METHOD_DIRECT, @@ -3121,7 +3122,7 @@ def _apply_trace_axis_style( self, axis, *, is_generated_axis: bool ) -> None: if not is_generated_axis or self._experimental_summary is None: - axis.set_xlabel("q (Å⁻¹)") + axis.set_xlabel(Q_A_INVERSE_LABEL) if not is_generated_axis: axis.set_ylabel("Intensity (arb. units)") axis.grid(True, alpha=0.25) diff --git a/src/saxshell/saxs/contrast_fft/__init__.py b/src/saxshell/saxs/contrast_fft/__init__.py new file mode 100644 index 0000000..0bd4264 --- /dev/null +++ b/src/saxshell/saxs/contrast_fft/__init__.py @@ -0,0 +1,21 @@ +from .backend import ( + ContrastFFTResult, + ContrastFFTSettings, + ContrastFFTTiming, + build_atomic_density_grid, + build_contrast_density_grid, + build_exclusion_mask_grid, + compute_contrast_fft_intensity, + default_contrast_fft_settings, +) + +__all__ = [ + "ContrastFFTResult", + "ContrastFFTSettings", + "ContrastFFTTiming", + "build_atomic_density_grid", + "build_contrast_density_grid", + "build_exclusion_mask_grid", + "compute_contrast_fft_intensity", + "default_contrast_fft_settings", +] diff --git a/src/saxshell/saxs/contrast_fft/backend.py b/src/saxshell/saxs/contrast_fft/backend.py new file mode 100644 index 0000000..50d5014 --- /dev/null +++ b/src/saxshell/saxs/contrast_fft/backend.py @@ -0,0 +1,650 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from time import perf_counter + +import numpy as np + +from saxshell.xyz2pdb.workflow import _covalent_radius + + +@dataclass(slots=True, frozen=True) +class ContrastFFTSettings: + spacing_a: float + gaussian_sigma_a: float + minimum_box_length_a: float + padding_a: float = 0.0 + support_sigma: float = 4.0 + solvent_density_e_per_a3: float = 0.0 + exclusion_radius_scale: float = 1.0 + exclusion_radius_padding_a: float = 0.0 + use_cubic_box: bool = True + + def normalized(self) -> "ContrastFFTSettings": + spacing = float(self.spacing_a) + sigma = float(self.gaussian_sigma_a) + minimum_box_length = float(self.minimum_box_length_a) + padding = float(self.padding_a) + support_sigma = float(self.support_sigma) + solvent_density = float(self.solvent_density_e_per_a3) + exclusion_radius_scale = float(self.exclusion_radius_scale) + exclusion_radius_padding = float(self.exclusion_radius_padding_a) + if spacing <= 0.0: + raise ValueError("spacing_a must be greater than zero.") + if sigma < 0.0: + raise ValueError("gaussian_sigma_a must be non-negative.") + if minimum_box_length < spacing: + raise ValueError( + "minimum_box_length_a must be at least one voxel wide." + ) + if padding < 0.0: + raise ValueError("padding_a must be non-negative.") + if support_sigma <= 0.0: + raise ValueError("support_sigma must be greater than zero.") + if exclusion_radius_scale <= 0.0: + raise ValueError( + "exclusion_radius_scale must be greater than zero." + ) + if exclusion_radius_padding < 0.0: + raise ValueError( + "exclusion_radius_padding_a must be non-negative." + ) + return ContrastFFTSettings( + spacing_a=spacing, + gaussian_sigma_a=sigma, + minimum_box_length_a=minimum_box_length, + padding_a=padding, + support_sigma=support_sigma, + solvent_density_e_per_a3=solvent_density, + exclusion_radius_scale=exclusion_radius_scale, + exclusion_radius_padding_a=exclusion_radius_padding, + use_cubic_box=bool(self.use_cubic_box), + ) + + +def default_contrast_fft_settings( + *, + solvent_density_e_per_a3: float = 0.0, + exclusion_radius_scale: float = 1.0, + exclusion_radius_padding_a: float = 0.0, +) -> ContrastFFTSettings: + return ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + solvent_density_e_per_a3=float(solvent_density_e_per_a3), + exclusion_radius_scale=float(exclusion_radius_scale), + exclusion_radius_padding_a=float(exclusion_radius_padding_a), + ).normalized() + + +@dataclass(slots=True, frozen=True) +class ContrastFFTGrid: + density: np.ndarray + grid_shape: tuple[int, int, int] + origin_a: tuple[float, float, float] + box_lengths_a: tuple[float, float, float] + voxel_spacing_a: tuple[float, float, float] + density_integral: float + expected_weight: float + + +@dataclass(slots=True, frozen=True) +class ContrastFFTTiming: + atomic_density_seconds: float + contrast_density_seconds: float + fft_seconds: float + shell_average_seconds: float + total_seconds: float + + +@dataclass(slots=True, frozen=True) +class ContrastFFTResult: + settings: ContrastFFTSettings + q_values: np.ndarray + raw_intensity: np.ndarray + kernel_corrected_intensity: np.ndarray + q_shell_counts: np.ndarray + density_integral: float + expected_weight: float + contrast_density_integral: float + expected_contrast_weight: float + solvent_exclusion_volume_a3: float + grid_shape: tuple[int, int, int] + box_lengths_a: tuple[float, float, float] + voxel_spacing_a: tuple[float, float, float] + q_nyquist_a_inverse: float + q_frequency_step_a_inverse: tuple[float, float, float] + q_convention: str + uses_two_pi_frequency_conversion: bool + density_subtraction_active: bool + first_nonempty_q_a_inverse: float | None + solvent_density_e_per_a3: float + contrast_mode: str + kernel_correction_supported: bool + kernel_correction_applied: bool + kernel_correction_model: str | None + timing: ContrastFFTTiming + + +def _raise_if_cancelled( + cancelled: Callable[[], bool] | None, +) -> None: + if cancelled is not None and bool(cancelled()): + raise RuntimeError("3D FFT Born calculation cancelled.") + + +def _box_axis_counts( + coordinates: np.ndarray, + settings: ContrastFFTSettings, +) -> tuple[int, int, int]: + spacing = float(settings.spacing_a) + span = np.ptp(np.asarray(coordinates, dtype=float), axis=0) + requested_lengths = np.asarray(span + 2.0 * float(settings.padding_a)) + minimum_length = max(float(settings.minimum_box_length_a), spacing) + if bool(settings.use_cubic_box): + axis_counts = [] + longest = max(float(np.max(requested_lengths)), minimum_length) + count = int(np.ceil(longest / spacing)) + if count % 2 == 0: + count += 1 + axis_counts = [count, count, count] + else: + axis_counts = [] + for requested_length in requested_lengths: + count = int( + np.ceil(max(float(requested_length), minimum_length) / spacing) + ) + if count % 2 == 0: + count += 1 + axis_counts.append(count) + return tuple(int(value) for value in axis_counts) + + +def _grid_origin_from_shape( + grid_shape: tuple[int, int, int], + spacing_a: float, +) -> np.ndarray: + counts = np.asarray(grid_shape, dtype=float) + return -0.5 * (counts - 1.0) * float(spacing_a) + + +def build_atomic_density_grid( + coordinates: np.ndarray, + weights: np.ndarray, + settings: ContrastFFTSettings, + *, + cancelled: Callable[[], bool] | None = None, +) -> ContrastFFTGrid: + normalized = settings.normalized() + coords = np.asarray(coordinates, dtype=float) + atom_weights = np.asarray(weights, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("coordinates must be an Nx3 array.") + if atom_weights.ndim != 1 or atom_weights.shape[0] != coords.shape[0]: + raise ValueError( + "weights must be a one-dimensional array matching coordinates." + ) + if coords.shape[0] == 0: + raise ValueError("At least one atom is required.") + + spacing = float(normalized.spacing_a) + sigma = float(normalized.gaussian_sigma_a) + voxel_volume = spacing**3 + grid_shape = _box_axis_counts(coords, normalized) + origin = _grid_origin_from_shape(grid_shape, spacing) + density = np.zeros(grid_shape, dtype=float) + + if sigma <= 0.0: + grid_indices = np.rint( + (coords - origin[np.newaxis, :]) / spacing + ).astype(int) + for grid_index, atom_weight in zip( + grid_indices, + atom_weights, + strict=False, + ): + _raise_if_cancelled(cancelled) + ix, iy, iz = ( + int(grid_index[0]), + int(grid_index[1]), + int(grid_index[2]), + ) + if ( + ix < 0 + or iy < 0 + or iz < 0 + or ix >= grid_shape[0] + or iy >= grid_shape[1] + or iz >= grid_shape[2] + ): + raise ValueError("An atom fell outside the FFT grid bounds.") + density[ix, iy, iz] += float(atom_weight) / voxel_volume + return ContrastFFTGrid( + density=density, + grid_shape=grid_shape, + origin_a=tuple(float(value) for value in origin), + box_lengths_a=tuple( + float(count * spacing) for count in grid_shape + ), + voxel_spacing_a=(spacing, spacing, spacing), + density_integral=float( + np.sum(density, dtype=float) * voxel_volume + ), + expected_weight=float(np.sum(atom_weights, dtype=float)), + ) + + support_radius = int( + np.ceil(float(normalized.support_sigma) * sigma / spacing) + ) + for point, atom_weight in zip(coords, atom_weights, strict=False): + _raise_if_cancelled(cancelled) + center_index = np.rint((point - origin) / spacing).astype(int) + axis_ranges: list[np.ndarray] = [] + axis_kernels: list[np.ndarray] = [] + axis_sums: list[float] = [] + for axis in range(3): + start = max(int(center_index[axis]) - support_radius, 0) + stop = min( + int(center_index[axis]) + support_radius + 1, + grid_shape[axis], + ) + index_range = np.arange(start, stop, dtype=int) + axis_positions = origin[axis] + index_range.astype(float) * spacing + kernel = np.exp( + -0.5 * np.square((axis_positions - float(point[axis])) / sigma) + ) + axis_ranges.append(index_range) + axis_kernels.append(np.asarray(kernel, dtype=float)) + axis_sums.append(float(np.sum(kernel, dtype=float))) + normalization = ( + axis_sums[0] * axis_sums[1] * axis_sums[2] * voxel_volume + ) + if normalization <= 0.0: + raise ValueError( + "Encountered a non-positive Gaussian normalization." + ) + contribution = ( + (float(atom_weight) / normalization) + * axis_kernels[0][:, np.newaxis, np.newaxis] + * axis_kernels[1][np.newaxis, :, np.newaxis] + * axis_kernels[2][np.newaxis, np.newaxis, :] + ) + density[ + np.ix_(axis_ranges[0], axis_ranges[1], axis_ranges[2]) + ] += contribution + + return ContrastFFTGrid( + density=density, + grid_shape=grid_shape, + origin_a=tuple(float(value) for value in origin), + box_lengths_a=tuple(float(count * spacing) for count in grid_shape), + voxel_spacing_a=(spacing, spacing, spacing), + density_integral=float(np.sum(density, dtype=float) * voxel_volume), + expected_weight=float(np.sum(atom_weights, dtype=float)), + ) + + +def build_exclusion_mask_grid( + coordinates: np.ndarray, + elements: list[str] | tuple[str, ...], + settings: ContrastFFTSettings, + *, + origin_a: tuple[float, float, float], + grid_shape: tuple[int, int, int], + cancelled: Callable[[], bool] | None = None, +) -> np.ndarray: + normalized = settings.normalized() + coords = np.asarray(coordinates, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 3: + raise ValueError("coordinates must be an Nx3 array.") + if len(elements) != coords.shape[0]: + raise ValueError("elements must match the number of coordinates.") + spacing = float(normalized.spacing_a) + origin = np.asarray(origin_a, dtype=float) + mask = np.zeros(grid_shape, dtype=bool) + for point, element in zip(coords, elements, strict=False): + _raise_if_cancelled(cancelled) + radius = float(normalized.exclusion_radius_scale) * float( + _covalent_radius(str(element)) + ) + float(normalized.exclusion_radius_padding_a) + radius = max(radius, 0.0) + if radius <= 0.0: + continue + support_radius = int(np.ceil(radius / spacing)) + center_index = np.rint((point - origin) / spacing).astype(int) + axis_ranges: list[np.ndarray] = [] + axis_positions: list[np.ndarray] = [] + for axis in range(3): + start = max(int(center_index[axis]) - support_radius, 0) + stop = min( + int(center_index[axis]) + support_radius + 1, + grid_shape[axis], + ) + index_range = np.arange(start, stop, dtype=int) + positions = origin[axis] + index_range.astype(float) * spacing + axis_ranges.append(index_range) + axis_positions.append(np.asarray(positions, dtype=float)) + dx = axis_positions[0][:, np.newaxis, np.newaxis] - float(point[0]) + dy = axis_positions[1][np.newaxis, :, np.newaxis] - float(point[1]) + dz = axis_positions[2][np.newaxis, np.newaxis, :] - float(point[2]) + local_mask = np.square(dx) + np.square(dy) + np.square(dz) <= radius**2 + mask[ + np.ix_(axis_ranges[0], axis_ranges[1], axis_ranges[2]) + ] |= local_mask + return np.asarray(mask, dtype=float) + + +def build_contrast_density_grid( + atomic_grid: ContrastFFTGrid, + settings: ContrastFFTSettings, + *, + coordinates: np.ndarray, + elements: list[str] | tuple[str, ...] | None = None, + cancelled: Callable[[], bool] | None = None, +) -> tuple[np.ndarray, float, float]: + normalized = settings.normalized() + atomic_density = np.asarray(atomic_grid.density, dtype=float) + solvent_density = float(normalized.solvent_density_e_per_a3) + if abs(solvent_density) <= 1.0e-15: + return ( + np.asarray(atomic_density, dtype=float), + 0.0, + float(atomic_grid.expected_weight), + ) + if elements is None: + raise ValueError( + "elements are required when solvent_density_e_per_a3 is non-zero." + ) + exclusion_mask = build_exclusion_mask_grid( + coordinates, + elements, + normalized, + origin_a=atomic_grid.origin_a, + grid_shape=atomic_grid.grid_shape, + cancelled=cancelled, + ) + voxel_volume = float(normalized.spacing_a) ** 3 + solvent_exclusion_volume = float( + np.sum(exclusion_mask, dtype=float) * voxel_volume + ) + contrast_density = atomic_density - solvent_density * exclusion_mask + expected_contrast_weight = float( + atomic_grid.expected_weight + - solvent_density * solvent_exclusion_volume + ) + return ( + np.asarray(contrast_density, dtype=float), + solvent_exclusion_volume, + expected_contrast_weight, + ) + + +def _q_bin_edges(q_values: np.ndarray) -> np.ndarray: + q_grid = np.asarray(q_values, dtype=float) + if q_grid.ndim != 1: + raise ValueError("q_values must be one-dimensional.") + if q_grid.size == 0: + raise ValueError("q_values must not be empty.") + if q_grid.size == 1: + q_step = max(float(q_grid[0]) * 0.5, 1.0e-6) + else: + q_step = float(np.median(np.diff(q_grid))) + return np.concatenate( + ( + [max(0.0, float(q_grid[0]) - 0.5 * q_step)], + 0.5 * (q_grid[:-1] + q_grid[1:]), + [float(q_grid[-1]) + 0.5 * q_step], + ) + ) + + +def compute_contrast_fft_intensity( + coordinates: np.ndarray, + weights: np.ndarray, + q_values: np.ndarray, + settings: ContrastFFTSettings, + *, + elements: list[str] | tuple[str, ...] | None = None, + cancelled: Callable[[], bool] | None = None, +) -> ContrastFFTResult: + total_start = perf_counter() + q_grid = np.asarray(q_values, dtype=float) + if q_grid.ndim != 1: + raise ValueError("q_values must be one-dimensional.") + if q_grid.size == 0: + raise ValueError("q_values must not be empty.") + normalized = settings.normalized() + atomic_start = perf_counter() + atomic_grid = build_atomic_density_grid( + coordinates, + weights, + normalized, + cancelled=cancelled, + ) + atomic_seconds = perf_counter() - atomic_start + _raise_if_cancelled(cancelled) + if ( + np.asarray(coordinates, dtype=float).shape[0] == 1 + and abs(float(normalized.solvent_density_e_per_a3)) <= 1.0e-15 + ): + spacing = float(normalized.spacing_a) + intensity = np.full( + q_grid.shape, + float(atomic_grid.expected_weight) ** 2, + dtype=float, + ) + total_seconds = perf_counter() - total_start + return ContrastFFTResult( + settings=normalized, + q_values=np.asarray(q_grid, dtype=float), + raw_intensity=intensity, + kernel_corrected_intensity=np.asarray(intensity, dtype=float), + q_shell_counts=np.ones_like(q_grid, dtype=int), + density_integral=float(atomic_grid.density_integral), + expected_weight=float(atomic_grid.expected_weight), + contrast_density_integral=float(atomic_grid.density_integral), + expected_contrast_weight=float(atomic_grid.expected_weight), + solvent_exclusion_volume_a3=0.0, + grid_shape=tuple(int(value) for value in atomic_grid.grid_shape), + box_lengths_a=tuple( + float(value) for value in atomic_grid.box_lengths_a + ), + voxel_spacing_a=tuple( + float(value) for value in atomic_grid.voxel_spacing_a + ), + q_nyquist_a_inverse=float(np.pi / spacing), + q_frequency_step_a_inverse=tuple( + float(2.0 * np.pi / max(length, spacing)) + for length in atomic_grid.box_lengths_a + ), + q_convention=( + "Direct single-atom Born evaluation for a bare-density " + "single atom; the Cartesian FFT grid is still built for " + "geometry diagnostics." + ), + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + first_nonempty_q_a_inverse=float(q_grid[0]), + solvent_density_e_per_a3=0.0, + contrast_mode="single_atom_bare_density_direct_born", + kernel_correction_supported=False, + kernel_correction_applied=False, + kernel_correction_model=None, + timing=ContrastFFTTiming( + atomic_density_seconds=float(atomic_seconds), + contrast_density_seconds=0.0, + fft_seconds=0.0, + shell_average_seconds=0.0, + total_seconds=float(total_seconds), + ), + ) + contrast_start = perf_counter() + contrast_density, solvent_exclusion_volume, expected_contrast_weight = ( + build_contrast_density_grid( + atomic_grid, + normalized, + coordinates=np.asarray(coordinates, dtype=float), + elements=elements, + cancelled=cancelled, + ) + ) + contrast_seconds = perf_counter() - contrast_start + _raise_if_cancelled(cancelled) + spacing = float(normalized.spacing_a) + voxel_volume = spacing**3 + fft_start = perf_counter() + amplitude = np.fft.rfftn(contrast_density) * voxel_volume + fft_seconds = perf_counter() - fft_start + _raise_if_cancelled(cancelled) + qx_axis = ( + 2.0 * np.pi * np.fft.fftfreq(atomic_grid.grid_shape[0], d=spacing) + ) + qy_axis = ( + 2.0 * np.pi * np.fft.fftfreq(atomic_grid.grid_shape[1], d=spacing) + ) + qz_axis = ( + 2.0 * np.pi * np.fft.rfftfreq(atomic_grid.grid_shape[2], d=spacing) + ) + q_edges = _q_bin_edges(q_grid) + + shell_sums = np.zeros_like(q_grid, dtype=float) + corrected_shell_sums = np.zeros_like(q_grid, dtype=float) + shell_counts = np.zeros_like(q_grid, dtype=int) + qy_squared = np.square(qy_axis)[:, np.newaxis] + qz_squared = np.square(qz_axis)[np.newaxis, :] + z_plane_weights = np.ones_like(qz_axis, dtype=int) + if qz_axis.size > 1: + if atomic_grid.grid_shape[2] % 2 == 0: + z_plane_weights[1:-1] = 2 + else: + z_plane_weights[1:] = 2 + z_plane_weights_2d = z_plane_weights[np.newaxis, :] + sigma = float(normalized.gaussian_sigma_a) + raw_kernel_correction_valid = bool( + sigma > 0.0 + and abs(float(normalized.solvent_density_e_per_a3)) <= 1.0e-15 + ) + shell_average_start = perf_counter() + + for x_index, qx_value in enumerate(qx_axis): + _raise_if_cancelled(cancelled) + q_magnitude = np.sqrt(float(qx_value) ** 2 + qy_squared + qz_squared) + intensity_slice = np.square(np.abs(amplitude[x_index])) + repeated_counts = np.broadcast_to( + z_plane_weights_2d, intensity_slice.shape + ) + bin_indices = np.digitize(q_magnitude.reshape(-1), q_edges) - 1 + valid_mask = (bin_indices >= 0) & (bin_indices < q_grid.size) + flat_counts = repeated_counts.reshape(-1) + flat_intensity = intensity_slice.reshape(-1) + shell_counts += np.bincount( + bin_indices[valid_mask], + weights=flat_counts[valid_mask], + minlength=q_grid.size, + ).astype(int) + shell_sums += np.bincount( + bin_indices[valid_mask], + weights=flat_intensity[valid_mask] * flat_counts[valid_mask], + minlength=q_grid.size, + ) + if raw_kernel_correction_valid: + intensity_response = np.exp( + -np.square(sigma) * np.square(q_magnitude) + ) + corrected_slice = intensity_slice / np.maximum( + intensity_response, 1.0e-12 + ) + corrected_shell_sums += np.bincount( + bin_indices[valid_mask], + weights=corrected_slice.reshape(-1)[valid_mask] + * flat_counts[valid_mask], + minlength=q_grid.size, + ) + + raw_shell_average = np.divide( + shell_sums, + shell_counts, + out=np.full_like(q_grid, np.nan, dtype=float), + where=shell_counts > 0, + ) + if raw_kernel_correction_valid: + corrected_shell_average = np.divide( + corrected_shell_sums, + shell_counts, + out=np.full_like(q_grid, np.nan, dtype=float), + where=shell_counts > 0, + ) + else: + corrected_shell_average = np.asarray(raw_shell_average, dtype=float) + + nonempty = np.flatnonzero(shell_counts > 0) + first_nonempty_q = ( + None + if nonempty.size == 0 + else float(np.asarray(q_grid)[int(nonempty[0])]) + ) + shell_average_seconds = perf_counter() - shell_average_start + contrast_integral = float( + np.sum(contrast_density, dtype=float) * voxel_volume + ) + total_seconds = perf_counter() - total_start + density_subtraction_active = bool( + abs(float(normalized.solvent_density_e_per_a3)) > 1.0e-15 + ) + return ContrastFFTResult( + settings=normalized, + q_values=np.asarray(q_grid, dtype=float), + raw_intensity=np.asarray(raw_shell_average, dtype=float), + kernel_corrected_intensity=np.asarray( + corrected_shell_average, + dtype=float, + ), + q_shell_counts=np.asarray(shell_counts, dtype=int), + density_integral=float(atomic_grid.density_integral), + expected_weight=float(atomic_grid.expected_weight), + contrast_density_integral=contrast_integral, + expected_contrast_weight=float(expected_contrast_weight), + solvent_exclusion_volume_a3=float(solvent_exclusion_volume), + grid_shape=tuple(int(value) for value in atomic_grid.grid_shape), + box_lengths_a=tuple( + float(value) for value in atomic_grid.box_lengths_a + ), + voxel_spacing_a=tuple( + float(value) for value in atomic_grid.voxel_spacing_a + ), + q_nyquist_a_inverse=float(np.pi / spacing), + q_frequency_step_a_inverse=tuple( + float(2.0 * np.pi / max(length, spacing)) + for length in atomic_grid.box_lengths_a + ), + q_convention=( + "3D FFT of a Cartesian contrast-density grid with q = 2πf, " + "followed by radial q-shell averaging of |A(qx, qy, qz)|^2." + ), + uses_two_pi_frequency_conversion=True, + density_subtraction_active=density_subtraction_active, + first_nonempty_q_a_inverse=first_nonempty_q, + solvent_density_e_per_a3=float(normalized.solvent_density_e_per_a3), + contrast_mode=( + "constant_solvent_density_inside_union_of_atomic_spheres" + if density_subtraction_active + else "bare_atomic_density_only" + ), + kernel_correction_supported=raw_kernel_correction_valid, + kernel_correction_applied=raw_kernel_correction_valid, + kernel_correction_model=( + "Gaussian deposition intensity factor exp(-sigma^2 q^2)" + if raw_kernel_correction_valid + else None + ), + timing=ContrastFFTTiming( + atomic_density_seconds=float(atomic_seconds), + contrast_density_seconds=float(contrast_seconds), + fft_seconds=float(fft_seconds), + shell_average_seconds=float(shell_average_seconds), + total_seconds=float(total_seconds), + ), + ) diff --git a/src/saxshell/saxs/contrast_fft/ui/__init__.py b/src/saxshell/saxs/contrast_fft/ui/__init__.py new file mode 100644 index 0000000..bc0e2de --- /dev/null +++ b/src/saxshell/saxs/contrast_fft/ui/__init__.py @@ -0,0 +1,9 @@ +from .main_window import ( + FFTBornApproximationMainWindow, + launch_3d_fft_born_approximation_ui, +) + +__all__ = [ + "FFTBornApproximationMainWindow", + "launch_3d_fft_born_approximation_ui", +] diff --git a/src/saxshell/saxs/contrast_fft/ui/main_window.py b/src/saxshell/saxs/contrast_fft/ui/main_window.py new file mode 100644 index 0000000..bda7e70 --- /dev/null +++ b/src/saxshell/saxs/contrast_fft/ui/main_window.py @@ -0,0 +1,4788 @@ +from __future__ import annotations + +import csv +import json +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from time import perf_counter + +import numpy as np +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar, +) +from matplotlib.figure import Figure +from mpl_toolkits.mplot3d import Axes3D # noqa: F401 +from PySide6.QtCore import ( + QObject, + QSettings, + Qt, + QThread, + QTimer, + Signal, + Slot, +) +from PySide6.QtGui import QAction, QCloseEvent +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGridLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QLineEdit, + QMainWindow, + QMessageBox, + QPlainTextEdit, + QPushButton, + QScrollArea, + QSpinBox, + QSplitter, + QToolButton, + QVBoxLayout, + QWidget, +) + +from saxshell.fullrmc import load_rmc_project_source +from saxshell.plotting import Q_A_INVERSE_LABEL +from saxshell.saxs.born_refinement.backend import ( + build_shared_q_grid, + compute_constant_weight_debye_intensity, +) +from saxshell.saxs.contrast.electron_density import ( + ANGSTROM3_PER_CM3, + CONTRAST_SOLVENT_METHOD_DIRECT, + CONTRAST_SOLVENT_METHOD_NEAT, + CONTRAST_SOLVENT_METHOD_REFERENCE, + ContrastElectronDensityEstimate, + ContrastSolventDensitySettings, + _direct_solvent_electron_density, + _neat_solvent_electron_density, +) +from saxshell.saxs.contrast.solvents import ( + ContrastSolventPreset, + delete_custom_solvent_preset, + load_solvent_presets, + ordered_solvent_preset_names, + save_custom_solvent_preset, +) +from saxshell.saxs.contrast_fft import ( + ContrastFFTResult, + ContrastFFTSettings, + ContrastFFTTiming, + compute_contrast_fft_intensity, + default_contrast_fft_settings, +) +from saxshell.saxs.debye import discover_cluster_bins +from saxshell.saxs.electron_density_mapping.ui.viewer import ( + ElectronDensityStructureViewer, +) +from saxshell.saxs.electron_density_mapping.workflow import ( + ElectronDensityFourierTransformSettings, + ElectronDensityMeshSettings, + ElectronDensitySmearingSettings, + ElectronDensityStructure, + apply_solvent_contrast_to_profile_result, + build_electron_density_mesh, + compute_electron_density_profile, + compute_electron_density_scattering_profile, + inspect_structure_input, + legacy_born_average_default_fourier_settings, + legacy_born_average_default_mesh_settings, + legacy_born_average_default_smearing_settings, + load_electron_density_structure, +) +from saxshell.saxs.ui._pane_snap import PaneSnapFilter +from saxshell.saxs.ui.branding import ( + configure_saxshell_application, + prepare_saxshell_application_identity, +) +from saxshell.saxs.ui.progress_dialog import SAXSProgressDialog + +AUTO_SNAP_PANES_KEY = "contrast_fft_auto_snap_panes_enabled" +_SOLVENT_PRESET_NONE = "__none__" +_OPEN_WINDOWS: list["FFTBornApproximationMainWindow"] = [] + + +class _FFTCancelledError(RuntimeError): + """Raised when the 3D FFT worker is cancelled cooperatively.""" + + +@dataclass(slots=True, frozen=True) +class _FFTProfileTarget: + key: str + display_name: str + structure_name: str + motif_name: str + file_count: int + reference_file: Path + source_files: tuple[Path, ...] + representative: str | None + source_mode: str + solvent_mode: str + + +@dataclass(slots=True, frozen=True) +class _FFTProfileComputationResult: + target: _FFTProfileTarget + q_values: np.ndarray + fft_result: ContrastFFTResult + legacy_q_values: np.ndarray | None + legacy_intensity: np.ndarray | None + exact_debye_intensity: np.ndarray | None + legacy_elapsed_seconds: float | None + debye_elapsed_seconds: float | None + + +@dataclass(slots=True, frozen=True) +class _FFTComputationPayload: + q_values: np.ndarray + profile_results: tuple[_FFTProfileComputationResult, ...] + + +class _CollapsibleSection(QWidget): + toggled = Signal(bool) + + def __init__( + self, + title: str, + body: QWidget, + parent: QWidget | None = None, + ) -> None: + super().__init__(parent) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + + header = QWidget(self) + header_layout = QHBoxLayout(header) + header_layout.setContentsMargins(0, 2, 0, 2) + header_layout.setSpacing(4) + self._toggle_button = QToolButton(self) + self._toggle_button.setToolButtonStyle( + Qt.ToolButtonStyle.ToolButtonTextBesideIcon + ) + self._toggle_button.setArrowType(Qt.ArrowType.RightArrow) + self._toggle_button.setText(title) + self._toggle_button.setAutoRaise(True) + self._toggle_button.clicked.connect(self._toggle) + header_layout.addWidget(self._toggle_button) + header_layout.addStretch(1) + + self._body = body + self._expanded = False + self._body.setVisible(self._expanded) + + layout.addWidget(header) + layout.addWidget(self._body) + + def _toggle(self) -> None: + self.set_expanded(not self._expanded) + + def set_expanded(self, expanded: bool) -> None: + requested = bool(expanded) + if self._expanded == requested: + return + self._expanded = requested + self._body.setVisible(requested) + self._toggle_button.setArrowType( + Qt.ArrowType.DownArrow if requested else Qt.ArrowType.RightArrow + ) + self._body.updateGeometry() + self.updateGeometry() + self.toggled.emit(requested) + + def expand(self) -> None: + self.set_expanded(True) + + def collapse(self) -> None: + self.set_expanded(False) + + @property + def is_expanded(self) -> bool: + return self._expanded + + +class _FFTComparisonPlot(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.figure = Figure(figsize=(8.4, 3.6)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(4) + layout.addWidget(self.toolbar) + self.canvas.setMinimumHeight(300) + layout.addWidget(self.canvas, stretch=1) + self.draw_placeholder() + + def draw_placeholder(self) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.58, + "Run the 3D FFT Born calculation to populate the q-space comparison plot.", + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + axis.text( + 0.5, + 0.39, + "Optional overlays can include the legacy 1D Born approximation, " + "exact Debye scattering, and the zero-contrast kernel-corrected FFT diagnostic.", + ha="center", + va="center", + wrap=True, + alpha=0.78, + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + + def set_curves( + self, + *, + q_values: np.ndarray, + primary_values: np.ndarray, + primary_label: str, + additional_series: list[dict[str, object]], + log_q_axis: bool, + log_intensity_axis: bool, + show_legend: bool, + ) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111) + self._plot_series( + axis, + q_values=np.asarray(q_values, dtype=float), + intensity=np.asarray(primary_values, dtype=float), + label=str(primary_label), + color="#1d4ed8", + linestyle="-", + linewidth=2.3, + log_q_axis=log_q_axis, + log_intensity_axis=log_intensity_axis, + ) + for series in additional_series: + self._plot_series( + axis, + q_values=np.asarray(series.get("q_values"), dtype=float), + intensity=np.asarray(series.get("intensity"), dtype=float), + label=str(series.get("label") or "Comparison"), + color=str(series.get("color") or "#64748b"), + linestyle=str(series.get("linestyle") or "--"), + linewidth=float(series.get("linewidth") or 1.6), + log_q_axis=log_q_axis, + log_intensity_axis=log_intensity_axis, + ) + if log_q_axis: + axis.set_xscale("log") + if log_intensity_axis: + axis.set_yscale("log") + axis.set_xlabel(Q_A_INVERSE_LABEL, labelpad=10.0) + axis.set_ylabel("Intensity (arb. units)") + axis.set_title("3D FFT Born Approximation") + axis.grid(True, which="both", alpha=0.28) + handles, labels = axis.get_legend_handles_labels() + if show_legend and handles: + axis.legend(loc="lower left", frameon=True) + self.figure.tight_layout() + self.canvas.draw_idle() + + def _plot_series( + self, + axis, + *, + q_values: np.ndarray, + intensity: np.ndarray, + label: str, + color: str, + linestyle: str, + linewidth: float, + log_q_axis: bool, + log_intensity_axis: bool, + ) -> None: + mask = np.ones_like(q_values, dtype=bool) + if log_q_axis: + mask &= q_values > 0.0 + if log_intensity_axis: + mask &= intensity > 0.0 + filtered_q = q_values[mask] + filtered_intensity = intensity[mask] + if filtered_q.size == 0 or filtered_intensity.size == 0: + return + axis.plot( + filtered_q, + filtered_intensity, + color=color, + linestyle=linestyle, + linewidth=linewidth, + label=label, + ) + + +class _FFTShellCountPlot(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.figure = Figure(figsize=(8.4, 2.8)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(4) + layout.addWidget(self.toolbar) + self.canvas.setMinimumHeight(240) + layout.addWidget(self.canvas, stretch=1) + self.draw_placeholder() + + def draw_placeholder(self) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111) + axis.text( + 0.5, + 0.5, + "q-shell population diagnostics will appear after the FFT run.", + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + + def set_counts( + self, q_values: np.ndarray, q_shell_counts: np.ndarray + ) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111) + axis.plot( + np.asarray(q_values, dtype=float), + np.asarray(q_shell_counts, dtype=int), + color="#0f766e", + linewidth=1.9, + ) + axis.set_xlabel(Q_A_INVERSE_LABEL) + axis.set_ylabel("Shell count") + axis.set_title("3D FFT q-Shell Population") + axis.grid(True, alpha=0.28) + self.figure.tight_layout() + self.canvas.draw_idle() + + +class _FFTRealSpaceVisualizer(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.figure = Figure(figsize=(8.4, 3.3)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(4) + layout.addWidget(self.toolbar) + self.canvas.setMinimumHeight(260) + layout.addWidget(self.canvas, stretch=1) + self.draw_placeholder() + + def draw_placeholder(self) -> None: + self.figure.clear() + axis = self.figure.add_subplot(111, projection="3d") + axis.text2D( + 0.5, + 0.58, + "Load a structure to preview its centered 3D real-space geometry.", + ha="center", + va="center", + wrap=True, + transform=axis.transAxes, + ) + axis.text2D( + 0.5, + 0.39, + "After the 3D FFT run, this panel adds the FFT box volume and a zoomed 3D structure view.", + ha="center", + va="center", + wrap=True, + alpha=0.78, + transform=axis.transAxes, + ) + axis.set_axis_off() + self.canvas.draw_idle() + + def set_structure_preview( + self, + structure: ElectronDensityStructure | None, + *, + fft_result: ContrastFFTResult | None = None, + contrast_summary: str | None = None, + ) -> None: + if structure is None: + self.draw_placeholder() + return + coordinates = np.asarray(structure.centered_coordinates, dtype=float) + if coordinates.size == 0: + self.draw_placeholder() + return + self.figure.clear() + if fft_result is None: + axis = self.figure.add_subplot(111, projection="3d") + self._draw_scene( + axis, + coordinates=coordinates, + title="Centered Structure Preview", + ) + axis.text2D( + 0.02, + 0.98, + "Run the 3D FFT calculation to overlay the FFT box volume and spacing diagnostics.", + ha="left", + va="top", + wrap=True, + fontsize=9.0, + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": "#ffffff", + "edgecolor": "#cbd5e1", + "alpha": 0.94, + }, + transform=axis.transAxes, + ) + self.figure.tight_layout() + self.canvas.draw_idle() + return + full_axis = self.figure.add_subplot(121, projection="3d") + zoom_axis = self.figure.add_subplot(122, projection="3d") + self._draw_scene( + full_axis, + coordinates=coordinates, + title="FFT Box Volume", + box_lengths=fft_result.box_lengths_a, + fit_to_box=True, + ) + self._draw_scene( + zoom_axis, + coordinates=coordinates, + title="Structure Zoom", + ) + summary_lines = [ + ( + "Box (Å): " + + " × ".join( + f"{value:.1f}" for value in fft_result.box_lengths_a + ) + ), + ( + "Spacing: " + f"{fft_result.voxel_spacing_a[0]:.3f} Å, " + f"q_Nyquist={fft_result.q_nyquist_a_inverse:.3f} Å⁻¹" + ), + ] + if contrast_summary: + summary_lines.append(str(contrast_summary).strip()) + full_axis.text2D( + 0.02, + 0.02, + "\n".join(summary_lines), + ha="left", + va="bottom", + fontsize=8.8, + bbox={ + "boxstyle": "round,pad=0.3", + "facecolor": "#ffffff", + "edgecolor": "#cbd5e1", + "alpha": 0.94, + }, + transform=full_axis.transAxes, + ) + self.figure.tight_layout() + self.canvas.draw_idle() + + def _draw_scene( + self, + axis, + *, + coordinates: np.ndarray, + title: str, + box_lengths: tuple[float, float, float] | None = None, + fit_to_box: bool = False, + ) -> None: + xyz = np.asarray(coordinates[:, :3], dtype=float) + radial_distance = np.linalg.norm(xyz, axis=1) + axis.scatter( + xyz[:, 0], + xyz[:, 1], + xyz[:, 2], + s=26.0, + c=radial_distance, + cmap="viridis", + alpha=0.90, + edgecolors="#0f172a", + linewidths=0.2, + depthshade=True, + ) + if box_lengths is not None: + self._draw_fft_box(axis, box_lengths=box_lengths) + if fit_to_box: + half_lengths = 0.5 * np.asarray(box_lengths, dtype=float) + margin = max(float(np.max(half_lengths)) * 0.08, 1.0) + self._set_3d_limits( + axis, + ( + -float(half_lengths[0]) - margin, + float(half_lengths[0]) + margin, + ), + ( + -float(half_lengths[1]) - margin, + float(half_lengths[1]) + margin, + ), + ( + -float(half_lengths[2]) - margin, + float(half_lengths[2]) + margin, + ), + ) + if not fit_to_box: + x_min = float(np.min(xyz[:, 0])) + x_max = float(np.max(xyz[:, 0])) + y_min = float(np.min(xyz[:, 1])) + y_max = float(np.max(xyz[:, 1])) + z_min = float(np.min(xyz[:, 2])) + z_max = float(np.max(xyz[:, 2])) + pad = ( + max( + x_max - x_min, + y_max - y_min, + z_max - z_min, + 1.0, + ) + * 0.18 + ) + self._set_3d_limits( + axis, + (x_min - pad, x_max + pad), + (y_min - pad, y_max + pad), + (z_min - pad, z_max + pad), + ) + axis.set_xlabel("x (Å)") + axis.set_ylabel("y (Å)") + axis.set_zlabel("z (Å)") + axis.set_title(title) + axis.view_init( + elev=20.0 if fit_to_box else 24.0, + azim=36.0 if fit_to_box else 48.0, + ) + axis.grid(True, alpha=0.24) + + def _draw_fft_box( + self, + axis, + *, + box_lengths: tuple[float, float, float], + ) -> None: + half_lengths = 0.5 * np.asarray(box_lengths, dtype=float) + x_half, y_half, z_half = ( + float(half_lengths[0]), + float(half_lengths[1]), + float(half_lengths[2]), + ) + corners = np.asarray( + [ + [-x_half, -y_half, -z_half], + [x_half, -y_half, -z_half], + [x_half, y_half, -z_half], + [-x_half, y_half, -z_half], + [-x_half, -y_half, z_half], + [x_half, -y_half, z_half], + [x_half, y_half, z_half], + [-x_half, y_half, z_half], + ], + dtype=float, + ) + edges = ( + (0, 1), + (1, 2), + (2, 3), + (3, 0), + (4, 5), + (5, 6), + (6, 7), + (7, 4), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + ) + for start, stop in edges: + axis.plot( + [corners[start, 0], corners[stop, 0]], + [corners[start, 1], corners[stop, 1]], + [corners[start, 2], corners[stop, 2]], + color="#1d4ed8", + linestyle="--", + linewidth=1.2, + alpha=0.95, + ) + + def _set_3d_limits( + self, + axis, + x_limits: tuple[float, float], + y_limits: tuple[float, float], + z_limits: tuple[float, float], + ) -> None: + axis.set_xlim(*x_limits) + axis.set_ylim(*y_limits) + axis.set_zlim(*z_limits) + spans = np.asarray( + [ + max(float(x_limits[1] - x_limits[0]), 1.0), + max(float(y_limits[1] - y_limits[0]), 1.0), + max(float(z_limits[1] - z_limits[0]), 1.0), + ], + dtype=float, + ) + axis.set_box_aspect(tuple(float(value) for value in spans)) + + +class _FFTComputationWorker(QObject): + progress = Signal(str) + finished = Signal(object) + failed = Signal(str) + cancelled = Signal(str) + + def __init__( + self, + *, + targets: tuple[_FFTProfileTarget, ...], + fft_settings: ContrastFFTSettings, + legacy_mesh_settings: ElectronDensityMeshSettings | None, + legacy_smearing_settings: ElectronDensitySmearingSettings | None, + legacy_fourier_settings: ( + ElectronDensityFourierTransformSettings | None + ), + active_contrast_settings: ContrastSolventDensitySettings | None, + active_contrast_name: str | None, + q_min: float, + q_max: float, + q_step: float, + compare_legacy_1d: bool, + compare_exact_debye: bool, + ) -> None: + super().__init__() + self._targets = tuple(targets) + self._fft_settings = fft_settings + self._legacy_mesh_settings = ( + None + if legacy_mesh_settings is None + else legacy_mesh_settings.normalized() + ) + self._legacy_smearing_settings = ( + None + if legacy_smearing_settings is None + else legacy_smearing_settings.normalized() + ) + self._legacy_fourier_settings = ( + None + if legacy_fourier_settings is None + else legacy_fourier_settings.normalized() + ) + self._active_contrast_settings = ( + None + if active_contrast_settings is None + else ContrastSolventDensitySettings.from_values( + **active_contrast_settings.to_dict() + ) + ) + self._active_contrast_name = ( + None if active_contrast_name is None else str(active_contrast_name) + ) + self._q_min = float(q_min) + self._q_max = float(q_max) + self._q_step = float(q_step) + self._compare_legacy_1d = bool(compare_legacy_1d) + self._compare_exact_debye = bool(compare_exact_debye) + self._cancel_requested = False + + def cancel(self) -> None: + self._cancel_requested = True + + def _raise_if_cancelled(self) -> None: + if self._cancel_requested: + raise _FFTCancelledError("3D FFT Born calculation cancelled.") + + @Slot() + def run(self) -> None: + try: + if not self._targets: + raise ValueError( + "No 3D FFT profile targets were prepared for the calculation." + ) + self._raise_if_cancelled() + self.progress.emit("Preparing shared q grid...") + q_values = build_shared_q_grid( + self._q_min, + self._q_max, + q_step=self._q_step, + ) + profile_results: list[_FFTProfileComputationResult] = [] + total_targets = len(self._targets) + for target_index, target in enumerate(self._targets, start=1): + self._raise_if_cancelled() + profile_results.append( + self._compute_profile_result_for_target( + target, + q_values=q_values, + target_index=target_index, + total_targets=total_targets, + ) + ) + self._raise_if_cancelled() + self.progress.emit("Updating 3D FFT Born outputs...") + self.finished.emit( + _FFTComputationPayload( + q_values=np.asarray(q_values, dtype=float), + profile_results=tuple(profile_results), + ) + ) + except _FFTCancelledError as exc: + self.cancelled.emit(str(exc)) + except Exception as exc: # pragma: no cover - defensive UI worker path + self.failed.emit(str(exc)) + + def _compute_profile_result_for_target( + self, + target: _FFTProfileTarget, + *, + q_values: np.ndarray, + target_index: int, + total_targets: int, + ) -> _FFTProfileComputationResult: + self._raise_if_cancelled() + self.progress.emit( + "Running 3D FFT Born approximation for " + f"{target.display_name} ({target_index}/{total_targets})..." + ) + fft_results: list[ContrastFFTResult] = [] + legacy_q_values: np.ndarray | None = None + legacy_intensities: list[np.ndarray] = [] + exact_debye_intensities: list[np.ndarray] = [] + legacy_elapsed_seconds = 0.0 + debye_elapsed_seconds = 0.0 + + for file_index, file_path in enumerate(target.source_files, start=1): + self._raise_if_cancelled() + self.progress.emit( + f"Loading structure {file_index}/{len(target.source_files)} " + f"for {target.display_name}: {file_path.name}" + ) + structure = load_electron_density_structure(file_path) + coordinates = np.asarray( + structure.centered_coordinates, + dtype=float, + ) + weights = np.asarray(structure.atomic_numbers, dtype=float) + fft_results.append( + compute_contrast_fft_intensity( + coordinates, + weights, + q_values, + self._fft_settings, + elements=structure.elements, + cancelled=lambda: self._cancel_requested, + ) + ) + self._raise_if_cancelled() + if self._compare_legacy_1d: + legacy_start = perf_counter() + mesh_settings = self._legacy_mesh_settings + if mesh_settings is None: + mesh_settings = legacy_born_average_default_mesh_settings( + structure + ) + profile = compute_electron_density_profile( + structure, + mesh_settings, + smearing_settings=( + self._legacy_smearing_settings + or legacy_born_average_default_smearing_settings() + ), + ) + if self._active_contrast_settings is not None: + profile = apply_solvent_contrast_to_profile_result( + profile, + self._active_contrast_settings, + solvent_name=self._active_contrast_name, + ) + legacy_fourier_template = self._legacy_fourier_settings + if legacy_fourier_template is None: + legacy_fourier_template = ( + legacy_born_average_default_fourier_settings( + r_max=float(profile.radial_centers[-1]), + q_min=float(q_values[0]), + q_max=float(q_values[-1]), + q_step=float(np.median(np.diff(q_values))), + ) + ) + legacy_result = compute_electron_density_scattering_profile( + profile, + ElectronDensityFourierTransformSettings( + r_min=float(legacy_fourier_template.r_min), + r_max=float(legacy_fourier_template.r_max), + domain_mode=str(legacy_fourier_template.domain_mode), + window_function=str( + legacy_fourier_template.window_function + ), + resampling_points=int( + legacy_fourier_template.resampling_points + ), + q_min=float(q_values[0]), + q_max=float(q_values[-1]), + q_step=float(np.median(np.diff(q_values))), + use_solvent_subtracted_profile=bool( + self._active_contrast_settings is not None + or legacy_fourier_template.use_solvent_subtracted_profile + ), + log_q_axis=bool(legacy_fourier_template.log_q_axis), + log_intensity_axis=bool( + legacy_fourier_template.log_intensity_axis + ), + ).normalized(), + ) + legacy_q_values = np.asarray( + legacy_result.q_values, + dtype=float, + ) + legacy_intensities.append( + np.asarray(legacy_result.intensity, dtype=float) + ) + legacy_elapsed_seconds += float(perf_counter() - legacy_start) + if self._compare_exact_debye: + self._raise_if_cancelled() + debye_start = perf_counter() + exact_debye_intensities.append( + np.asarray( + compute_constant_weight_debye_intensity( + coordinates, + weights, + q_values, + ), + dtype=float, + ) + ) + debye_elapsed_seconds += float(perf_counter() - debye_start) + + aggregated_fft_result = self._aggregate_fft_results( + q_values, + fft_results, + ) + return _FFTProfileComputationResult( + target=target, + q_values=np.asarray(q_values, dtype=float), + fft_result=aggregated_fft_result, + legacy_q_values=( + None + if legacy_q_values is None + else np.asarray(legacy_q_values, dtype=float) + ), + legacy_intensity=( + None + if not legacy_intensities + else np.nanmean(np.vstack(legacy_intensities), axis=0) + ), + exact_debye_intensity=( + None + if not exact_debye_intensities + else np.nanmean(np.vstack(exact_debye_intensities), axis=0) + ), + legacy_elapsed_seconds=( + None + if not legacy_intensities + else legacy_elapsed_seconds / float(len(legacy_intensities)) + ), + debye_elapsed_seconds=( + None + if not exact_debye_intensities + else debye_elapsed_seconds + / float(len(exact_debye_intensities)) + ), + ) + + def _aggregate_fft_results( + self, + q_values: np.ndarray, + fft_results: list[ContrastFFTResult], + ) -> ContrastFFTResult: + if not fft_results: + raise ValueError("No FFT results were available to aggregate.") + if len(fft_results) == 1: + return fft_results[0] + first = fft_results[0] + + def _mean_array(values: list[np.ndarray]) -> np.ndarray: + return np.nanmean(np.stack(values, axis=0), axis=0) + + def _mean_scalar(name: str) -> float: + return float( + np.mean( + [float(getattr(result, name)) for result in fft_results], + dtype=float, + ) + ) + + def _mean_triplet(name: str) -> tuple[float, float, float]: + return tuple( + float( + np.mean( + [ + float(getattr(result, name)[axis]) + for result in fft_results + ], + dtype=float, + ) + ) + for axis in range(3) + ) + + mean_shell_counts = np.rint( + np.mean( + [ + np.asarray(result.q_shell_counts, dtype=float) + for result in fft_results + ], + axis=0, + ) + ).astype(int) + nonempty = np.flatnonzero(mean_shell_counts > 0) + first_nonempty_q = ( + None + if nonempty.size == 0 + else float(np.asarray(q_values, dtype=float)[int(nonempty[0])]) + ) + return ContrastFFTResult( + settings=first.settings, + q_values=np.asarray(q_values, dtype=float), + raw_intensity=_mean_array( + [ + np.asarray(result.raw_intensity, dtype=float) + for result in fft_results + ] + ), + kernel_corrected_intensity=_mean_array( + [ + np.asarray( + result.kernel_corrected_intensity, + dtype=float, + ) + for result in fft_results + ] + ), + q_shell_counts=np.asarray(mean_shell_counts, dtype=int), + density_integral=_mean_scalar("density_integral"), + expected_weight=_mean_scalar("expected_weight"), + contrast_density_integral=_mean_scalar( + "contrast_density_integral" + ), + expected_contrast_weight=_mean_scalar("expected_contrast_weight"), + solvent_exclusion_volume_a3=_mean_scalar( + "solvent_exclusion_volume_a3" + ), + grid_shape=first.grid_shape, + box_lengths_a=_mean_triplet("box_lengths_a"), + voxel_spacing_a=_mean_triplet("voxel_spacing_a"), + q_nyquist_a_inverse=_mean_scalar("q_nyquist_a_inverse"), + q_frequency_step_a_inverse=_mean_triplet( + "q_frequency_step_a_inverse" + ), + q_convention=first.q_convention, + uses_two_pi_frequency_conversion=bool( + first.uses_two_pi_frequency_conversion + ), + density_subtraction_active=bool(first.density_subtraction_active), + first_nonempty_q_a_inverse=first_nonempty_q, + solvent_density_e_per_a3=_mean_scalar("solvent_density_e_per_a3"), + contrast_mode=first.contrast_mode, + kernel_correction_supported=all( + bool(result.kernel_correction_supported) + for result in fft_results + ), + kernel_correction_applied=all( + bool(result.kernel_correction_applied) + for result in fft_results + ), + kernel_correction_model=first.kernel_correction_model, + timing=type(first.timing)( + atomic_density_seconds=float( + np.mean( + [ + result.timing.atomic_density_seconds + for result in fft_results + ], + dtype=float, + ) + ), + contrast_density_seconds=float( + np.mean( + [ + result.timing.contrast_density_seconds + for result in fft_results + ], + dtype=float, + ) + ), + fft_seconds=float( + np.mean( + [result.timing.fft_seconds for result in fft_results], + dtype=float, + ) + ), + shell_average_seconds=float( + np.mean( + [ + result.timing.shell_average_seconds + for result in fft_results + ], + dtype=float, + ) + ), + total_seconds=float( + np.mean( + [ + result.timing.total_seconds + for result in fft_results + ], + dtype=float, + ) + ), + ), + ) + + +class FFTBornApproximationMainWindow(QMainWindow): + born_components_built = Signal(object) + + def __init__( + self, + *, + initial_project_dir: Path | None = None, + initial_input_path: Path | None = None, + initial_output_dir: Path | None = None, + initial_project_q_min: float | None = None, + initial_project_q_max: float | None = None, + initial_distribution_id: str | None = None, + initial_distribution_root_dir: Path | None = None, + initial_use_predicted_structure_weights: bool = False, + initial_use_representative_structures: bool = False, + preview_mode: bool = True, + ) -> None: + super().__init__() + self._preview_mode = bool(preview_mode) + self._project_dir = initial_project_dir + self._output_dir = initial_output_dir + self._project_q_min = initial_project_q_min + self._project_q_max = initial_project_q_max + self._distribution_id = ( + None + if initial_distribution_id is None + else str(initial_distribution_id) + ) + self._distribution_root_dir = initial_distribution_root_dir + self._use_predicted_structure_weights = bool( + initial_use_predicted_structure_weights + ) + self._prefer_representative_structures = bool( + initial_use_representative_structures + ) + self._auto_snap_panes_enabled = self._load_auto_snap_setting() + self._deferred_initial_input_path = initial_input_path + self._loaded_input_path: Path | None = None + self._loaded_reference_file: Path | None = None + self._loaded_structure: ElectronDensityStructure | None = None + self._loaded_structure_count = 0 + self._available_profile_targets: dict[ + tuple[str, str], + tuple[_FFTProfileTarget, ...], + ] = {} + self._current_profile_targets: tuple[_FFTProfileTarget, ...] = () + self._reference_structure_cache: dict[ + str, ElectronDensityStructure + ] = {} + self._active_profile_key: str | None = None + self._solvent_presets: dict[str, ContrastSolventPreset] = {} + self._active_contrast_settings: ( + ContrastSolventDensitySettings | None + ) = None + self._active_contrast_estimate: ( + ContrastElectronDensityEstimate | None + ) = None + self._active_contrast_name: str | None = None + self._active_solvent_density_e_per_a3 = 0.0 + self._current_payload: _FFTComputationPayload | None = None + self._computed_profile_results: dict[ + str, _FFTProfileComputationResult + ] = {} + self._computed_profile_run_signature: dict[str, object] | None = None + self._restoring_workspace_state = False + self._contrast_controls_dirty = False + self._curve_legend_visible = True + self._compute_thread: QThread | None = None + self._compute_worker: _FFTComputationWorker | None = None + self._progress_dialog: SAXSProgressDialog | None = None + self._close_requested_while_running = False + self._build_ui() + self._build_menu_bar() + self._apply_preview_mode_title() + self._refresh_preview_mode_banner() + self._reload_solvent_presets(selected_name="Water") + self._sync_density_method_controls() + self._refresh_contrast_display() + self._connect_trace_configuration_controls() + self._update_curve_legend_button_text() + self._sync_kernel_correction_option() + self._update_push_to_model_state() + if self._deferred_initial_input_path is not None: + QTimer.singleShot(0, self._load_deferred_input) + + def _build_ui(self) -> None: + central = QWidget(self) + root_layout = QVBoxLayout(central) + root_layout.setContentsMargins(8, 8, 8, 8) + root_layout.setSpacing(8) + self._pane_splitter = QSplitter(Qt.Orientation.Horizontal, central) + root_layout.addWidget(self._pane_splitter, stretch=1) + self.setCentralWidget(central) + + self._left_scroll_area = QScrollArea(self) + self._left_scroll_area.setWidgetResizable(True) + self._left_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + self._left_scroll_area.setVerticalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + left_container = QWidget(self) + left_layout = QVBoxLayout(left_container) + left_layout.setContentsMargins(8, 8, 8, 8) + left_layout.setSpacing(10) + self.preview_mode_banner = QLabel() + self.preview_mode_banner.setWordWrap(True) + self.preview_mode_banner.setStyleSheet( + "QLabel { background: #f8fafc; border: 1px solid #cbd5e1; " + "border-radius: 8px; padding: 10px; }" + ) + left_layout.addWidget(self.preview_mode_banner) + left_layout.addWidget(self._build_input_group()) + self.fft_settings_section = _CollapsibleSection( + "3D FFT Settings", + self._build_fft_settings_group(), + self, + ) + self.fft_settings_section.expand() + left_layout.addWidget(self.fft_settings_section) + self.legacy_1d_settings_section = _CollapsibleSection( + "1D FFT Comparison Settings", + self._build_legacy_1d_settings_group(), + self, + ) + left_layout.addWidget(self.legacy_1d_settings_section) + self.contrast_section = _CollapsibleSection( + "Electron Density Contrast", + self._build_contrast_group(), + self, + ) + self.contrast_section.expand() + left_layout.addWidget(self.contrast_section) + self.comparison_section = _CollapsibleSection( + "Comparison Overlays", + self._build_comparison_group(), + self, + ) + self.comparison_section.expand() + left_layout.addWidget(self.comparison_section) + left_layout.addWidget(self._build_plot_options_group()) + left_layout.addWidget(self._build_actions_group()) + left_layout.addWidget(self._build_log_group(), stretch=1) + left_layout.addStretch(1) + self._left_scroll_area.setWidget(left_container) + self._pane_splitter.addWidget(self._left_scroll_area) + + self._right_scroll_area = QScrollArea(self) + self._right_scroll_area.setWidgetResizable(True) + self._right_scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + self._right_scroll_area.setVerticalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + right_container = QWidget(self) + right_layout = QVBoxLayout(right_container) + right_layout.setContentsMargins(8, 8, 8, 8) + right_layout.setSpacing(10) + self.structure_viewer = ElectronDensityStructureViewer(self) + structure_group = QGroupBox("Structure Viewer") + structure_layout = QVBoxLayout(structure_group) + structure_layout.addWidget(self.structure_viewer) + right_layout.addWidget(structure_group) + self.curve_plot = _FFTComparisonPlot(self) + curve_group = QGroupBox("q-Space Curves") + curve_layout = QVBoxLayout(curve_group) + curve_controls = QHBoxLayout() + curve_controls.setContentsMargins(0, 0, 0, 0) + curve_controls.setSpacing(6) + self.toggle_curve_legend_button = QPushButton("Hide Legend") + self.toggle_curve_legend_button.clicked.connect( + self._toggle_curve_legend + ) + curve_controls.addWidget(self.toggle_curve_legend_button) + self.export_curve_csv_button = QPushButton("Export Plot CSV") + self.export_curve_csv_button.clicked.connect( + self._export_q_space_curves_csv + ) + curve_controls.addWidget(self.export_curve_csv_button) + curve_controls.addStretch(1) + curve_layout.addLayout(curve_controls) + curve_layout.addWidget(self.curve_plot) + right_layout.addWidget(curve_group) + self.fft_box_visualizer = _FFTRealSpaceVisualizer(self) + fft_visualizer_group = QGroupBox("FFT Real-Space Visualizer") + fft_visualizer_layout = QVBoxLayout(fft_visualizer_group) + fft_visualizer_layout.addWidget(self.fft_box_visualizer) + right_layout.addWidget(fft_visualizer_group) + self.shell_count_plot = _FFTShellCountPlot(self) + shell_group = QGroupBox("FFT Shell Diagnostics") + shell_layout = QVBoxLayout(shell_group) + shell_layout.addWidget(self.shell_count_plot) + right_layout.addWidget(shell_group) + self.result_summary_box = QPlainTextEdit(self) + self.result_summary_box.setReadOnly(True) + self.result_summary_box.setMinimumHeight(180) + summary_group = QGroupBox("Run Summary") + summary_layout = QVBoxLayout(summary_group) + summary_layout.addWidget(self.result_summary_box) + right_layout.addWidget(summary_group) + right_layout.addStretch(1) + self._right_scroll_area.setWidget(right_container) + self._pane_splitter.addWidget(self._right_scroll_area) + self._pane_splitter.setStretchFactor(0, 0) + self._pane_splitter.setStretchFactor(1, 1) + self._pane_splitter.setSizes([430, 1090]) + self._auto_snap_filter = PaneSnapFilter( + self._pane_splitter, + self._left_scroll_area, + self._right_scroll_area, + self, + ) + self._set_auto_snap_panes_enabled( + self._auto_snap_panes_enabled, + persist=False, + ) + + def _build_menu_bar(self) -> None: + settings_menu = self.menuBar().addMenu("Settings") + self.auto_snap_panes_action = QAction("Auto-Snap Panes", self) + self.auto_snap_panes_action.setCheckable(True) + self.auto_snap_panes_action.setChecked( + bool(self._auto_snap_panes_enabled) + ) + self.auto_snap_panes_action.triggered.connect( + self._toggle_auto_snap_panes + ) + settings_menu.addAction(self.auto_snap_panes_action) + + def _build_input_group(self) -> QWidget: + group = QGroupBox("Input") + layout = QVBoxLayout(group) + path_row = QHBoxLayout() + self.input_path_edit = QLineEdit(self) + path_row.addWidget(self.input_path_edit, stretch=1) + browse_file_button = QPushButton("Open File...") + browse_file_button.clicked.connect(self._browse_input_file) + path_row.addWidget(browse_file_button) + browse_folder_button = QPushButton("Open Folder...") + browse_folder_button.clicked.connect(self._browse_input_folder) + path_row.addWidget(browse_folder_button) + layout.addLayout(path_row) + load_row = QHBoxLayout() + self.load_input_button = QPushButton("Load Input") + self.load_input_button.clicked.connect(self._load_input_from_edit) + load_row.addWidget(self.load_input_button) + load_row.addStretch(1) + layout.addLayout(load_row) + source_grid = QGridLayout() + source_grid.setHorizontalSpacing(8) + source_grid.setVerticalSpacing(6) + self.structure_source_combo = QComboBox(self) + self.structure_source_combo.addItem( + "Average cluster folders / input structures", + "average", + ) + self.structure_source_combo.addItem( + "Representative structures", + "representative", + ) + self.structure_source_combo.currentIndexChanged.connect( + self._on_structure_source_mode_changed + ) + source_grid.addWidget(QLabel("Structure source"), 0, 0) + source_grid.addWidget(self.structure_source_combo, 0, 1) + self.representative_solvent_mode_combo = QComboBox(self) + self.representative_solvent_mode_combo.currentIndexChanged.connect( + self._refresh_available_profile_targets + ) + source_grid.addWidget(QLabel("Representative solvent"), 1, 0) + source_grid.addWidget(self.representative_solvent_mode_combo, 1, 1) + self.active_profile_combo = QComboBox(self) + self.active_profile_combo.currentIndexChanged.connect( + self._on_active_profile_changed + ) + source_grid.addWidget(QLabel("Active profile"), 2, 0) + source_grid.addWidget(self.active_profile_combo, 2, 1) + layout.addLayout(source_grid) + self.structure_source_hint = QLabel( + "Average structure mode is active until representative targets or cluster bins are detected." + ) + self.structure_source_hint.setWordWrap(True) + layout.addWidget(self.structure_source_hint) + self.loaded_input_summary = QLabel( + "Load a structure file or folder to inspect the 3D FFT Born setup." + ) + self.loaded_input_summary.setWordWrap(True) + layout.addWidget(self.loaded_input_summary) + return group + + def _build_fft_settings_group(self) -> QWidget: + panel = QWidget(self) + layout = QVBoxLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + default_settings = default_contrast_fft_settings() + intro = QLabel( + "These controls define the shared q grid and the Cartesian voxel grid " + "used by the 3D FFT Born approximation. The q-range inherits from " + "Project Setup when the window is launched from the main UI." + ) + intro.setWordWrap(True) + layout.addWidget(intro) + grid = QGridLayout() + grid.setHorizontalSpacing(8) + grid.setVerticalSpacing(6) + grid.setColumnStretch(1, 1) + grid.setColumnStretch(3, 1) + tooltips = { + "q_min": ( + "Lower bound of the shared q grid in Å^-1. When the 3D FFT " + "window is launched from the main UI, this inherits the " + "project q-range when that value is available." + ), + "q_max": ( + "Upper bound of the shared q grid in Å^-1. Keep this below the " + "useful FFT Nyquist limit for the chosen voxel spacing." + ), + "q_step": ( + "Spacing between neighboring q samples in the comparison plot. " + "This shared q grid is reused for the 3D FFT result and the " + "optional 1D Born and Debye overlays." + ), + "spacing": ( + "Real-space voxel spacing for the Cartesian 3D density grid. " + "Smaller spacing improves real-space detail and raises the FFT " + "Nyquist limit, but increases memory and compute cost." + ), + "sigma": ( + "Gaussian deposition width used when mapping atomic electron " + "density onto the voxel grid. Larger sigma smooths the real-space " + "map and suppresses high-q scattering." + ), + "minimum_box_length": ( + "Minimum side length of the FFT box in Å before odd-grid " + "rounding. Larger boxes improve low-q sampling because the FFT " + "frequency spacing scales roughly as 2π/L." + ), + "padding": ( + "Extra vacuum padding added around the structure before the FFT " + "box is built. Padding helps reduce boundary coupling and " + "periodic-image contamination." + ), + } + + def _add_setting( + row: int, + label_text: str, + widget: QWidget, + tooltip: str, + column: int = 0, + ) -> None: + label = QLabel(label_text) + label.setToolTip(tooltip) + widget.setToolTip(tooltip) + grid.addWidget(label, row, column) + grid.addWidget(widget, row, column + 1) + + self.q_min_spin = QDoubleSpinBox(self) + self.q_min_spin.setRange(0.0, 100.0) + self.q_min_spin.setDecimals(4) + self.q_min_spin.setValue( + 0.01 if self._project_q_min is None else float(self._project_q_min) + ) + self.q_max_spin = QDoubleSpinBox(self) + self.q_max_spin.setRange(0.0, 100.0) + self.q_max_spin.setDecimals(4) + self.q_max_spin.setValue( + 1.20 if self._project_q_max is None else float(self._project_q_max) + ) + self.q_step_spin = QDoubleSpinBox(self) + self.q_step_spin.setRange(1.0e-4, 10.0) + self.q_step_spin.setDecimals(4) + self.q_step_spin.setValue(0.01) + self.spacing_spin = QDoubleSpinBox(self) + self.spacing_spin.setRange(0.1, 25.0) + self.spacing_spin.setDecimals(3) + self.spacing_spin.setValue(float(default_settings.spacing_a)) + self.sigma_spin = QDoubleSpinBox(self) + self.sigma_spin.setRange(0.0, 25.0) + self.sigma_spin.setDecimals(3) + self.sigma_spin.setValue(float(default_settings.gaussian_sigma_a)) + self.min_box_length_spin = QDoubleSpinBox(self) + self.min_box_length_spin.setRange(1.0, 5000.0) + self.min_box_length_spin.setDecimals(3) + self.min_box_length_spin.setValue( + float(default_settings.minimum_box_length_a) + ) + self.padding_spin = QDoubleSpinBox(self) + self.padding_spin.setRange(0.0, 500.0) + self.padding_spin.setDecimals(3) + self.padding_spin.setValue(float(default_settings.padding_a)) + _add_setting(0, "q min (Å⁻¹)", self.q_min_spin, tooltips["q_min"]) + _add_setting(0, "q max (Å⁻¹)", self.q_max_spin, tooltips["q_max"], 2) + _add_setting(1, "q step (Å⁻¹)", self.q_step_spin, tooltips["q_step"]) + _add_setting( + 1, + "Voxel spacing (Å)", + self.spacing_spin, + tooltips["spacing"], + 2, + ) + _add_setting( + 2, + "Gaussian sigma (Å)", + self.sigma_spin, + tooltips["sigma"], + ) + _add_setting( + 2, + "Minimum box length (Å)", + self.min_box_length_spin, + tooltips["minimum_box_length"], + 2, + ) + _add_setting( + 3, + "Extra padding (Å)", + self.padding_spin, + tooltips["padding"], + ) + layout.addLayout(grid) + return panel + + def _build_legacy_1d_settings_group(self) -> QWidget: + panel = QWidget(self) + layout = QVBoxLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + mesh_defaults = legacy_born_average_default_mesh_settings() + smearing_defaults = legacy_born_average_default_smearing_settings() + fourier_defaults = legacy_born_average_default_fourier_settings( + r_max=float(mesh_defaults.rmax), + q_min=float( + 0.01 if self._project_q_min is None else self._project_q_min + ), + q_max=float( + 1.20 if self._project_q_max is None else self._project_q_max + ), + q_step=0.01, + ) + + intro = QLabel( + "These controls configure the legacy 1D Born overlay that can be " + "plotted alongside the 3D FFT result. The q grid stays shared with " + "the 3D FFT calculation, and any active solvent contrast is reused " + "for the 1D q-space curve." + ) + intro.setWordWrap(True) + layout.addWidget(intro) + + mesh_grid = QGridLayout() + mesh_grid.setHorizontalSpacing(8) + mesh_grid.setVerticalSpacing(6) + mesh_grid.setColumnStretch(1, 1) + mesh_grid.setColumnStretch(3, 1) + self.legacy_1d_rstep_spin = QDoubleSpinBox(self) + self.legacy_1d_rstep_spin.setRange(0.01, 5.0) + self.legacy_1d_rstep_spin.setDecimals(3) + self.legacy_1d_rstep_spin.setValue(float(mesh_defaults.rstep)) + self.legacy_1d_theta_spin = QSpinBox(self) + self.legacy_1d_theta_spin.setRange(4, 720) + self.legacy_1d_theta_spin.setValue(int(mesh_defaults.theta_divisions)) + self.legacy_1d_phi_spin = QSpinBox(self) + self.legacy_1d_phi_spin.setRange(4, 360) + self.legacy_1d_phi_spin.setValue(int(mesh_defaults.phi_divisions)) + self.legacy_1d_rmax_spin = QDoubleSpinBox(self) + self.legacy_1d_rmax_spin.setRange(0.01, 5000.0) + self.legacy_1d_rmax_spin.setDecimals(3) + self.legacy_1d_rmax_spin.setValue(float(mesh_defaults.rmax)) + mesh_grid.addWidget(QLabel("rstep (Å)"), 0, 0) + mesh_grid.addWidget(self.legacy_1d_rstep_spin, 0, 1) + mesh_grid.addWidget(QLabel("Theta divisions"), 0, 2) + mesh_grid.addWidget(self.legacy_1d_theta_spin, 0, 3) + mesh_grid.addWidget(QLabel("Phi divisions"), 1, 0) + mesh_grid.addWidget(self.legacy_1d_phi_spin, 1, 1) + mesh_grid.addWidget(QLabel("rmax (Å)"), 1, 2) + mesh_grid.addWidget(self.legacy_1d_rmax_spin, 1, 3) + layout.addLayout(mesh_grid) + + transform_grid = QGridLayout() + transform_grid.setHorizontalSpacing(8) + transform_grid.setVerticalSpacing(6) + transform_grid.setColumnStretch(1, 1) + transform_grid.setColumnStretch(3, 1) + self.legacy_1d_smearing_factor_spin = QDoubleSpinBox(self) + self.legacy_1d_smearing_factor_spin.setRange(0.0, 500.0) + self.legacy_1d_smearing_factor_spin.setDecimals(6) + self.legacy_1d_smearing_factor_spin.setValue( + float(smearing_defaults.debye_waller_factor) + ) + self.legacy_1d_domain_combo = QComboBox(self) + self.legacy_1d_domain_combo.addItem("One-sided (legacy)", "legacy") + self.legacy_1d_domain_combo.addItem("Mirrored", "mirrored") + self.legacy_1d_domain_combo.setCurrentIndex( + 0 if str(fourier_defaults.domain_mode) == "legacy" else 1 + ) + self.legacy_1d_window_combo = QComboBox(self) + for label, value in ( + ("None", "none"), + ("Lorch", "lorch"), + ("Cosine", "cosine"), + ("Hanning", "hanning"), + ("Parzen", "parzen"), + ("Welch", "welch"), + ("Gaussian", "gaussian"), + ("Sine", "sine"), + ("Kaiser-Bessel", "kaiser_bessel"), + ): + self.legacy_1d_window_combo.addItem(label, value) + if value == str(fourier_defaults.window_function): + self.legacy_1d_window_combo.setCurrentIndex( + self.legacy_1d_window_combo.count() - 1 + ) + self.legacy_1d_resampling_points_spin = QSpinBox(self) + self.legacy_1d_resampling_points_spin.setRange(8, 200000) + self.legacy_1d_resampling_points_spin.setValue( + int(fourier_defaults.resampling_points) + ) + self.legacy_1d_shared_q_note = QLabel( + "The 1D overlay always uses the shared 3D FFT q range and q step for direct plot comparison." + ) + self.legacy_1d_shared_q_note.setWordWrap(True) + self.legacy_1d_contrast_note = QLabel( + "No active solvent contrast is currently being reused by the 1D comparison curve." + ) + self.legacy_1d_contrast_note.setWordWrap(True) + transform_grid.addWidget(QLabel("Debye-Waller factor (Ų)"), 0, 0) + transform_grid.addWidget(self.legacy_1d_smearing_factor_spin, 0, 1) + transform_grid.addWidget(QLabel("Transform domain"), 0, 2) + transform_grid.addWidget(self.legacy_1d_domain_combo, 0, 3) + transform_grid.addWidget(QLabel("Window"), 1, 0) + transform_grid.addWidget(self.legacy_1d_window_combo, 1, 1) + transform_grid.addWidget(QLabel("Resample pts"), 1, 2) + transform_grid.addWidget(self.legacy_1d_resampling_points_spin, 1, 3) + layout.addLayout(transform_grid) + layout.addWidget(self.legacy_1d_shared_q_note) + layout.addWidget(self.legacy_1d_contrast_note) + return panel + + def _build_contrast_group(self) -> QWidget: + panel = QWidget(self) + layout = QGridLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + layout.setHorizontalSpacing(8) + layout.setVerticalSpacing(6) + layout.setColumnStretch(1, 1) + + intro = QLabel( + "Set up an optional flat solvent electron-density subtraction for " + "the 3D FFT Born approximation. These settings stay separate from " + "the legacy 1D Born workflow and must be applied before the FFT run." + ) + intro.setWordWrap(True) + layout.addWidget(intro, 0, 0, 1, 2) + + self.solvent_method_combo = QComboBox(self) + self.solvent_method_combo.addItem( + "Estimate from Solvent Formula and Density", + userData=CONTRAST_SOLVENT_METHOD_NEAT, + ) + self.solvent_method_combo.addItem( + "Reference Solvent Structure (XYZ/PDB)", + userData=CONTRAST_SOLVENT_METHOD_REFERENCE, + ) + self.solvent_method_combo.addItem( + "Direct Electron Density Value", + userData=CONTRAST_SOLVENT_METHOD_DIRECT, + ) + self.solvent_method_combo.currentIndexChanged.connect( + self._sync_density_method_controls + ) + layout.addWidget(QLabel("Compute option"), 1, 0) + layout.addWidget(self.solvent_method_combo, 1, 1) + + self.solvent_preset_combo = QComboBox(self) + self.solvent_preset_combo.currentIndexChanged.connect( + self._load_selected_solvent_preset + ) + self.save_custom_solvent_button = QPushButton("Save Custom") + self.save_custom_solvent_button.clicked.connect( + self._save_current_solvent_preset + ) + self.delete_custom_solvent_button = QPushButton("Delete Custom") + self.delete_custom_solvent_button.clicked.connect( + self._delete_current_solvent_preset + ) + solvent_row = QWidget(self) + solvent_row_layout = QHBoxLayout(solvent_row) + solvent_row_layout.setContentsMargins(0, 0, 0, 0) + solvent_row_layout.setSpacing(6) + solvent_row_layout.addWidget(self.solvent_preset_combo, stretch=1) + solvent_row_layout.addWidget(self.save_custom_solvent_button) + solvent_row_layout.addWidget(self.delete_custom_solvent_button) + layout.addWidget(QLabel("Saved solvents"), 2, 0) + layout.addWidget(solvent_row, 2, 1) + + self.solvent_formula_edit = QLineEdit(self) + self.solvent_formula_edit.setPlaceholderText( + "Examples: H2O, Vacuum, C3H7NO (DMF), C2H6OS (DMSO)" + ) + layout.addWidget(QLabel("Solvent formula"), 3, 0) + layout.addWidget(self.solvent_formula_edit, 3, 1) + + self.solvent_density_spin = QDoubleSpinBox(self) + self.solvent_density_spin.setDecimals(6) + self.solvent_density_spin.setRange(0.0, 100.0) + self.solvent_density_spin.setSingleStep(0.01) + self.solvent_density_spin.setValue(1.0) + self.solvent_density_spin.setKeyboardTracking(False) + layout.addWidget(QLabel("Density (g/mL)"), 4, 0) + layout.addWidget(self.solvent_density_spin, 4, 1) + + self.direct_density_spin = QDoubleSpinBox(self) + self.direct_density_spin.setDecimals(9) + self.direct_density_spin.setRange(0.0, 100.0) + self.direct_density_spin.setSingleStep(0.001) + self.direct_density_spin.setValue(0.334) + self.direct_density_spin.setKeyboardTracking(False) + layout.addWidget(QLabel("Direct density (e-/ų)"), 5, 0) + layout.addWidget(self.direct_density_spin, 5, 1) + + self.reference_solvent_file_edit = QLineEdit(self) + self.reference_solvent_file_edit.setPlaceholderText( + "Choose a reference solvent XYZ or PDB file" + ) + self.reference_solvent_browse_button = QPushButton("Browse…") + self.reference_solvent_browse_button.clicked.connect( + self._choose_reference_solvent_file + ) + reference_row = QWidget(self) + reference_layout = QHBoxLayout(reference_row) + reference_layout.setContentsMargins(0, 0, 0, 0) + reference_layout.setSpacing(6) + reference_layout.addWidget(self.reference_solvent_file_edit, stretch=1) + reference_layout.addWidget(self.reference_solvent_browse_button) + layout.addWidget(QLabel("Reference solvent file"), 6, 0) + layout.addWidget(reference_row, 6, 1) + + contrast_fft_defaults = default_contrast_fft_settings() + self.exclusion_radius_scale_spin = QDoubleSpinBox(self) + self.exclusion_radius_scale_spin.setRange(0.1, 10.0) + self.exclusion_radius_scale_spin.setDecimals(3) + self.exclusion_radius_scale_spin.setValue( + float(contrast_fft_defaults.exclusion_radius_scale) + ) + layout.addWidget(QLabel("Exclusion radius scale"), 7, 0) + layout.addWidget(self.exclusion_radius_scale_spin, 7, 1) + + self.exclusion_radius_padding_spin = QDoubleSpinBox(self) + self.exclusion_radius_padding_spin.setRange(0.0, 25.0) + self.exclusion_radius_padding_spin.setDecimals(3) + self.exclusion_radius_padding_spin.setValue( + float(contrast_fft_defaults.exclusion_radius_padding_a) + ) + layout.addWidget(QLabel("Exclusion radius padding (Å)"), 8, 0) + layout.addWidget(self.exclusion_radius_padding_spin, 8, 1) + + self.solvent_method_hint_label = QLabel(self) + self.solvent_method_hint_label.setWordWrap(True) + layout.addWidget(self.solvent_method_hint_label, 9, 0, 1, 2) + + self.apply_contrast_button = QPushButton( + "Apply Electron Density Contrast" + ) + self.apply_contrast_button.clicked.connect( + self._apply_contrast_settings + ) + layout.addWidget(self.apply_contrast_button, 10, 0, 1, 2) + + self.active_contrast_value = QLabel( + "No active solvent electron density contrast yet." + ) + self.active_contrast_value.setWordWrap(True) + layout.addWidget(QLabel("Active contrast"), 11, 0) + layout.addWidget(self.active_contrast_value, 11, 1) + + self.contrast_notice_value = QLabel( + "Apply contrast settings to use a solvent-density subtraction in the next 3D FFT Born calculation." + ) + self.contrast_notice_value.setWordWrap(True) + layout.addWidget(QLabel("Notes"), 12, 0) + layout.addWidget(self.contrast_notice_value, 12, 1) + return panel + + def _build_comparison_group(self) -> QWidget: + panel = QWidget(self) + layout = QVBoxLayout(panel) + layout.setContentsMargins(0, 0, 0, 0) + self.compare_legacy_checkbox = QCheckBox( + "Overlay 1D Born Approximation (Average)" + ) + self.compare_legacy_checkbox.setChecked(True) + self.compare_legacy_checkbox.toggled.connect( + self._refresh_plot_controls + ) + self.compare_exact_debye_checkbox = QCheckBox( + "Overlay exact Debye scattering" + ) + self.compare_exact_debye_checkbox.setChecked(False) + self.compare_exact_debye_checkbox.toggled.connect( + self._refresh_plot_controls + ) + self.show_kernel_corrected_checkbox = QCheckBox( + "Show kernel-corrected FFT overlay (zero-contrast diagnostic)" + ) + self.show_kernel_corrected_checkbox.setChecked(False) + self.show_kernel_corrected_checkbox.setToolTip( + "Kernel correction removes the Gaussian voxel-deposition response " + "from the zero-contrast FFT intensity so it can be compared against " + "the point-scatterer Debye limit. It is not used for solvent-contrast production curves." + ) + self.show_kernel_corrected_checkbox.toggled.connect( + self._refresh_plot_controls + ) + layout.addWidget(self.compare_legacy_checkbox) + layout.addWidget(self.compare_exact_debye_checkbox) + layout.addWidget(self.show_kernel_corrected_checkbox) + return panel + + def _build_plot_options_group(self) -> QWidget: + group = QGroupBox("Plot Options") + layout = QVBoxLayout(group) + self.log_q_checkbox = QCheckBox("Log q axis") + self.log_q_checkbox.setChecked(True) + self.log_q_checkbox.toggled.connect(self._refresh_plot_controls) + self.log_intensity_checkbox = QCheckBox("Log intensity axis") + self.log_intensity_checkbox.setChecked(True) + self.log_intensity_checkbox.toggled.connect( + self._refresh_plot_controls + ) + layout.addWidget(self.log_q_checkbox) + layout.addWidget(self.log_intensity_checkbox) + return group + + def _build_actions_group(self) -> QWidget: + group = QGroupBox("Actions") + layout = QHBoxLayout(group) + self.compute_button = QPushButton("Compute 3D FFT Born Approximation") + self.compute_button.clicked.connect(self._start_calculation) + layout.addWidget(self.compute_button) + self.push_to_model_button = QPushButton("Push to Model") + self.push_to_model_button.clicked.connect( + self._push_components_to_model + ) + layout.addWidget(self.push_to_model_button) + self.clear_results_button = QPushButton("Clear Results") + self.clear_results_button.clicked.connect(self._clear_results) + layout.addWidget(self.clear_results_button) + return group + + def _build_log_group(self) -> QWidget: + group = QGroupBox("Status Log") + layout = QVBoxLayout(group) + self.status_log_box = QPlainTextEdit(self) + self.status_log_box.setReadOnly(True) + self.status_log_box.setMinimumHeight(180) + layout.addWidget(self.status_log_box) + return group + + def _connect_trace_configuration_controls(self) -> None: + for spinbox in ( + self.q_min_spin, + self.q_max_spin, + self.q_step_spin, + self.spacing_spin, + self.sigma_spin, + self.min_box_length_spin, + self.padding_spin, + self.exclusion_radius_scale_spin, + self.exclusion_radius_padding_spin, + ): + spinbox.valueChanged.connect(self._on_trace_configuration_changed) + for widget in ( + self.solvent_method_combo, + self.solvent_preset_combo, + ): + widget.currentIndexChanged.connect( + self._on_contrast_controls_changed + ) + for widget in ( + self.solvent_formula_edit, + self.reference_solvent_file_edit, + ): + widget.textChanged.connect(self._on_contrast_controls_changed) + for spinbox in ( + self.solvent_density_spin, + self.direct_density_spin, + ): + spinbox.valueChanged.connect(self._on_contrast_controls_changed) + + def _on_trace_configuration_changed(self, *_args: object) -> None: + if self._restoring_workspace_state: + return + self._refresh_fft_box_visualizer() + self._update_push_to_model_state() + + def _on_contrast_controls_changed(self, *_args: object) -> None: + if self._restoring_workspace_state: + return + self._contrast_controls_dirty = True + if self._computed_profile_results: + self.contrast_notice_value.setText( + "Electron density contrast controls changed. Apply the new " + "contrast or recompute the 3D FFT traces before pushing these " + "components to the model." + ) + self.contrast_notice_value.setStyleSheet("color: #b45309;") + self._update_push_to_model_state() + + def _legacy_1d_mesh_settings_from_controls( + self, + ) -> ElectronDensityMeshSettings: + return ElectronDensityMeshSettings( + rstep=float(self.legacy_1d_rstep_spin.value()), + theta_divisions=int(self.legacy_1d_theta_spin.value()), + phi_divisions=int(self.legacy_1d_phi_spin.value()), + rmax=float(self.legacy_1d_rmax_spin.value()), + ).normalized() + + def _legacy_1d_smearing_settings_from_controls( + self, + ) -> ElectronDensitySmearingSettings: + return ElectronDensitySmearingSettings( + debye_waller_factor=float( + self.legacy_1d_smearing_factor_spin.value() + ) + ).normalized() + + def _legacy_1d_fourier_settings_from_controls( + self, + ) -> ElectronDensityFourierTransformSettings: + mesh_settings = self._legacy_1d_mesh_settings_from_controls() + return ElectronDensityFourierTransformSettings( + r_min=0.0, + r_max=float(mesh_settings.rmax), + domain_mode=str( + self.legacy_1d_domain_combo.currentData() or "legacy" + ), + window_function=str( + self.legacy_1d_window_combo.currentData() or "none" + ), + resampling_points=int( + self.legacy_1d_resampling_points_spin.value() + ), + q_min=float(self.q_min_spin.value()), + q_max=float(self.q_max_spin.value()), + q_step=float(self.q_step_spin.value()), + use_solvent_subtracted_profile=bool( + self._active_contrast_settings is not None + ), + log_q_axis=bool(self.log_q_checkbox.isChecked()), + log_intensity_axis=bool(self.log_intensity_checkbox.isChecked()), + ).normalized() + + def _sync_legacy_1d_defaults_to_structure( + self, + structure: ElectronDensityStructure, + ) -> None: + mesh_defaults = legacy_born_average_default_mesh_settings(structure) + self.legacy_1d_rstep_spin.setValue(float(mesh_defaults.rstep)) + self.legacy_1d_theta_spin.setValue(int(mesh_defaults.theta_divisions)) + self.legacy_1d_phi_spin.setValue(int(mesh_defaults.phi_divisions)) + self.legacy_1d_rmax_spin.setValue(float(mesh_defaults.rmax)) + + def _refresh_legacy_1d_contrast_note(self) -> None: + if self._active_contrast_settings is None: + self.legacy_1d_contrast_note.setText( + "No active solvent contrast is currently being reused by the 1D comparison curve." + ) + self.legacy_1d_contrast_note.setStyleSheet("color: #475569;") + return + if abs(float(self._active_solvent_density_e_per_a3)) <= 1.0e-15: + self.legacy_1d_contrast_note.setText( + "The 1D comparison curve reuses the active zero-density contrast, so both 1D and 3D curves stay in the bare-density limit." + ) + self.legacy_1d_contrast_note.setStyleSheet("color: #166534;") + return + self.legacy_1d_contrast_note.setText( + "The 1D comparison curve reuses the active solvent-density subtraction before its Fourier transform so the q-space comparison stays contrast-matched." + ) + self.legacy_1d_contrast_note.setStyleSheet("color: #166534;") + + @Slot() + def _on_structure_source_mode_changed(self) -> None: + self._prefer_representative_structures = ( + self._current_structure_source_mode() == "representative" + ) + self._refresh_available_profile_targets() + + def _selected_solvent_preset_token(self) -> object: + return self.solvent_preset_combo.currentData() + + def _selected_solvent_preset_name(self) -> str | None: + token = self._selected_solvent_preset_token() + if token in {None, _SOLVENT_PRESET_NONE}: + return None + return str(token).strip() or None + + def _reload_solvent_presets(self, *, selected_name: str | None) -> None: + self._solvent_presets = load_solvent_presets() + previous_name = ( + selected_name + if selected_name is not None + else self._selected_solvent_preset_name() + ) + self.solvent_preset_combo.blockSignals(True) + self.solvent_preset_combo.clear() + self.solvent_preset_combo.addItem("Custom entry", None) + self.solvent_preset_combo.addItem("None", _SOLVENT_PRESET_NONE) + selected_index = 0 + if previous_name == _SOLVENT_PRESET_NONE: + selected_index = 1 + for index, preset_name in enumerate( + ordered_solvent_preset_names(self._solvent_presets), + start=2, + ): + preset = self._solvent_presets[preset_name] + label = ( + preset.name if preset.builtin else f"{preset.name} (Custom)" + ) + self.solvent_preset_combo.addItem(label, preset_name) + if previous_name == preset_name: + selected_index = index + self.solvent_preset_combo.setCurrentIndex(selected_index) + self.solvent_preset_combo.blockSignals(False) + self._load_selected_solvent_preset() + + @Slot() + def _load_selected_solvent_preset(self) -> None: + if self._selected_solvent_preset_token() == _SOLVENT_PRESET_NONE: + self.delete_custom_solvent_button.setEnabled(False) + self._sync_density_method_controls() + return + preset_name = self._selected_solvent_preset_name() + preset = self._solvent_presets.get(preset_name or "") + if preset is None: + self.delete_custom_solvent_button.setEnabled(False) + self._sync_density_method_controls() + return + self.solvent_formula_edit.setText(preset.formula) + self.solvent_density_spin.setValue(preset.density_g_per_ml) + self.delete_custom_solvent_button.setEnabled(not preset.builtin) + self._sync_density_method_controls() + + @Slot() + def _save_current_solvent_preset(self) -> None: + suggested_name = self._selected_solvent_preset_name() or "" + preset_name, accepted = QInputDialog.getText( + self, + "Save Custom Solvent", + "Custom solvent name:", + text=suggested_name, + ) + if not accepted: + return + name = str(preset_name).strip() + if not name: + QMessageBox.warning( + self, + "Save Custom Solvent", + "Enter a solvent name before saving.", + ) + return + formula = self.solvent_formula_edit.text().strip() + density = float(self.solvent_density_spin.value()) + try: + preset = ContrastSolventPreset( + name=name, + formula=formula, + density_g_per_ml=density, + builtin=False, + ) + except ValueError as exc: + QMessageBox.warning(self, "Save Custom Solvent", str(exc)) + return + if name in self._solvent_presets: + response = QMessageBox.question( + self, + "Overwrite custom solvent?", + f"A solvent named '{name}' already exists. Overwrite it?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if response != QMessageBox.StandardButton.Yes: + return + save_custom_solvent_preset(preset) + self._reload_solvent_presets(selected_name=name) + self._append_status(f"Saved custom solvent preset: {name}.") + + @Slot() + def _delete_current_solvent_preset(self) -> None: + preset_name = self._selected_solvent_preset_name() + if preset_name is None: + QMessageBox.information( + self, + "Delete Custom Solvent", + "Select a saved custom solvent first.", + ) + return + preset = self._solvent_presets.get(preset_name) + if preset is None or preset.builtin: + QMessageBox.information( + self, + "Delete Custom Solvent", + "Only custom solvents can be deleted.", + ) + return + response = QMessageBox.question( + self, + "Delete Custom Solvent", + f"Delete the custom solvent preset '{preset_name}'?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if response != QMessageBox.StandardButton.Yes: + return + delete_custom_solvent_preset(preset_name) + self._reload_solvent_presets(selected_name="Water") + self._append_status(f"Deleted custom solvent preset: {preset_name}.") + + def _clear_solvent_contrast_requested_from_controls(self) -> bool: + method = str(self.solvent_method_combo.currentData() or "").strip() + return ( + method == CONTRAST_SOLVENT_METHOD_NEAT + and self._selected_solvent_preset_token() == _SOLVENT_PRESET_NONE + ) + + @Slot() + def _sync_density_method_controls(self) -> None: + method = str(self.solvent_method_combo.currentData() or "").strip() + using_neat = method == CONTRAST_SOLVENT_METHOD_NEAT + using_reference = method == CONTRAST_SOLVENT_METHOD_REFERENCE + using_direct = method == CONTRAST_SOLVENT_METHOD_DIRECT + for widget in ( + self.solvent_preset_combo, + self.solvent_formula_edit, + self.solvent_density_spin, + self.save_custom_solvent_button, + ): + widget.setEnabled(using_neat) + self.delete_custom_solvent_button.setEnabled( + using_neat + and self._selected_solvent_preset_name() is not None + and not self._solvent_presets.get( + self._selected_solvent_preset_name() or "", + ContrastSolventPreset("", "Vacuum", 0.0, builtin=True), + ).builtin + ) + self.reference_solvent_file_edit.setEnabled(using_reference) + self.reference_solvent_browse_button.setEnabled(using_reference) + self.direct_density_spin.setEnabled(using_direct) + if using_reference: + self.solvent_method_hint_label.setText( + "Reference structure mode estimates a uniform solvent electron density " + "from the full XYZ/PDB coordinate box spanned by the selected file." + ) + elif using_direct: + self.solvent_method_hint_label.setText( + "Direct value mode uses the electron density you provide in e-/ų. " + "Use 0.0 e-/ų to model vacuum without solvent subtraction." + ) + elif self._clear_solvent_contrast_requested_from_controls(): + self.solvent_method_hint_label.setText( + "The None solvent option clears the active solvent-density subtraction " + "and returns the 3D FFT workflow to bare atomic density." + ) + else: + self.solvent_method_hint_label.setText( + "Quick estimate mode uses the selected solvent stoichiometry and density. " + "Built-in presets include Water, Vacuum, DMF, and DMSO." + ) + + @Slot() + def _choose_reference_solvent_file(self) -> None: + start_dir = ( + str( + Path(self.reference_solvent_file_edit.text()) + .expanduser() + .resolve() + .parent + ) + if self.reference_solvent_file_edit.text().strip() + else str(self._project_dir or Path.cwd()) + ) + selected_path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Choose Reference Solvent Structure", + start_dir, + "Structure files (*.pdb *.xyz);;All files (*)", + ) + if not selected_path: + return + self.reference_solvent_file_edit.setText( + str(Path(selected_path).expanduser().resolve()) + ) + + def _contrast_settings_from_controls( + self, + ) -> ContrastSolventDensitySettings: + method = str(self.solvent_method_combo.currentData() or "").strip() + if method == CONTRAST_SOLVENT_METHOD_REFERENCE: + reference_path = self.reference_solvent_file_edit.text().strip() + if not reference_path: + raise ValueError( + "Choose a reference solvent XYZ or PDB file before computing the solvent electron density." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_REFERENCE, + reference_structure_file=reference_path, + ) + if method == CONTRAST_SOLVENT_METHOD_DIRECT: + direct_density = float(self.direct_density_spin.value()) + if direct_density < 0.0: + raise ValueError( + "Enter a non-negative direct solvent electron density before computing the solvent contrast." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_DIRECT, + direct_electron_density_e_per_a3=direct_density, + ) + formula = self.solvent_formula_edit.text().strip() + if not formula: + raise ValueError( + "Enter a solvent stoichiometry formula before computing the solvent electron density." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_NEAT, + solvent_formula=formula, + solvent_density_g_per_ml=self.solvent_density_spin.value(), + ) + + def _contrast_display_name_from_controls(self) -> str: + method = str(self.solvent_method_combo.currentData() or "").strip() + if method == CONTRAST_SOLVENT_METHOD_REFERENCE: + reference_path = self.reference_solvent_file_edit.text().strip() + return ( + Path(reference_path).stem + if reference_path + else "Reference solvent" + ) + if method == CONTRAST_SOLVENT_METHOD_DIRECT: + return "Direct solvent" + preset_name = self._selected_solvent_preset_name() + if preset_name: + return preset_name + return self.solvent_formula_edit.text().strip() or "Solvent" + + def _estimate_reference_solvent_density( + self, + settings: ContrastSolventDensitySettings, + ) -> ContrastElectronDensityEstimate: + reference_file = settings.reference_structure_file + if reference_file is None: + raise ValueError( + "Choose a reference solvent XYZ or PDB file before computing the solvent electron density." + ) + structure = load_electron_density_structure(reference_file) + coordinates = np.asarray(structure.coordinates, dtype=float) + spans = np.max(coordinates, axis=0) - np.min(coordinates, axis=0) + volume_a3 = float(np.prod(spans)) + if volume_a3 <= 0.0: + raise ValueError( + "The reference solvent structure must span a positive 3D coordinate box." + ) + total_electrons = float(np.sum(structure.atomic_numbers, dtype=float)) + density_e_per_a3 = total_electrons / volume_a3 + return ContrastElectronDensityEstimate( + method=CONTRAST_SOLVENT_METHOD_REFERENCE, + label="Reference solvent structure", + volume_a3=volume_a3, + total_electrons=total_electrons, + electron_density_e_per_a3=float(density_e_per_a3), + electron_density_e_per_cm3=float( + density_e_per_a3 * ANGSTROM3_PER_CM3 + ), + atom_count=structure.atom_count, + element_counts=dict(sorted(structure.element_counts.items())), + reference_structure_file=reference_file, + reference_box_spans=( + float(spans[0]), + float(spans[1]), + float(spans[2]), + ), + ) + + def _estimate_solvent_density( + self, + settings: ContrastSolventDensitySettings, + ) -> ContrastElectronDensityEstimate: + if settings.method == CONTRAST_SOLVENT_METHOD_REFERENCE: + return self._estimate_reference_solvent_density(settings) + if settings.method == CONTRAST_SOLVENT_METHOD_DIRECT: + return _direct_solvent_electron_density(settings, volume_a3=1.0) + return _neat_solvent_electron_density(settings, volume_a3=1.0) + + @Slot() + def _apply_contrast_settings(self) -> None: + self._apply_contrast_settings_from_controls(announce=True) + + def _apply_contrast_settings_from_controls( + self, + *, + announce: bool, + ) -> bool: + if self._clear_solvent_contrast_requested_from_controls(): + self._active_contrast_settings = None + self._active_contrast_estimate = None + self._active_contrast_name = None + self._active_solvent_density_e_per_a3 = 0.0 + self._contrast_controls_dirty = False + self._sync_kernel_correction_option() + self._refresh_contrast_display() + self._refresh_fft_box_visualizer() + self._update_push_to_model_state() + if announce: + self._append_status( + "Cleared the active solvent-density subtraction for the 3D FFT Born workflow." + ) + self.statusBar().showMessage("Cleared 3D FFT solvent contrast") + return True + try: + self._active_contrast_settings = ( + self._contrast_settings_from_controls() + ) + self._active_contrast_estimate = self._estimate_solvent_density( + self._active_contrast_settings + ) + self._active_contrast_name = ( + self._contrast_display_name_from_controls() + ) + self._active_solvent_density_e_per_a3 = float( + self._active_contrast_estimate.electron_density_e_per_a3 + ) + except Exception as exc: + self._show_error("Solvent Contrast Error", str(exc)) + return False + self._contrast_controls_dirty = False + self._sync_kernel_correction_option() + self._refresh_contrast_display() + self._refresh_fft_box_visualizer() + self._update_push_to_model_state() + if announce: + self._append_status( + "Applied solvent-density setup " + f"{self._active_contrast_name or 'Solvent'} at " + f"{self._active_solvent_density_e_per_a3:.6f} e/ų." + ) + self.statusBar().showMessage("Applied 3D FFT solvent contrast") + return True + + def _refresh_contrast_display(self) -> None: + estimate = self._active_contrast_estimate + if estimate is None: + self.active_contrast_value.setText( + "No active solvent electron density contrast. The 3D FFT run will use bare atomic density only." + ) + self.active_contrast_value.setStyleSheet("color: #475569;") + self.contrast_notice_value.setText( + "Apply contrast settings to enable constant solvent-density subtraction inside the atomic exclusion mask." + ) + self.contrast_notice_value.setStyleSheet("color: #475569;") + self._refresh_legacy_1d_contrast_note() + return + active_name = self._active_contrast_name or estimate.label + self.active_contrast_value.setText( + f"{active_name}: {estimate.electron_density_e_per_a3:.6f} e/ų" + ) + if abs(float(estimate.electron_density_e_per_a3)) <= 1.0e-15: + self.active_contrast_value.setStyleSheet("color: #166534;") + self.contrast_notice_value.setText( + "This is effectively a zero-contrast run, so the kernel-corrected FFT overlay can be used as a point-scatterer diagnostic." + ) + self.contrast_notice_value.setStyleSheet("color: #166534;") + self._refresh_legacy_1d_contrast_note() + return + self.active_contrast_value.setStyleSheet("color: #166534;") + self.contrast_notice_value.setText( + "The next 3D FFT Born calculation will subtract this constant solvent electron density inside the union of atomic exclusion spheres." + ) + self.contrast_notice_value.setStyleSheet("color: #166534;") + self._refresh_legacy_1d_contrast_note() + + def _active_contrast_summary_for_visualizer(self) -> str | None: + estimate = self._active_contrast_estimate + if estimate is None: + return None + if abs(float(estimate.electron_density_e_per_a3)) <= 1.0e-15: + return "Active contrast: zero-density / vacuum" + return ( + f"Active contrast: {estimate.electron_density_e_per_a3:.4f} e/ų, " + f"scale={self.exclusion_radius_scale_spin.value():.3f}, " + f"padding={self.exclusion_radius_padding_spin.value():.3f} Å" + ) + + def _make_profile_target_key( + self, + *, + structure_name: str, + motif_name: str, + source_mode: str, + solvent_mode: str, + ) -> str: + return ( + f"{source_mode}|{solvent_mode}|{structure_name.strip()}|" + f"{motif_name.strip() or 'no_motif'}" + ) + + def _project_source(self): + if self._project_dir is None: + return None + try: + return load_rmc_project_source(self._project_dir) + except Exception: + return None + + def _project_average_input_path(self) -> Path | None: + if self._project_dir is None: + return None + try: + from saxshell.saxs.project_manager.project import ( + SAXSProjectManager, + ) + + settings = SAXSProjectManager().load_project(self._project_dir) + except Exception: + return None + clusters_dir = settings.resolved_clusters_dir + if clusters_dir is not None and clusters_dir.exists(): + return clusters_dir + return None + + def _cluster_targets_from_input_path( + self, + input_path: Path, + ) -> tuple[_FFTProfileTarget, ...]: + try: + cluster_bins = discover_cluster_bins(input_path) + except Exception: + cluster_bins = [] + if cluster_bins: + return tuple( + _FFTProfileTarget( + key=self._make_profile_target_key( + structure_name=cluster_bin.structure, + motif_name=cluster_bin.motif, + source_mode="average", + solvent_mode="input", + ), + display_name=( + cluster_bin.structure + if cluster_bin.motif == "no_motif" + else f"{cluster_bin.structure}/{cluster_bin.motif}" + ), + structure_name=cluster_bin.structure, + motif_name=cluster_bin.motif, + file_count=len(cluster_bin.files), + reference_file=cluster_bin.files[0], + source_files=tuple(cluster_bin.files), + representative=cluster_bin.representative, + source_mode="average", + solvent_mode="input", + ) + for cluster_bin in cluster_bins + ) + inspection = inspect_structure_input(input_path) + structure_name = ( + input_path.stem if input_path.is_file() else input_path.name + ) + return ( + _FFTProfileTarget( + key=self._make_profile_target_key( + structure_name=structure_name, + motif_name="no_motif", + source_mode="average", + solvent_mode="input", + ), + display_name=structure_name, + structure_name=structure_name, + motif_name="no_motif", + file_count=int(inspection.total_files), + reference_file=inspection.reference_file, + source_files=tuple(inspection.structure_files), + representative=inspection.reference_file.name, + source_mode="average", + solvent_mode="input", + ), + ) + + def _representative_targets_from_project_source( + self, + ) -> dict[tuple[str, str], tuple[_FFTProfileTarget, ...]]: + def source_solvent_mode(entry) -> str: + normalized = str( + getattr(entry, "source_solvent_mode", "") or "" + ).strip() + if normalized == "nosolv": + return "none" + if normalized == "fullsolv": + return "full" + return "partial" + + project_source = self._project_source() + if project_source is None: + return {} + targets: dict[tuple[str, str], list[_FFTProfileTarget]] = {} + metadata = project_source.representative_selection + solvent_metadata = project_source.solvent_handling + if metadata is not None: + source_entries: list[tuple[object, Path]] = [] + source_modes: set[str] = set() + for entry in metadata.representative_entries: + source_file = Path(entry.source_file).expanduser().resolve() + if not source_file.is_file(): + continue + source_entries.append((entry, source_file)) + source_modes.add(source_solvent_mode(entry)) + if source_entries: + aggregate_mode = ( + next(iter(source_modes)) + if len(source_modes) == 1 + else "partial" + ) + source_targets: list[_FFTProfileTarget] = [] + for entry, source_file in source_entries: + source_targets.append( + _FFTProfileTarget( + key=self._make_profile_target_key( + structure_name=entry.structure, + motif_name=entry.motif, + source_mode="representative", + solvent_mode=aggregate_mode, + ), + display_name=( + entry.structure + if entry.motif == "no_motif" + else f"{entry.structure}/{entry.motif}" + ), + structure_name=entry.structure, + motif_name=entry.motif, + file_count=1, + reference_file=source_file, + source_files=(source_file,), + representative=( + entry.source_file_name or source_file.name + ), + source_mode="representative", + solvent_mode=aggregate_mode, + ) + ) + has_built_mode_overlap = bool( + solvent_metadata is not None + and solvent_metadata.entries + and aggregate_mode in {"none", "full"} + ) + if not has_built_mode_overlap: + targets[("representative", aggregate_mode)] = ( + source_targets + ) + if solvent_metadata is not None and solvent_metadata.entries: + no_solvent_targets: list[_FFTProfileTarget] = [] + full_solvent_targets: list[_FFTProfileTarget] = [] + for entry in solvent_metadata.entries: + no_solvent_path = ( + Path(entry.no_solvent_pdb).expanduser().resolve() + ) + completed_path = ( + Path(entry.completed_pdb).expanduser().resolve() + ) + display_name = ( + entry.structure + if entry.motif == "no_motif" + else f"{entry.structure}/{entry.motif}" + ) + if no_solvent_path.is_file(): + no_solvent_targets.append( + _FFTProfileTarget( + key=self._make_profile_target_key( + structure_name=entry.structure, + motif_name=entry.motif, + source_mode="representative", + solvent_mode="none", + ), + display_name=display_name, + structure_name=entry.structure, + motif_name=entry.motif, + file_count=1, + reference_file=no_solvent_path, + source_files=(no_solvent_path,), + representative=no_solvent_path.name, + source_mode="representative", + solvent_mode="none", + ) + ) + if completed_path.is_file(): + full_solvent_targets.append( + _FFTProfileTarget( + key=self._make_profile_target_key( + structure_name=entry.structure, + motif_name=entry.motif, + source_mode="representative", + solvent_mode="full", + ), + display_name=display_name, + structure_name=entry.structure, + motif_name=entry.motif, + file_count=1, + reference_file=completed_path, + source_files=(completed_path,), + representative=completed_path.name, + source_mode="representative", + solvent_mode="full", + ) + ) + if no_solvent_targets: + targets[("representative", "none")] = no_solvent_targets + if full_solvent_targets: + targets[("representative", "full")] = full_solvent_targets + return {key: tuple(value) for key, value in targets.items() if value} + + def _resolve_available_profile_targets( + self, + input_path: Path, + ) -> dict[tuple[str, str], tuple[_FFTProfileTarget, ...]]: + targets: dict[tuple[str, str], tuple[_FFTProfileTarget, ...]] = {} + average_error: Exception | None = None + average_input_candidates: list[Path] = [] + project_average_input = self._project_average_input_path() + if project_average_input is not None: + average_input_candidates.append(project_average_input) + if not any( + candidate.resolve() == input_path.resolve() + for candidate in average_input_candidates + ): + average_input_candidates.append(input_path) + for average_input in average_input_candidates: + try: + average_targets = self._cluster_targets_from_input_path( + average_input + ) + except Exception as exc: + if average_error is None: + average_error = exc + continue + if average_targets: + targets[("average", "input")] = average_targets + break + targets.update(self._representative_targets_from_project_source()) + if not targets and average_error is not None: + raise average_error + return targets + + def _current_structure_source_mode(self) -> str: + return str(self.structure_source_combo.currentData() or "average") + + def _current_representative_solvent_mode(self) -> str: + return str( + self.representative_solvent_mode_combo.currentData() or "partial" + ) + + def _active_profile_target(self) -> _FFTProfileTarget | None: + if not self._current_profile_targets: + return None + for target in self._current_profile_targets: + if target.key == self._active_profile_key: + return target + return self._current_profile_targets[0] + + def _load_structure_for_target( + self, + target: _FFTProfileTarget | None, + ) -> ElectronDensityStructure | None: + if target is None: + return None + cache_key = str(target.reference_file) + cached = self._reference_structure_cache.get(cache_key) + if cached is not None: + return cached + try: + structure = load_electron_density_structure(target.reference_file) + except Exception: + return None + self._reference_structure_cache[cache_key] = structure + return structure + + def _set_active_profile_target( + self, + target: _FFTProfileTarget | None, + ) -> None: + self._active_profile_key = None if target is None else target.key + if target is None: + self._loaded_reference_file = None + self._loaded_structure = None + self._loaded_structure_count = 0 + self._refresh_fft_box_visualizer() + return + self._loaded_reference_file = target.reference_file + self._loaded_structure_count = int(target.file_count) + self._loaded_structure = self._load_structure_for_target(target) + if self._loaded_structure is not None: + preview_mesh = self._build_preview_mesh_geometry( + self._loaded_structure + ) + self._sync_legacy_1d_defaults_to_structure(self._loaded_structure) + self.structure_viewer.set_structure( + self._loaded_structure, + mesh_geometry=preview_mesh, + scene_key=str(target.reference_file), + ) + self.structure_viewer.mesh_contrast_spin.setValue(90.0) + self.structure_viewer.mesh_linewidth_spin.setValue(1.6) + self._update_loaded_input_summary() + self._refresh_fft_box_visualizer() + self._refresh_plot_controls() + self._sync_workspace_state() + + def _update_loaded_input_summary(self) -> None: + if self._loaded_input_path is None: + self.loaded_input_summary.setText( + "Load a structure file or folder to inspect the 3D FFT Born setup." + ) + return + target = self._active_profile_target() + if target is None: + self.loaded_input_summary.setText( + f"Loaded input from {self._loaded_input_path}, but no eligible 3D FFT profile targets were detected." + ) + return + self.loaded_input_summary.setText( + f"Loaded {len(self._current_profile_targets)} active 3D FFT profile " + f"target(s) from {self._loaded_input_path.name}.\n" + f"Selected profile: {target.display_name} using " + f"{target.file_count} structure file(s) from " + f"{target.source_mode} / {target.solvent_mode} mode." + ) + + def _populate_active_profile_combo(self) -> None: + self.active_profile_combo.blockSignals(True) + self.active_profile_combo.clear() + selected_index = 0 + for index, target in enumerate(self._current_profile_targets): + label = ( + f"{target.display_name} ({target.file_count} file" + f"{'' if target.file_count == 1 else 's'})" + ) + self.active_profile_combo.addItem(label, target.key) + if target.key == self._active_profile_key: + selected_index = index + self.active_profile_combo.setEnabled( + bool(self._current_profile_targets) + ) + if self._current_profile_targets: + self.active_profile_combo.setCurrentIndex(selected_index) + self.active_profile_combo.blockSignals(False) + + @Slot() + def _refresh_available_profile_targets(self) -> None: + if self._loaded_input_path is None: + self.structure_source_combo.setEnabled(False) + self.representative_solvent_mode_combo.setEnabled(False) + self.active_profile_combo.setEnabled(False) + return + available_sources = { + source_mode + for source_mode, _solvent_mode in self._available_profile_targets + } + self.structure_source_combo.blockSignals(True) + self.structure_source_combo.setEnabled(bool(available_sources)) + preferred_source_mode = ( + "representative" + if self._prefer_representative_structures + and "representative" in available_sources + else "average" + ) + current_source_mode = self._current_structure_source_mode() + if ( + not self._current_profile_targets + and preferred_source_mode in available_sources + ): + current_source_mode = preferred_source_mode + elif current_source_mode not in available_sources: + current_source_mode = preferred_source_mode + target_source_index = self.structure_source_combo.findData( + current_source_mode + ) + if target_source_index < 0: + target_source_index = self.structure_source_combo.findData( + "average" + ) + self.structure_source_combo.setCurrentIndex( + max(target_source_index, 0) + ) + self.structure_source_combo.blockSignals(False) + + representative_modes = [ + solvent_mode + for source_mode, solvent_mode in self._available_profile_targets + if source_mode == "representative" + ] + representative_labels = { + "none": "No solvent representative", + "partial": "Partial / source representative", + "full": "Full solvent representative", + } + current_rep_mode = self._current_representative_solvent_mode() + if current_rep_mode not in representative_modes: + current_rep_mode = ( + "partial" + if "partial" in representative_modes + else ( + representative_modes[0] + if representative_modes + else "partial" + ) + ) + self.representative_solvent_mode_combo.blockSignals(True) + self.representative_solvent_mode_combo.clear() + if representative_modes: + for solvent_mode in ("none", "partial", "full"): + if solvent_mode not in representative_modes: + continue + self.representative_solvent_mode_combo.addItem( + representative_labels[solvent_mode], + solvent_mode, + ) + rep_index = self.representative_solvent_mode_combo.findData( + current_rep_mode + ) + self.representative_solvent_mode_combo.setCurrentIndex( + max(rep_index, 0) + ) + else: + self.representative_solvent_mode_combo.addItem( + "No representative solvent variants available", + None, + ) + self.representative_solvent_mode_combo.setCurrentIndex(0) + using_representatives = ( + self._current_structure_source_mode() == "representative" + and bool(representative_modes) + ) + self.representative_solvent_mode_combo.setEnabled( + using_representatives + ) + self.representative_solvent_mode_combo.blockSignals(False) + + selected_mode = self._current_structure_source_mode() + selected_solvent_mode = ( + self._current_representative_solvent_mode() + if selected_mode == "representative" + else "input" + ) + self._current_profile_targets = self._available_profile_targets.get( + (selected_mode, selected_solvent_mode), + (), + ) + if self._active_profile_key is None or all( + target.key != self._active_profile_key + for target in self._current_profile_targets + ): + self._active_profile_key = ( + None + if not self._current_profile_targets + else self._current_profile_targets[0].key + ) + self._populate_active_profile_combo() + if selected_mode == "representative" and representative_modes: + self.structure_source_hint.setText( + "Representative mode is active. Choose no-solvent, partial/source, or full-solvent representative files when those project artifacts are available." + ) + elif selected_mode == "representative": + self.structure_source_hint.setText( + "Representative mode was requested, but no saved representative structures are currently available for this project." + ) + else: + self.structure_source_hint.setText( + "Average structure mode is active. Each profile target averages all structure files in its active cluster bin or loaded input folder." + ) + self._set_active_profile_target(self._active_profile_target()) + self._update_push_to_model_state() + + @Slot() + def _on_active_profile_changed(self) -> None: + target_key = str(self.active_profile_combo.currentData() or "").strip() + if not target_key: + self._set_active_profile_target(None) + return + for target in self._current_profile_targets: + if target.key == target_key: + self._set_active_profile_target(target) + return + self._set_active_profile_target(self._active_profile_target()) + + def _active_profile_result(self) -> _FFTProfileComputationResult | None: + if self._active_profile_key is None: + return None + return self._computed_profile_results.get(self._active_profile_key) + + def _legacy_overlay_label(self) -> str: + if self._active_contrast_settings is None: + return "1D Born Approximation (Average)" + return "1D Born Approximation (Average, matched contrast)" + + def _refresh_fft_box_visualizer(self) -> None: + active_result = self._active_profile_result() + self.fft_box_visualizer.set_structure_preview( + self._loaded_structure, + fft_result=( + None if active_result is None else active_result.fft_result + ), + contrast_summary=self._active_contrast_summary_for_visualizer(), + ) + + def _update_curve_legend_button_text(self) -> None: + self.toggle_curve_legend_button.setText( + "Hide Legend" if self._curve_legend_visible else "Show Legend" + ) + + @Slot() + def _toggle_curve_legend(self) -> None: + self._curve_legend_visible = not self._curve_legend_visible + self._update_curve_legend_button_text() + self._refresh_plot_controls() + + def _current_curve_series( + self, + ) -> list[tuple[str, np.ndarray, np.ndarray]]: + active_result = self._active_profile_result() + if active_result is None: + return [] + primary_label = ( + "3D FFT Born Approximation (solvent contrast)" + if active_result.fft_result.density_subtraction_active + else "3D FFT Born Approximation" + ) + series: list[tuple[str, np.ndarray, np.ndarray]] = [ + ( + primary_label, + np.asarray(active_result.q_values, dtype=float), + np.asarray( + active_result.fft_result.raw_intensity, dtype=float + ), + ) + ] + if ( + self.show_kernel_corrected_checkbox.isChecked() + and active_result.fft_result.kernel_correction_supported + ): + series.append( + ( + "3D FFT kernel-corrected (diagnostic)", + np.asarray(active_result.q_values, dtype=float), + np.asarray( + active_result.fft_result.kernel_corrected_intensity, + dtype=float, + ), + ) + ) + if ( + self.compare_legacy_checkbox.isChecked() + and active_result.legacy_q_values is not None + and active_result.legacy_intensity is not None + ): + series.append( + ( + self._legacy_overlay_label(), + np.asarray(active_result.legacy_q_values, dtype=float), + np.asarray(active_result.legacy_intensity, dtype=float), + ) + ) + if ( + self.compare_exact_debye_checkbox.isChecked() + and active_result.exact_debye_intensity is not None + ): + series.append( + ( + "Exact Debye scattering", + np.asarray(active_result.q_values, dtype=float), + np.asarray( + active_result.exact_debye_intensity, dtype=float + ), + ) + ) + return series + + @Slot() + def _export_q_space_curves_csv(self) -> None: + series = self._current_curve_series() + if not series: + self._show_error( + "No q-Space Data", + "Run the 3D FFT Born approximation before exporting plot data.", + ) + return + start_dir = str(self._output_dir or self._project_dir or Path.cwd()) + selected_path, _selected_filter = QFileDialog.getSaveFileName( + self, + "Export q-Space Curves CSV", + str(Path(start_dir) / "fft_born_q_space_curves.csv"), + "CSV files (*.csv);;All files (*)", + ) + if not selected_path: + return + path = Path(selected_path).expanduser().resolve() + max_length = max( + int(q_values.size) for _label, q_values, _intensity in series + ) + header: list[str] = [] + for label, _q_values, _intensity in series: + prefix = ( + str(label) + .lower() + .replace(" ", "_") + .replace(",", "") + .replace("(", "") + .replace(")", "") + .replace("/", "_") + .replace("-", "_") + ) + header.extend([f"{prefix}_q_a_inverse", f"{prefix}_intensity"]) + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.writer(handle) + writer.writerow(header) + for index in range(max_length): + row: list[object] = [] + for _label, q_values, intensity in series: + if index < q_values.size: + row.append(f"{float(q_values[index]):.8f}") + row.append(f"{float(intensity[index]):.8e}") + else: + row.extend(["", ""]) + writer.writerow(row) + self._append_status(f"Exported q-space plot data CSV to {path}.") + self.statusBar().showMessage("Exported q-space plot data CSV") + + def _append_status(self, message: str) -> None: + text = str(message).strip() + if not text: + return + self.status_log_box.appendPlainText(text) + self.statusBar().showMessage(text) + + def _apply_preview_mode_title(self) -> None: + title = "3D FFT Born Approximation" + if self._preview_mode: + title += " (Preview)" + self.setWindowTitle(title) + + def _refresh_preview_mode_banner(self) -> None: + if self._preview_mode: + self.preview_mode_banner.setText( + "Preview Mode: inspect a structure file or folder and compare the " + "3D FFT Born approximation against optional legacy 1D Born and exact Debye references." + ) + self.preview_mode_banner.setToolTip( + "Preview mode does not push 3D FFT Born outputs into an active SAXS model." + ) + return + distribution_text = ( + self._distribution_id + if self._distribution_id is not None + else "active computed distribution" + ) + self.preview_mode_banner.setText( + "Computed Distribution Mode: this run is linked to " + f"{distribution_text}. Use this window to evaluate the separate 3D FFT Born workflow before full component-export integration." + ) + self.preview_mode_banner.setToolTip( + "This window was launched from Build SAXS Components using the 3D FFT Born Approximation mode." + ) + + def _browse_input_file(self) -> None: + selected, _filter = QFileDialog.getOpenFileName( + self, + "Open Structure File", + str(self._project_dir or Path.home()), + "Structure Files (*.pdb *.xyz);;All Files (*)", + ) + if selected: + self.input_path_edit.setText(selected) + + def _browse_input_folder(self) -> None: + selected = QFileDialog.getExistingDirectory( + self, + "Open Structure Folder", + str(self._project_dir or Path.home()), + ) + if selected: + self.input_path_edit.setText(selected) + + @Slot() + def _load_input_from_edit(self) -> None: + text = str(self.input_path_edit.text()).strip() + if not text: + self._show_error( + "Input required", + "Choose a structure file or folder before loading the 3D FFT Born input.", + ) + return + self._load_input_path(Path(text).expanduser().resolve()) + + def _load_input_path(self, path: Path) -> None: + try: + available_targets = self._resolve_available_profile_targets(path) + except Exception as exc: + self._show_error( + "Load failed", + f"Could not load 3D FFT Born input targets:\n{exc}", + ) + return + if not available_targets: + self._show_error( + "Load failed", + "No eligible 3D FFT Born profile targets were discovered for the selected input.", + ) + return + self._loaded_input_path = path + self._available_profile_targets = dict(available_targets) + self._current_payload = None + self._computed_profile_results = {} + self._computed_profile_run_signature = None + self.curve_plot.draw_placeholder() + self.shell_count_plot.draw_placeholder() + self.result_summary_box.clear() + self._restoring_workspace_state = True + try: + self._refresh_available_profile_targets() + finally: + self._restoring_workspace_state = False + self._restore_workspace_state_from_disk() + self._update_push_to_model_state() + self._append_status( + f"Loaded 3D FFT Born input from {path} with " + f"{sum(len(targets) for targets in self._available_profile_targets.values())} " + "discovered profile target(s)." + ) + + def _build_preview_mesh_geometry( + self, + structure: ElectronDensityStructure, + ): + return build_electron_density_mesh( + structure, + legacy_born_average_default_mesh_settings(structure), + ) + + def _load_deferred_input(self) -> None: + if self._deferred_initial_input_path is None: + return + self.input_path_edit.setText(str(self._deferred_initial_input_path)) + self._load_input_path(self._deferred_initial_input_path) + self._deferred_initial_input_path = None + + def _current_fft_settings(self) -> ContrastFFTSettings: + return ContrastFFTSettings( + spacing_a=float(self.spacing_spin.value()), + gaussian_sigma_a=float(self.sigma_spin.value()), + minimum_box_length_a=float(self.min_box_length_spin.value()), + padding_a=float(self.padding_spin.value()), + solvent_density_e_per_a3=float( + self._active_solvent_density_e_per_a3 + ), + exclusion_radius_scale=float( + self.exclusion_radius_scale_spin.value() + ), + exclusion_radius_padding_a=float( + self.exclusion_radius_padding_spin.value() + ), + ) + + @staticmethod + def _signature_float(value: object) -> float: + return round(float(value), 12) + + def _current_trace_configuration_signature(self) -> dict[str, object]: + fft_settings = self._current_fft_settings().normalized() + return { + "target_keys": [ + target.key for target in self._current_profile_targets + ], + "structure_source_mode": self._current_structure_source_mode(), + "representative_solvent_mode": ( + self._current_representative_solvent_mode() + if self._current_structure_source_mode() == "representative" + else "input" + ), + "q_min": self._signature_float(self.q_min_spin.value()), + "q_max": self._signature_float(self.q_max_spin.value()), + "q_step": self._signature_float(self.q_step_spin.value()), + "fft_settings": { + "spacing_a": self._signature_float(fft_settings.spacing_a), + "gaussian_sigma_a": self._signature_float( + fft_settings.gaussian_sigma_a + ), + "minimum_box_length_a": self._signature_float( + fft_settings.minimum_box_length_a + ), + "padding_a": self._signature_float(fft_settings.padding_a), + "support_sigma": self._signature_float( + fft_settings.support_sigma + ), + "solvent_density_e_per_a3": self._signature_float( + fft_settings.solvent_density_e_per_a3 + ), + "exclusion_radius_scale": self._signature_float( + fft_settings.exclusion_radius_scale + ), + "exclusion_radius_padding_a": self._signature_float( + fft_settings.exclusion_radius_padding_a + ), + "use_cubic_box": bool(fft_settings.use_cubic_box), + }, + "active_contrast_settings": ( + None + if self._active_contrast_settings is None + else self._active_contrast_settings.to_dict() + ), + "active_solvent_density_e_per_a3": self._signature_float( + self._active_solvent_density_e_per_a3 + ), + } + + def _requested_q_values_from_controls(self) -> np.ndarray: + return build_shared_q_grid( + float(self.q_min_spin.value()), + float(self.q_max_spin.value()), + q_step=float(self.q_step_spin.value()), + ) + + def _computed_profile_q_grid_matches_requested_grid(self) -> bool: + if not self._computed_profile_results: + return False + try: + requested_q_values = self._requested_q_values_from_controls() + except Exception: + return False + for result in self._computed_profile_results.values(): + q_values = np.asarray(result.q_values, dtype=float) + if q_values.shape != requested_q_values.shape: + return False + if not np.allclose( + q_values, + requested_q_values, + rtol=1.0e-10, + atol=1.0e-12, + ): + return False + return True + + def _results_match_current_configuration(self) -> bool: + if not self._computed_profile_results: + return False + if self._contrast_controls_dirty: + return False + run_signature = self._computed_profile_run_signature + if run_signature is None: + return True + return ( + run_signature == self._current_trace_configuration_signature() + and self._computed_profile_q_grid_matches_requested_grid() + ) + + @Slot() + def _start_calculation(self) -> None: + if ( + self._compute_thread is not None + and self._compute_thread.isRunning() + ): + return + if not self._current_profile_targets: + self._show_error( + "No input loaded", + "Load a structure file or folder with at least one active " + "profile target before running the 3D FFT Born approximation.", + ) + return + q_min = float(self.q_min_spin.value()) + q_max = float(self.q_max_spin.value()) + q_step = float(self.q_step_spin.value()) + if q_max <= q_min: + self._show_error( + "Invalid q range", + "q max must be greater than q min before starting the " + "3D FFT Born calculation.", + ) + return + if self._contrast_controls_dirty and not ( + self._apply_contrast_settings_from_controls(announce=True) + ): + return + self.compute_button.setEnabled(False) + self.load_input_button.setEnabled(False) + self._close_requested_while_running = False + self._begin_progress_dialog("Running 3D FFT Born Approximation...") + self._append_status("Starting 3D FFT Born approximation calculation.") + self._compute_thread = QThread(self) + self._compute_worker = _FFTComputationWorker( + targets=tuple(self._current_profile_targets), + fft_settings=self._current_fft_settings(), + legacy_mesh_settings=self._legacy_1d_mesh_settings_from_controls(), + legacy_smearing_settings=( + self._legacy_1d_smearing_settings_from_controls() + ), + legacy_fourier_settings=( + self._legacy_1d_fourier_settings_from_controls() + ), + active_contrast_settings=self._active_contrast_settings, + active_contrast_name=self._active_contrast_name, + q_min=q_min, + q_max=q_max, + q_step=q_step, + compare_legacy_1d=self.compare_legacy_checkbox.isChecked(), + compare_exact_debye=self.compare_exact_debye_checkbox.isChecked(), + ) + self._compute_worker.moveToThread(self._compute_thread) + self._compute_thread.started.connect(self._compute_worker.run) + self._compute_worker.progress.connect(self._on_worker_progress) + self._compute_worker.finished.connect(self._on_worker_finished) + self._compute_worker.failed.connect(self._on_worker_failed) + self._compute_worker.cancelled.connect(self._on_worker_cancelled) + self._compute_worker.finished.connect(self._compute_thread.quit) + self._compute_worker.failed.connect(self._compute_thread.quit) + self._compute_worker.cancelled.connect(self._compute_thread.quit) + self._compute_thread.finished.connect(self._compute_worker.deleteLater) + self._compute_thread.finished.connect(self._compute_thread.deleteLater) + self._compute_thread.finished.connect(self._clear_worker_handles) + self._compute_thread.finished.connect( + self._finalize_close_after_worker + ) + self._compute_thread.start(QThread.Priority.LowPriority) + + @Slot(str) + def _on_worker_progress(self, message: str) -> None: + self._append_status(message) + self._set_progress_message(message) + + @Slot(object) + def _on_worker_finished(self, payload: object) -> None: + self.compute_button.setEnabled(True) + self.load_input_button.setEnabled(True) + self._close_progress_dialog() + if not isinstance(payload, _FFTComputationPayload): + self._show_error( + "Calculation failed", + "The 3D FFT Born calculation returned an unexpected payload.", + ) + return + self._current_payload = payload + self._computed_profile_results = { + result.target.key: result for result in payload.profile_results + } + self._computed_profile_run_signature = ( + self._current_trace_configuration_signature() + ) + if ( + self._active_profile_key is None + or self._active_profile_key not in self._computed_profile_results + ) and payload.profile_results: + self._active_profile_key = payload.profile_results[0].target.key + self._refresh_plot_controls() + self._refresh_fft_box_visualizer() + self._update_result_summary(payload) + active_result = self._active_profile_result() + if active_result is not None: + self._append_status( + "3D FFT Nyquist limit: " + f"{active_result.fft_result.q_nyquist_a_inverse:.6f} Å^-1 using voxel spacing " + f"{active_result.fft_result.voxel_spacing_a[0]:.3f} Å." + ) + self._sync_workspace_state() + self._update_push_to_model_state() + self._append_status("3D FFT Born approximation calculation complete.") + + @Slot(str) + def _on_worker_failed(self, message: str) -> None: + self.compute_button.setEnabled(True) + self.load_input_button.setEnabled(True) + self._close_progress_dialog() + if self._close_requested_while_running: + return + self._show_error( + "Calculation failed", + "The 3D FFT Born approximation could not be completed:\n" + + str(message).strip(), + ) + + @Slot(str) + def _on_worker_cancelled(self, message: str) -> None: + self.compute_button.setEnabled(True) + self.load_input_button.setEnabled(True) + self._close_progress_dialog() + if self._close_requested_while_running: + return + text = str(message).strip() or "3D FFT Born calculation cancelled." + self._append_status(text) + self.statusBar().showMessage(text) + + @Slot() + def _clear_worker_handles(self) -> None: + self._compute_worker = None + self._compute_thread = None + + @Slot() + def _finalize_close_after_worker(self) -> None: + if not self._close_requested_while_running: + return + self._close_progress_dialog() + self.hide() + self.deleteLater() + + def _update_result_summary(self, payload: _FFTComputationPayload) -> None: + active_result = self._active_profile_result() + if active_result is None: + self.result_summary_box.clear() + return + fft = active_result.fft_result + lines = [ + ("Computed profiles: " f"{len(payload.profile_results)}"), + ("Active profile: " f"{active_result.target.display_name}"), + f"Input: {self._loaded_reference_file or 'None'}", + f"Atoms: {0 if self._loaded_structure is None else self._loaded_structure.atom_count}", + f"Target file count: {active_result.target.file_count}", + ( + "Solvent density contrast: " + f"{fft.solvent_density_e_per_a3:.6f} e/ų" + ), + f"Contrast mode: {fft.contrast_mode}", + f"Grid shape: {'x'.join(str(value) for value in fft.grid_shape)}", + ( + "Box lengths (Å): " + + ", ".join(f"{value:.3f}" for value in fft.box_lengths_a) + ), + f"q Nyquist: {fft.q_nyquist_a_inverse:.6f} Å⁻¹", + ( + "First non-empty q bin: " + + ( + "None" + if fft.first_nonempty_q_a_inverse is None + else f"{fft.first_nonempty_q_a_inverse:.4f} Å⁻¹" + ) + ), + ( + "Atomic density integral / expected weight: " + f"{fft.density_integral:.6f} / {fft.expected_weight:.6f}" + ), + ( + "Contrast density integral / expected contrast weight: " + f"{fft.contrast_density_integral:.6f} / {fft.expected_contrast_weight:.6f}" + ), + ( + "Timing (s): deposit=" + f"{fft.timing.atomic_density_seconds:.4f}, contrast=" + f"{fft.timing.contrast_density_seconds:.4f}, fft=" + f"{fft.timing.fft_seconds:.4f}, shells=" + f"{fft.timing.shell_average_seconds:.4f}, total=" + f"{fft.timing.total_seconds:.4f}" + ), + ] + if active_result.legacy_elapsed_seconds is not None: + lines.append( + "Legacy 1D Born comparison time (s): " + f"{active_result.legacy_elapsed_seconds:.4f}" + ) + if self._active_contrast_settings is not None: + lines.append( + "Legacy 1D Born overlay reused the active solvent contrast before its Fourier transform." + ) + if active_result.debye_elapsed_seconds is not None: + lines.append( + f"Exact Debye comparison time (s): {active_result.debye_elapsed_seconds:.4f}" + ) + if fft.kernel_correction_supported: + lines.append( + "Kernel correction is available for this zero-contrast run as a diagnostic overlay." + ) + else: + lines.append( + "Kernel correction is disabled because solvent-contrast subtraction is active." + ) + self.result_summary_box.setPlainText("\n".join(lines)) + + @Slot() + def _refresh_plot_controls(self) -> None: + active_result = self._active_profile_result() + if active_result is None: + self.curve_plot.draw_placeholder() + self.shell_count_plot.draw_placeholder() + self._refresh_fft_box_visualizer() + return + additional_series: list[dict[str, object]] = [] + if ( + self.show_kernel_corrected_checkbox.isChecked() + and active_result.fft_result.kernel_correction_supported + ): + additional_series.append( + { + "q_values": active_result.q_values, + "intensity": active_result.fft_result.kernel_corrected_intensity, + "label": "3D FFT kernel-corrected (diagnostic)", + "color": "#7c3aed", + "linestyle": "--", + } + ) + if ( + self.compare_legacy_checkbox.isChecked() + and active_result.legacy_q_values is not None + and active_result.legacy_intensity is not None + ): + additional_series.append( + { + "q_values": active_result.legacy_q_values, + "intensity": active_result.legacy_intensity, + "label": self._legacy_overlay_label(), + "color": "#b45309", + "linestyle": "-.", + } + ) + if ( + self.compare_exact_debye_checkbox.isChecked() + and active_result.exact_debye_intensity is not None + ): + additional_series.append( + { + "q_values": active_result.q_values, + "intensity": active_result.exact_debye_intensity, + "label": "Exact Debye scattering", + "color": "#0f172a", + "linestyle": ":", + "linewidth": 1.9, + } + ) + primary_label = ( + "3D FFT Born Approximation (solvent contrast)" + if active_result.fft_result.density_subtraction_active + else "3D FFT Born Approximation" + ) + self.curve_plot.set_curves( + q_values=active_result.q_values, + primary_values=active_result.fft_result.raw_intensity, + primary_label=primary_label, + additional_series=additional_series, + log_q_axis=self.log_q_checkbox.isChecked(), + log_intensity_axis=self.log_intensity_checkbox.isChecked(), + show_legend=self._curve_legend_visible, + ) + self.shell_count_plot.set_counts( + active_result.q_values, + active_result.fft_result.q_shell_counts, + ) + self._refresh_fft_box_visualizer() + + def _sync_kernel_correction_option(self) -> None: + solvent_density = float(self._active_solvent_density_e_per_a3) + enabled = abs(solvent_density) <= 1.0e-15 + self.show_kernel_corrected_checkbox.setEnabled(enabled) + if not enabled: + self.show_kernel_corrected_checkbox.setChecked(False) + + @Slot() + def _clear_results(self) -> None: + self._current_payload = None + self._computed_profile_results = {} + self._computed_profile_run_signature = None + self.curve_plot.draw_placeholder() + self.shell_count_plot.draw_placeholder() + self.result_summary_box.clear() + self._refresh_fft_box_visualizer() + self._sync_workspace_state() + self._update_push_to_model_state() + self._append_status("Cleared 3D FFT Born outputs.") + + def _state_dir(self) -> Path | None: + if self._output_dir is not None: + return Path(self._output_dir).expanduser().resolve() + if self._distribution_root_dir is None: + return None + return ( + Path(self._distribution_root_dir).expanduser().resolve() + / "born_approximation_3d_fft" + ) + + def _workspace_state_path(self) -> Path | None: + state_dir = self._state_dir() + if state_dir is None: + return None + return state_dir / "workspace_state.json" + + def _component_summary_path(self) -> Path | None: + state_dir = self._state_dir() + if state_dir is None: + return None + return state_dir / "born_approximation_3d_fft_component_summary.json" + + def _component_artifact_targets(self) -> tuple[Path, Path] | None: + if self._distribution_root_dir is None: + return None + if self._use_predicted_structure_weights: + return ( + self._distribution_root_dir + / "scattering_components_predicted_structures", + self._distribution_root_dir + / "md_saxs_map_predicted_structures.json", + ) + return ( + self._distribution_root_dir / "scattering_components", + self._distribution_root_dir / "md_saxs_map.json", + ) + + @staticmethod + def _component_profile_filename(structure: str, motif: str) -> str: + safe_name = f"{structure}_{motif}".replace("/", "_") + return f"{safe_name}.txt" + + def _serialize_fft_result( + self, + result: ContrastFFTResult, + ) -> dict[str, object]: + return { + "settings": { + "spacing_a": float(result.settings.spacing_a), + "gaussian_sigma_a": float(result.settings.gaussian_sigma_a), + "minimum_box_length_a": float( + result.settings.minimum_box_length_a + ), + "padding_a": float(result.settings.padding_a), + "support_sigma": float(result.settings.support_sigma), + "solvent_density_e_per_a3": float( + result.settings.solvent_density_e_per_a3 + ), + "exclusion_radius_scale": float( + result.settings.exclusion_radius_scale + ), + "exclusion_radius_padding_a": float( + result.settings.exclusion_radius_padding_a + ), + "use_cubic_box": bool(result.settings.use_cubic_box), + }, + "q_values": np.asarray(result.q_values, dtype=float).tolist(), + "raw_intensity": np.asarray( + result.raw_intensity, + dtype=float, + ).tolist(), + "kernel_corrected_intensity": np.asarray( + result.kernel_corrected_intensity, + dtype=float, + ).tolist(), + "q_shell_counts": np.asarray( + result.q_shell_counts, + dtype=int, + ).tolist(), + "density_integral": float(result.density_integral), + "expected_weight": float(result.expected_weight), + "contrast_density_integral": float( + result.contrast_density_integral + ), + "expected_contrast_weight": float(result.expected_contrast_weight), + "solvent_exclusion_volume_a3": float( + result.solvent_exclusion_volume_a3 + ), + "grid_shape": [int(value) for value in result.grid_shape], + "box_lengths_a": [float(value) for value in result.box_lengths_a], + "voxel_spacing_a": [ + float(value) for value in result.voxel_spacing_a + ], + "q_nyquist_a_inverse": float(result.q_nyquist_a_inverse), + "q_frequency_step_a_inverse": [ + float(value) for value in result.q_frequency_step_a_inverse + ], + "q_convention": str(result.q_convention), + "uses_two_pi_frequency_conversion": bool( + result.uses_two_pi_frequency_conversion + ), + "density_subtraction_active": bool( + result.density_subtraction_active + ), + "first_nonempty_q_a_inverse": ( + None + if result.first_nonempty_q_a_inverse is None + else float(result.first_nonempty_q_a_inverse) + ), + "solvent_density_e_per_a3": float(result.solvent_density_e_per_a3), + "contrast_mode": str(result.contrast_mode), + "kernel_correction_supported": bool( + result.kernel_correction_supported + ), + "kernel_correction_applied": bool( + result.kernel_correction_applied + ), + "kernel_correction_model": result.kernel_correction_model, + "timing": { + "atomic_density_seconds": float( + result.timing.atomic_density_seconds + ), + "contrast_density_seconds": float( + result.timing.contrast_density_seconds + ), + "fft_seconds": float(result.timing.fft_seconds), + "shell_average_seconds": float( + result.timing.shell_average_seconds + ), + "total_seconds": float(result.timing.total_seconds), + }, + } + + def _deserialize_fft_result( + self, + payload: dict[str, object], + ) -> ContrastFFTResult: + settings_payload = dict(payload.get("settings", {})) + settings = ContrastFFTSettings( + spacing_a=float(settings_payload.get("spacing_a", 2.5)), + gaussian_sigma_a=float( + settings_payload.get("gaussian_sigma_a", 0.75) + ), + minimum_box_length_a=float( + settings_payload.get("minimum_box_length_a", 640.0) + ), + padding_a=float(settings_payload.get("padding_a", 24.0)), + support_sigma=float(settings_payload.get("support_sigma", 4.0)), + solvent_density_e_per_a3=float( + settings_payload.get("solvent_density_e_per_a3", 0.0) + ), + exclusion_radius_scale=float( + settings_payload.get("exclusion_radius_scale", 1.0) + ), + exclusion_radius_padding_a=float( + settings_payload.get("exclusion_radius_padding_a", 0.0) + ), + use_cubic_box=bool(settings_payload.get("use_cubic_box", True)), + ).normalized() + timing_payload = dict(payload.get("timing", {})) + return ContrastFFTResult( + settings=settings, + q_values=np.asarray(payload.get("q_values", []), dtype=float), + raw_intensity=np.asarray( + payload.get("raw_intensity", []), + dtype=float, + ), + kernel_corrected_intensity=np.asarray( + payload.get("kernel_corrected_intensity", []), + dtype=float, + ), + q_shell_counts=np.asarray( + payload.get("q_shell_counts", []), + dtype=int, + ), + density_integral=float(payload.get("density_integral", 0.0)), + expected_weight=float(payload.get("expected_weight", 0.0)), + contrast_density_integral=float( + payload.get("contrast_density_integral", 0.0) + ), + expected_contrast_weight=float( + payload.get("expected_contrast_weight", 0.0) + ), + solvent_exclusion_volume_a3=float( + payload.get("solvent_exclusion_volume_a3", 0.0) + ), + grid_shape=tuple( + int(value) for value in payload.get("grid_shape", (1, 1, 1)) + ), + box_lengths_a=tuple( + float(value) + for value in payload.get("box_lengths_a", (0.0, 0.0, 0.0)) + ), + voxel_spacing_a=tuple( + float(value) + for value in payload.get("voxel_spacing_a", (0.0, 0.0, 0.0)) + ), + q_nyquist_a_inverse=float(payload.get("q_nyquist_a_inverse", 0.0)), + q_frequency_step_a_inverse=tuple( + float(value) + for value in payload.get( + "q_frequency_step_a_inverse", + (0.0, 0.0, 0.0), + ) + ), + q_convention=str(payload.get("q_convention", "")).strip(), + uses_two_pi_frequency_conversion=bool( + payload.get("uses_two_pi_frequency_conversion", True) + ), + density_subtraction_active=bool( + payload.get("density_subtraction_active", False) + ), + first_nonempty_q_a_inverse=( + None + if payload.get("first_nonempty_q_a_inverse") is None + else float(payload.get("first_nonempty_q_a_inverse", 0.0)) + ), + solvent_density_e_per_a3=float( + payload.get("solvent_density_e_per_a3", 0.0) + ), + contrast_mode=str(payload.get("contrast_mode", "")).strip(), + kernel_correction_supported=bool( + payload.get("kernel_correction_supported", False) + ), + kernel_correction_applied=bool( + payload.get("kernel_correction_applied", False) + ), + kernel_correction_model=payload.get("kernel_correction_model"), + timing=ContrastFFTTiming( + atomic_density_seconds=float( + timing_payload.get("atomic_density_seconds", 0.0) + ), + contrast_density_seconds=float( + timing_payload.get("contrast_density_seconds", 0.0) + ), + fft_seconds=float(timing_payload.get("fft_seconds", 0.0)), + shell_average_seconds=float( + timing_payload.get("shell_average_seconds", 0.0) + ), + total_seconds=float(timing_payload.get("total_seconds", 0.0)), + ), + ) + + def _serialize_profile_result( + self, + result: _FFTProfileComputationResult, + ) -> dict[str, object]: + return { + "target": { + "key": result.target.key, + "display_name": result.target.display_name, + "structure_name": result.target.structure_name, + "motif_name": result.target.motif_name, + "file_count": int(result.target.file_count), + "reference_file": str(result.target.reference_file), + "source_files": [ + str(path) for path in result.target.source_files + ], + "representative": result.target.representative, + "source_mode": result.target.source_mode, + "solvent_mode": result.target.solvent_mode, + }, + "q_values": np.asarray(result.q_values, dtype=float).tolist(), + "fft_result": self._serialize_fft_result(result.fft_result), + "legacy_q_values": ( + None + if result.legacy_q_values is None + else np.asarray(result.legacy_q_values, dtype=float).tolist() + ), + "legacy_intensity": ( + None + if result.legacy_intensity is None + else np.asarray(result.legacy_intensity, dtype=float).tolist() + ), + "exact_debye_intensity": ( + None + if result.exact_debye_intensity is None + else np.asarray( + result.exact_debye_intensity, + dtype=float, + ).tolist() + ), + "legacy_elapsed_seconds": result.legacy_elapsed_seconds, + "debye_elapsed_seconds": result.debye_elapsed_seconds, + } + + def _deserialize_profile_result( + self, + payload: dict[str, object], + ) -> _FFTProfileComputationResult: + target_payload = dict(payload.get("target", {})) + target = _FFTProfileTarget( + key=str(target_payload.get("key", "")).strip(), + display_name=str(target_payload.get("display_name", "")).strip(), + structure_name=str( + target_payload.get("structure_name", "") + ).strip(), + motif_name=str( + target_payload.get("motif_name", "no_motif") + ).strip() + or "no_motif", + file_count=int(target_payload.get("file_count", 1)), + reference_file=Path( + str(target_payload.get("reference_file", "")).strip() + ) + .expanduser() + .resolve(), + source_files=tuple( + Path(str(path)).expanduser().resolve() + for path in target_payload.get("source_files", []) + ), + representative=( + None + if target_payload.get("representative") is None + else str(target_payload.get("representative", "")).strip() + ), + source_mode=str( + target_payload.get("source_mode", "average") + ).strip() + or "average", + solvent_mode=str( + target_payload.get("solvent_mode", "input") + ).strip() + or "input", + ) + return _FFTProfileComputationResult( + target=target, + q_values=np.asarray(payload.get("q_values", []), dtype=float), + fft_result=self._deserialize_fft_result( + dict(payload.get("fft_result", {})) + ), + legacy_q_values=( + None + if payload.get("legacy_q_values") is None + else np.asarray( + payload.get("legacy_q_values", []), dtype=float + ) + ), + legacy_intensity=( + None + if payload.get("legacy_intensity") is None + else np.asarray( + payload.get("legacy_intensity", []), dtype=float + ) + ), + exact_debye_intensity=( + None + if payload.get("exact_debye_intensity") is None + else np.asarray( + payload.get("exact_debye_intensity", []), + dtype=float, + ) + ), + legacy_elapsed_seconds=( + None + if payload.get("legacy_elapsed_seconds") is None + else float(payload.get("legacy_elapsed_seconds", 0.0)) + ), + debye_elapsed_seconds=( + None + if payload.get("debye_elapsed_seconds") is None + else float(payload.get("debye_elapsed_seconds", 0.0)) + ), + ) + + def _build_component_summary_payload(self) -> dict[str, object]: + return { + "schema_version": 1, + "generated_at": datetime.now().isoformat(timespec="seconds"), + "distribution_id": self._distribution_id, + "project_dir": ( + None if self._project_dir is None else str(self._project_dir) + ), + "input_path": ( + None + if self._loaded_input_path is None + else str(self._loaded_input_path) + ), + "structure_source_mode": self._current_structure_source_mode(), + "prefer_representative_structures": bool( + self._prefer_representative_structures + ), + "representative_solvent_mode": ( + self._current_representative_solvent_mode() + ), + "active_profile_key": self._active_profile_key, + "active_contrast_name": self._active_contrast_name, + "active_contrast_settings": ( + None + if self._active_contrast_settings is None + else self._active_contrast_settings.to_dict() + ), + "active_contrast_estimate": ( + None + if self._active_contrast_estimate is None + else self._active_contrast_estimate.to_dict() + ), + "active_solvent_density_e_per_a3": float( + self._active_solvent_density_e_per_a3 + ), + "q_min": float(self.q_min_spin.value()), + "q_max": float(self.q_max_spin.value()), + "q_step": float(self.q_step_spin.value()), + "fft_settings": { + "spacing_a": float(self.spacing_spin.value()), + "gaussian_sigma_a": float(self.sigma_spin.value()), + "minimum_box_length_a": float( + self.min_box_length_spin.value() + ), + "padding_a": float(self.padding_spin.value()), + "exclusion_radius_scale": float( + self.exclusion_radius_scale_spin.value() + ), + "exclusion_radius_padding_a": float( + self.exclusion_radius_padding_spin.value() + ), + }, + "legacy_mesh_settings": ( + self._legacy_1d_mesh_settings_from_controls().to_dict() + ), + "legacy_smearing_settings": ( + self._legacy_1d_smearing_settings_from_controls().to_dict() + ), + "legacy_fourier_settings": ( + self._legacy_1d_fourier_settings_from_controls().to_dict() + ), + "compare_legacy_1d": bool( + self.compare_legacy_checkbox.isChecked() + ), + "compare_exact_debye": bool( + self.compare_exact_debye_checkbox.isChecked() + ), + "show_kernel_corrected": bool( + self.show_kernel_corrected_checkbox.isChecked() + ), + "log_q_axis": bool(self.log_q_checkbox.isChecked()), + "log_intensity_axis": bool( + self.log_intensity_checkbox.isChecked() + ), + "curve_legend_visible": bool(self._curve_legend_visible), + "trace_configuration_signature": ( + self._computed_profile_run_signature + if self._computed_profile_run_signature is not None + else self._current_trace_configuration_signature() + ), + "profile_results": [ + self._serialize_profile_result(result) + for result in self._computed_profile_results.values() + ], + } + + def _restore_contrast_estimate( + self, + payload: dict[str, object] | None, + ) -> ContrastElectronDensityEstimate | None: + if not isinstance(payload, dict): + return None + reference_path = payload.get("reference_structure_file") + reference_box_spans = payload.get("reference_box_spans") + translated_volume_center = payload.get("translated_volume_center") + try: + return ContrastElectronDensityEstimate( + method=str(payload.get("method", "")).strip(), + label=str(payload.get("label", "")).strip(), + volume_a3=float(payload.get("volume_a3", 0.0)), + total_electrons=float(payload.get("total_electrons", 0.0)), + electron_density_e_per_a3=float( + payload.get("electron_density_e_per_a3", 0.0) + ), + electron_density_e_per_cm3=float( + payload.get("electron_density_e_per_cm3", 0.0) + ), + atom_count=( + None + if payload.get("atom_count") is None + else int(payload.get("atom_count", 0)) + ), + element_counts={ + str(key): int(value) + for key, value in dict( + payload.get("element_counts", {}) + ).items() + }, + formula=( + None + if payload.get("formula") is None + else str(payload.get("formula", "")).strip() or None + ), + source_density_g_per_cm3=( + None + if payload.get("source_density_g_per_cm3") is None + else float(payload.get("source_density_g_per_cm3", 0.0)) + ), + reference_structure_file=( + None + if reference_path is None + else Path(str(reference_path)).expanduser().resolve() + ), + reference_box_spans=( + None + if reference_box_spans is None + else tuple(float(value) for value in reference_box_spans) + ), + translated_volume_center=( + None + if translated_volume_center is None + else tuple( + float(value) for value in translated_volume_center + ) + ), + ) + except Exception: + return None + + def _restore_workspace_controls_from_payload( + self, + payload: dict[str, object], + ) -> None: + self.q_min_spin.setValue( + float(payload.get("q_min", self.q_min_spin.value())) + ) + self.q_max_spin.setValue( + float(payload.get("q_max", self.q_max_spin.value())) + ) + self.q_step_spin.setValue( + float(payload.get("q_step", self.q_step_spin.value())) + ) + fft_settings = payload.get("fft_settings") + if isinstance(fft_settings, dict): + self.spacing_spin.setValue( + float(fft_settings.get("spacing_a", self.spacing_spin.value())) + ) + self.sigma_spin.setValue( + float( + fft_settings.get( + "gaussian_sigma_a", + self.sigma_spin.value(), + ) + ) + ) + self.min_box_length_spin.setValue( + float( + fft_settings.get( + "minimum_box_length_a", + self.min_box_length_spin.value(), + ) + ) + ) + self.padding_spin.setValue( + float(fft_settings.get("padding_a", self.padding_spin.value())) + ) + self.exclusion_radius_scale_spin.setValue( + float( + fft_settings.get( + "exclusion_radius_scale", + self.exclusion_radius_scale_spin.value(), + ) + ) + ) + self.exclusion_radius_padding_spin.setValue( + float( + fft_settings.get( + "exclusion_radius_padding_a", + self.exclusion_radius_padding_spin.value(), + ) + ) + ) + legacy_mesh = payload.get("legacy_mesh_settings") + if isinstance(legacy_mesh, dict): + self.legacy_1d_rstep_spin.setValue( + float( + legacy_mesh.get( + "rstep_a", + self.legacy_1d_rstep_spin.value(), + ) + ) + ) + self.legacy_1d_theta_spin.setValue( + int( + legacy_mesh.get( + "theta_divisions", + self.legacy_1d_theta_spin.value(), + ) + ) + ) + self.legacy_1d_phi_spin.setValue( + int( + legacy_mesh.get( + "phi_divisions", + self.legacy_1d_phi_spin.value(), + ) + ) + ) + self.legacy_1d_rmax_spin.setValue( + float( + legacy_mesh.get( + "rmax_a", + self.legacy_1d_rmax_spin.value(), + ) + ) + ) + legacy_smearing = payload.get("legacy_smearing_settings") + if isinstance(legacy_smearing, dict): + self.legacy_1d_smearing_factor_spin.setValue( + float( + legacy_smearing.get( + "debye_waller_factor_a2", + self.legacy_1d_smearing_factor_spin.value(), + ) + ) + ) + legacy_fourier = payload.get("legacy_fourier_settings") + if isinstance(legacy_fourier, dict): + domain_index = self.legacy_1d_domain_combo.findData( + legacy_fourier.get("domain_mode", "legacy") + ) + if domain_index >= 0: + self.legacy_1d_domain_combo.setCurrentIndex(domain_index) + window_index = self.legacy_1d_window_combo.findData( + legacy_fourier.get("window_function", "none") + ) + if window_index >= 0: + self.legacy_1d_window_combo.setCurrentIndex(window_index) + self.legacy_1d_resampling_points_spin.setValue( + int( + legacy_fourier.get( + "resampling_points", + self.legacy_1d_resampling_points_spin.value(), + ) + ) + ) + self.compare_legacy_checkbox.setChecked( + bool(payload.get("compare_legacy_1d", True)) + ) + self.compare_exact_debye_checkbox.setChecked( + bool(payload.get("compare_exact_debye", False)) + ) + self.log_q_checkbox.setChecked(bool(payload.get("log_q_axis", True))) + self.log_intensity_checkbox.setChecked( + bool(payload.get("log_intensity_axis", True)) + ) + self._curve_legend_visible = bool( + payload.get("curve_legend_visible", True) + ) + self._update_curve_legend_button_text() + + contrast_settings_payload = payload.get("active_contrast_settings") + self._active_contrast_settings = None + self._active_contrast_estimate = None + self._active_contrast_name = None + self._active_solvent_density_e_per_a3 = 0.0 + if isinstance(contrast_settings_payload, dict): + restored_settings = ContrastSolventDensitySettings.from_values( + **contrast_settings_payload + ) + method_index = self.solvent_method_combo.findData( + restored_settings.method + ) + if method_index >= 0: + self.solvent_method_combo.setCurrentIndex(method_index) + self.solvent_formula_edit.setText( + restored_settings.solvent_formula or "" + ) + self.solvent_density_spin.setValue( + float(restored_settings.solvent_density_g_per_ml or 0.0) + ) + self.direct_density_spin.setValue( + float( + restored_settings.direct_electron_density_e_per_a3 or 0.0 + ) + ) + self.reference_solvent_file_edit.setText( + "" + if restored_settings.reference_structure_file is None + else str(restored_settings.reference_structure_file) + ) + self._active_contrast_settings = restored_settings + self._active_contrast_name = ( + str(payload.get("active_contrast_name", "")).strip() or None + ) + self._active_contrast_estimate = self._restore_contrast_estimate( + payload.get("active_contrast_estimate") + if isinstance(payload.get("active_contrast_estimate"), dict) + else None + ) + if self._active_contrast_estimate is None: + try: + self._active_contrast_estimate = ( + self._estimate_solvent_density(restored_settings) + ) + except Exception: + self._active_contrast_estimate = None + if self._active_contrast_estimate is not None: + self._active_solvent_density_e_per_a3 = float( + self._active_contrast_estimate.electron_density_e_per_a3 + ) + else: + self._active_solvent_density_e_per_a3 = float( + payload.get("active_solvent_density_e_per_a3", 0.0) + ) + self._sync_density_method_controls() + self._sync_kernel_correction_option() + self.show_kernel_corrected_checkbox.setChecked( + bool(payload.get("show_kernel_corrected", False)) + and self.show_kernel_corrected_checkbox.isEnabled() + ) + self._contrast_controls_dirty = False + self._refresh_contrast_display() + + def _delete_workspace_state(self, *, announce: bool) -> None: + workspace_state_path = self._workspace_state_path() + if workspace_state_path is None or not workspace_state_path.is_file(): + return + try: + workspace_state_path.unlink() + except Exception as exc: + if announce: + self._append_status( + "Could not remove the saved 3D FFT workspace state from " + f"{workspace_state_path}: {exc}" + ) + return + if announce: + self._append_status( + f"Removed the saved 3D FFT workspace state from {workspace_state_path}." + ) + + def _sync_workspace_state(self) -> None: + if self._preview_mode or self._restoring_workspace_state: + return + workspace_state_path = self._workspace_state_path() + if workspace_state_path is None: + return + if not self._computed_profile_results: + self._delete_workspace_state(announce=False) + return + if not self._results_match_current_configuration(): + return + payload = self._build_component_summary_payload() + payload["state_kind"] = "workspace_session" + workspace_state_path.parent.mkdir(parents=True, exist_ok=True) + workspace_state_path.write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", + ) + + def _restore_workspace_state_from_disk(self) -> None: + if self._preview_mode: + return + workspace_state_path = self._workspace_state_path() + if workspace_state_path is None or not workspace_state_path.is_file(): + return + try: + payload = json.loads( + workspace_state_path.read_text(encoding="utf-8") + ) + except Exception as exc: + self._append_status( + "Could not read the saved 3D FFT workspace state from " + f"{workspace_state_path}: {exc}" + ) + return + self._restoring_workspace_state = True + try: + self._prefer_representative_structures = bool( + payload.get( + "prefer_representative_structures", + self._prefer_representative_structures, + ) + ) + self._restore_workspace_controls_from_payload(payload) + source_mode = str(payload.get("structure_source_mode", "")).strip() + if source_mode: + index = self.structure_source_combo.findData(source_mode) + if index >= 0: + self.structure_source_combo.setCurrentIndex(index) + solvent_mode = str( + payload.get("representative_solvent_mode", "") + ).strip() + if solvent_mode: + index = self.representative_solvent_mode_combo.findData( + solvent_mode + ) + if index >= 0: + self.representative_solvent_mode_combo.setCurrentIndex( + index + ) + self._refresh_available_profile_targets() + restored_results: list[_FFTProfileComputationResult] = [] + for entry in payload.get("profile_results", []): + if not isinstance(entry, dict): + continue + try: + restored_results.append( + self._deserialize_profile_result(dict(entry)) + ) + except Exception: + continue + self._computed_profile_results = { + result.target.key: result for result in restored_results + } + self._current_payload = ( + None + if not restored_results + else _FFTComputationPayload( + q_values=np.asarray( + restored_results[0].q_values, + dtype=float, + ), + profile_results=tuple(restored_results), + ) + ) + restored_signature = payload.get("trace_configuration_signature") + self._computed_profile_run_signature = ( + dict(restored_signature) + if restored_results and isinstance(restored_signature, dict) + else None + ) + restored_key = str(payload.get("active_profile_key", "")).strip() + if restored_key: + self._active_profile_key = restored_key + if self._active_profile_key not in self._computed_profile_results: + self._active_profile_key = ( + None + if not restored_results + else restored_results[0].target.key + ) + self._populate_active_profile_combo() + self._set_active_profile_target(self._active_profile_target()) + if ( + restored_results + and self._computed_profile_run_signature is None + ): + self._computed_profile_run_signature = ( + self._current_trace_configuration_signature() + ) + if self._current_payload is not None: + self._update_result_summary(self._current_payload) + self._refresh_plot_controls() + finally: + self._restoring_workspace_state = False + self._update_push_to_model_state() + + def _update_push_to_model_state(self) -> None: + enabled = bool( + (not self._preview_mode) + and self._distribution_root_dir is not None + and self._computed_profile_results + and self._results_match_current_configuration() + ) + self.push_to_model_button.setEnabled(enabled) + + def _ensure_linked_distribution_ready_for_push(self) -> None: + if ( + self._preview_mode + or self._distribution_root_dir is None + or self._project_dir is None + ): + return + metadata_path = self._distribution_root_dir / "distribution.json" + prior_weights_path = self._distribution_root_dir / ( + "md_prior_weights_predicted_structures.json" + if self._use_predicted_structure_weights + else "md_prior_weights.json" + ) + if metadata_path.is_file() and prior_weights_path.is_file(): + return + from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + ) + from saxshell.saxs.project_manager.project import ( + SAXSProjectManager, + project_artifact_paths, + ) + + project_manager = SAXSProjectManager() + settings = project_manager.load_project(self._project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_predicted_structure_weights = bool( + self._use_predicted_structure_weights + ) + settings.use_representative_structures = ( + self._current_structure_source_mode() == "representative" + ) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + if artifact_paths.root_dir.resolve() != ( + self._distribution_root_dir.expanduser().resolve() + ): + raise ValueError( + "The linked computed distribution no longer matches the " + "active project settings. Reopen the 3D FFT Born workflow " + "from Project Setup and push again." + ) + project_manager.generate_prior_weights(settings) + + def _write_component_trace_file( + self, + result: _FFTProfileComputationResult, + component_dir: Path, + ) -> str: + profile_file = self._component_profile_filename( + result.target.structure_name, + result.target.motif_name, + ) + output_path = component_dir / profile_file + q_values = np.asarray(result.q_values, dtype=float) + intensity = np.asarray(result.fft_result.raw_intensity, dtype=float) + data = np.column_stack( + [ + q_values, + intensity, + np.zeros_like(q_values, dtype=float), + np.zeros_like(q_values, dtype=float), + ] + ) + header = ( + f"# Number of files: {result.target.file_count}\n" + "# Columns: q, S(q)_avg, S(q)_std, S(q)_se\n" + ) + np.savetxt( + output_path, + data, + comments="", + header=header, + fmt=["%.8f", "%.8f", "%.8f", "%.8f"], + ) + return profile_file + + def _write_distribution_component_summary( + self, summary_path: Path + ) -> None: + summary_path.parent.mkdir(parents=True, exist_ok=True) + summary_path.write_text( + json.dumps(self._build_component_summary_payload(), indent=2) + + "\n", + encoding="utf-8", + ) + + @Slot() + def _push_components_to_model(self) -> None: + if self._preview_mode or self._distribution_root_dir is None: + self._show_error( + "Push to Model unavailable", + "This 3D FFT Born window is not linked to a computed distribution.", + ) + return + if not self._computed_profile_results: + self._show_error( + "Push to Model unavailable", + "Compute at least one 3D FFT Born profile set before pushing the outputs into the model.", + ) + return + if not self._results_match_current_configuration(): + self._show_error( + "Recompute 3D FFT Born traces", + "The active 3D FFT source, q range, grid settings, or electron density contrast has changed since these traces were computed. Recompute the traces before pushing them into the model.", + ) + return + artifact_targets = self._component_artifact_targets() + summary_path = self._component_summary_path() + if artifact_targets is None or summary_path is None: + self._show_error( + "Push to Model unavailable", + "The linked computed distribution paths are not available.", + ) + return + self._ensure_linked_distribution_ready_for_push() + component_dir, component_map_path = artifact_targets + component_dir.mkdir(parents=True, exist_ok=True) + saxs_map: dict[str, dict[str, str]] = {} + for result in self._computed_profile_results.values(): + profile_file = self._write_component_trace_file( + result, + component_dir, + ) + saxs_map.setdefault(result.target.structure_name, {}) + saxs_map[result.target.structure_name][ + result.target.motif_name + ] = profile_file + component_map_path.write_text( + json.dumps({"saxs_map": saxs_map}, indent=2) + "\n", + encoding="utf-8", + ) + self._write_distribution_component_summary(summary_path) + if self._project_dir is not None: + from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + ) + from saxshell.saxs.project_manager.project import ( + SAXSProjectManager, + project_artifact_paths, + ) + + project_manager = SAXSProjectManager() + settings = project_manager.load_project(self._project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_predicted_structure_weights = bool( + self._use_predicted_structure_weights + ) + settings.use_representative_structures = ( + self._current_structure_source_mode() == "representative" + ) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + if artifact_paths.root_dir.resolve() == ( + self._distribution_root_dir.expanduser().resolve() + ): + project_manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + built_component_source_mode=( + self._current_structure_source_mode() + ), + ) + project_manager.save_project( + settings, + refresh_registered_paths=False, + ) + else: + raise ValueError( + "The linked computed distribution no longer matches the " + "active project settings. Reopen the 3D FFT Born workflow " + "from Project Setup and push again." + ) + self._append_status( + "Pushed 3D FFT Born component traces into the linked computed " + f"distribution: {component_map_path}" + ) + self.statusBar().showMessage("3D FFT Born components pushed to model") + self.born_components_built.emit( + { + "project_dir": ( + None + if self._project_dir is None + else str(self._project_dir) + ), + "distribution_id": self._distribution_id, + "distribution_dir": ( + None + if self._distribution_root_dir is None + else str(self._distribution_root_dir) + ), + "component_dir": str(component_dir), + "component_map_path": str(component_map_path), + "component_summary_path": str(summary_path), + } + ) + + def _show_error(self, title: str, message: str) -> None: + QMessageBox.critical(self, title, message) + self.statusBar().showMessage(title) + + def _ensure_progress_dialog(self) -> SAXSProgressDialog: + if self._progress_dialog is None: + dialog = SAXSProgressDialog(self) + dialog.setModal(True) + dialog.setWindowModality(Qt.WindowModality.WindowModal) + self._progress_dialog = dialog + return self._progress_dialog + + def _begin_progress_dialog(self, message: str) -> None: + dialog = self._ensure_progress_dialog() + dialog.begin_busy( + str(message).strip() or "Running 3D FFT Born Approximation...", + title="Calculating 3D FFT Born Approximation", + ) + QApplication.processEvents() + + def _set_progress_message(self, message: str) -> None: + dialog = self._ensure_progress_dialog() + stripped = ( + str(message).strip() or "Running 3D FFT Born Approximation..." + ) + dialog.setWindowTitle("Calculating 3D FFT Born Approximation") + dialog.progress_bar.setRange(0, 0) + dialog.progress_bar.setValue(0) + dialog.progress_bar.setFormat("") + dialog.message_label.setText(stripped) + dialog.show() + dialog.raise_() + QApplication.processEvents() + + def _close_progress_dialog(self) -> None: + if self._progress_dialog is not None: + self._progress_dialog.close() + + def _ui_settings(self) -> QSettings: + return QSettings("SAXShell", "SAXS") + + def _load_auto_snap_setting(self) -> bool: + raw_value = self._ui_settings().value(AUTO_SNAP_PANES_KEY, True) + if isinstance(raw_value, bool): + return raw_value + return str(raw_value).strip().lower() not in { + "", + "0", + "false", + "no", + "off", + } + + @Slot(bool) + def _toggle_auto_snap_panes(self, enabled: bool) -> None: + self._set_auto_snap_panes_enabled(enabled, persist=True) + + def set_auto_snap_enabled(self, enabled: bool) -> None: + self._set_auto_snap_panes_enabled(enabled, persist=False) + + def _set_auto_snap_panes_enabled( + self, + enabled: bool, + *, + persist: bool, + ) -> None: + self._auto_snap_panes_enabled = bool(enabled) + if hasattr(self, "auto_snap_panes_action"): + self.auto_snap_panes_action.blockSignals(True) + self.auto_snap_panes_action.setChecked( + self._auto_snap_panes_enabled + ) + self.auto_snap_panes_action.blockSignals(False) + if hasattr(self, "_auto_snap_filter"): + self._auto_snap_filter.set_enabled(self._auto_snap_panes_enabled) + if persist: + self._ui_settings().setValue( + AUTO_SNAP_PANES_KEY, + self._auto_snap_panes_enabled, + ) + self.statusBar().showMessage( + "Auto-snap panes " + + ("enabled" if self._auto_snap_panes_enabled else "disabled") + ) + + def closeEvent( + self, event: QCloseEvent + ) -> None: # noqa: N802 - Qt override + if ( + self._compute_thread is not None + and self._compute_thread.isRunning() + ): + self._close_requested_while_running = True + self._close_progress_dialog() + if self._compute_worker is not None: + self._compute_worker.cancel() + self._compute_thread.quit() + self._compute_thread.wait(1000) + if ( + self._compute_thread is not None + and self._compute_thread.isRunning() + ): + self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose, False) + self.hide() + event.accept() + return + self._close_requested_while_running = False + else: + self._close_requested_while_running = False + self._close_progress_dialog() + if self._progress_dialog is not None: + self._progress_dialog.deleteLater() + self._progress_dialog = None + self._reference_structure_cache.clear() + self._available_profile_targets = {} + self._current_profile_targets = () + super().closeEvent(event) + + +def _forget_open_window(window: FFTBornApproximationMainWindow) -> None: + global _OPEN_WINDOWS + _OPEN_WINDOWS = [ + existing for existing in _OPEN_WINDOWS if existing is not window + ] + + +def launch_3d_fft_born_approximation_ui( + *, + initial_project_dir: str | Path | None = None, + initial_input_path: str | Path | None = None, + initial_output_dir: str | Path | None = None, + initial_project_q_min: float | None = None, + initial_project_q_max: float | None = None, + initial_distribution_id: str | None = None, + initial_distribution_root_dir: str | Path | None = None, + initial_use_predicted_structure_weights: bool = False, + initial_use_representative_structures: bool = False, + preview_mode: bool = True, +) -> FFTBornApproximationMainWindow: + app = QApplication.instance() + if app is None: + prepare_saxshell_application_identity() + app = QApplication(sys.argv) + configure_saxshell_application(app) + window = FFTBornApproximationMainWindow( + initial_project_dir=( + None + if initial_project_dir is None + else Path(initial_project_dir).expanduser().resolve() + ), + initial_input_path=( + None + if initial_input_path is None + else Path(initial_input_path).expanduser().resolve() + ), + initial_output_dir=( + None + if initial_output_dir is None + else Path(initial_output_dir).expanduser().resolve() + ), + initial_project_q_min=initial_project_q_min, + initial_project_q_max=initial_project_q_max, + initial_distribution_id=initial_distribution_id, + initial_distribution_root_dir=( + None + if initial_distribution_root_dir is None + else Path(initial_distribution_root_dir).expanduser().resolve() + ), + initial_use_predicted_structure_weights=( + initial_use_predicted_structure_weights + ), + initial_use_representative_structures=( + initial_use_representative_structures + ), + preview_mode=preview_mode, + ) + window.show() + window.raise_() + _OPEN_WINDOWS.append(window) + window.destroyed.connect( + lambda _obj=None, win=window: _forget_open_window(win) + ) + return window + + +__all__ = [ + "FFTBornApproximationMainWindow", + "launch_3d_fft_born_approximation_ui", +] diff --git a/src/saxshell/saxs/electron_density_mapping/ui/__init__.py b/src/saxshell/saxs/electron_density_mapping/ui/__init__.py index 4a54153..265111b 100644 --- a/src/saxshell/saxs/electron_density_mapping/ui/__init__.py +++ b/src/saxshell/saxs/electron_density_mapping/ui/__init__.py @@ -1,9 +1,27 @@ -from .main_window import ( - ElectronDensityMappingMainWindow, - launch_electron_density_mapping_ui, -) +from __future__ import annotations __all__ = [ "ElectronDensityMappingMainWindow", "launch_electron_density_mapping_ui", ] + + +def __getattr__(name: str): + if name in { + "ElectronDensityMappingMainWindow", + "launch_electron_density_mapping_ui", + }: + from .main_window import ( + ElectronDensityMappingMainWindow, + launch_electron_density_mapping_ui, + ) + + return { + "ElectronDensityMappingMainWindow": ( + ElectronDensityMappingMainWindow + ), + "launch_electron_density_mapping_ui": ( + launch_electron_density_mapping_ui + ), + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/saxshell/saxs/electron_density_mapping/ui/main_window.py b/src/saxshell/saxs/electron_density_mapping/ui/main_window.py index 777f020..59442b5 100644 --- a/src/saxshell/saxs/electron_density_mapping/ui/main_window.py +++ b/src/saxshell/saxs/electron_density_mapping/ui/main_window.py @@ -3,7 +3,7 @@ import csv import json import sys -from dataclasses import dataclass +from dataclasses import dataclass, replace from datetime import datetime from pathlib import Path from typing import Callable @@ -104,6 +104,9 @@ compute_electron_density_scattering_profile, compute_single_atom_debye_scattering_profile_for_input, inspect_structure_input, + legacy_born_average_default_fourier_settings, + legacy_born_average_default_mesh_settings, + legacy_born_average_default_smearing_settings, load_electron_density_structure, prepare_electron_density_fourier_transform, prepare_single_atom_debye_scattering_preview, @@ -120,6 +123,10 @@ ) from saxshell.saxs.ui.progress_dialog import SAXSProgressDialog +# Use MathText for the inverse-Angstrom exponent so the active UI font does +# not need to provide the superscript minus glyph. +Q_A_INVERSE_LABEL = "q (Å$^{-1}$)" + _OPEN_WINDOWS: list["ElectronDensityMappingMainWindow"] = [] _CLUSTER_TRACE_COLORS = ( "#b45309", @@ -129,6 +136,7 @@ "#dc2626", "#0891b2", ) +_SOLVENT_PRESET_NONE = "__none__" AUTO_SNAP_PANES_KEY = "auto_snap_panes_enabled" _FT_COLUMN_STOICHIOMETRY = 0 _FT_COLUMN_STATUS = 1 @@ -220,6 +228,9 @@ class _SavedOutputEntry: profile_result: ElectronDensityProfileResult fourier_settings: ElectronDensityFourierTransformSettings transform_result: ElectronDensityScatteringTransformResult | None = None + debye_scattering_result: ( + ElectronDensityDebyeScatteringAverageResult | None + ) = None @dataclass(slots=True) @@ -240,6 +251,37 @@ class _DebyeComparisonEntry: info_text: str +@dataclass(slots=True, frozen=True) +class _DebyeScatteringWorkItem: + group_key: str | None + label: str + inspection: ElectronDensityInputInspection + q_values: tuple[float, ...] + progress_total: int + + +@dataclass(slots=True, frozen=True) +class _DebyeScatteringCompletedItem: + group_key: str | None + label: str + result: ElectronDensityDebyeScatteringAverageResult + + +@dataclass(slots=True, frozen=True) +class _DebyeScatteringRunPayload: + results: tuple[_DebyeScatteringCompletedItem, ...] = () + failures: tuple[str, ...] = () + + +@dataclass(slots=True, frozen=True) +class _DebyeScatteringRunContext: + mode: str + scope_label: str = "" + selected_keys: tuple[str, ...] = () + skipped_pending: int = 0 + skipped_single_atom: int = 0 + + def _build_cluster_group_states_for_path( path: Path, *, @@ -547,7 +589,9 @@ def _build_ui(self) -> None: self.status_label = QLabel( "All selected entries are overlaid on shared axes. " - "Use the trace table to show/hide or recolour individual traces." + "Use the trace table to show/hide or recolour individual traces. " + "Saved Debye traces, when available, use dashed lines on the " + "right axis in the q-space panel." ) self.status_label.setWordWrap(True) self.status_label.setStyleSheet("color: #475569;") @@ -929,8 +973,10 @@ def _draw_scattering_profile(self) -> None: pairs = self._visible_entries() fig = self._scatter_plot.figure fig.clear() - ax = fig.add_subplot(111) + born_axis = fig.add_subplot(111) + debye_axis = None has_scatter = False + has_debye = False use_log_q = any( bool(e.fourier_settings.log_q_axis) for e, _ in pairs @@ -941,12 +987,15 @@ def _draw_scattering_profile(self) -> None: for e, _ in pairs if e.transform_result is not None ) + born_positive_i = True + debye_positive_i = True + plotted_lines: list[object] = [] for entry, color in pairs: - if entry.transform_result is None: + transform_result = entry.transform_result + if transform_result is None: continue - result = entry.transform_result - q = np.asarray(result.q_values, dtype=float) - intensity = np.asarray(result.intensity, dtype=float) + q = np.asarray(transform_result.q_values, dtype=float) + intensity = np.asarray(transform_result.intensity, dtype=float) mask = np.ones_like(q, dtype=bool) if use_log_q: mask &= q > 0.0 @@ -954,41 +1003,94 @@ def _draw_scattering_profile(self) -> None: mask &= intensity > 0.0 if not mask.any(): continue + born_positive_i = born_positive_i and bool(np.all(intensity > 0.0)) label = ( entry.group_label or ElectronDensityMappingMainWindow._saved_output_context_label( entry ) ) - ax.plot( + born_label = ( + f"{label} · Born" + if ElectronDensityMappingMainWindow._saved_output_entry_debye_result( + entry + ) + is not None + else label + ) + (born_line,) = born_axis.plot( q[mask], intensity[mask], color=color, linewidth=1.8, - label=label, + label=born_label, ) + plotted_lines.append(born_line) has_scatter = True + debye_result = ElectronDensityMappingMainWindow._saved_output_entry_debye_result( + entry + ) + if debye_result is None: + continue + debye_q = np.asarray(debye_result.q_values, dtype=float) + debye_i = np.asarray(debye_result.mean_intensity, dtype=float) + debye_mask = np.ones_like(debye_q, dtype=bool) + if use_log_q: + debye_mask &= debye_q > 0.0 + if use_log_i: + debye_mask &= debye_i > 0.0 + if not debye_mask.any(): + continue + if debye_axis is None: + debye_axis = born_axis.twinx() + debye_positive_i = debye_positive_i and bool(np.all(debye_i > 0.0)) + (debye_line,) = debye_axis.plot( + debye_q[debye_mask], + debye_i[debye_mask], + color=color, + linewidth=1.8, + linestyle="--", + label=f"{label} · Debye", + ) + plotted_lines.append(debye_line) + has_debye = True if not has_scatter: - ax.text( + born_axis.text( 0.5, 0.5, "No Fourier transform results available for selected entries.", ha="center", va="center", - transform=ax.transAxes, + transform=born_axis.transAxes, ) - ax.set_axis_off() + born_axis.set_axis_off() self._scatter_plot.canvas.draw_idle() return if use_log_q: - ax.set_xscale("log") - if use_log_i: - ax.set_yscale("log") - ax.set_xlabel("q (Å⁻¹)", labelpad=10.0) - ax.set_ylabel("Intensity (arb. units)") - ax.set_title("q-Space Scattering Profile") - ax.grid(True, which="both", alpha=0.28) - ax.legend(loc="upper right", frameon=True, fontsize=8) + born_axis.set_xscale("log") + if debye_axis is not None: + debye_axis.set_xscale("log") + if use_log_i and born_positive_i: + born_axis.set_yscale("log") + if debye_axis is not None and use_log_i and debye_positive_i: + debye_axis.set_yscale("log") + born_axis.set_xlabel(Q_A_INVERSE_LABEL, labelpad=10.0) + born_axis.set_ylabel( + "Born Approximation Intensity (arb. units)" + if has_debye + else "Intensity (arb. units)" + ) + if debye_axis is not None: + debye_axis.set_ylabel("Debye Scattering Intensity (arb. units)") + born_axis.set_title("q-Space Scattering Profile") + born_axis.grid(True, which="both", alpha=0.28) + born_axis.legend( + plotted_lines, + [line.get_label() for line in plotted_lines], + loc="upper right", + frameon=True, + fontsize=8, + ) fig.tight_layout() self._scatter_plot.canvas.draw_idle() @@ -1069,6 +1171,11 @@ def _export_all_csvs(self) -> None: profile_result=entry.profile_result, fourier_preview=preview, transform_result=entry.transform_result, + debye_scattering_result=( + ElectronDensityMappingMainWindow._saved_output_entry_debye_result( + entry + ) + ), ) written_count += 1 except Exception as exc: @@ -1355,7 +1462,7 @@ def _refresh_plot(self) -> None: if self.log_y_checkbox.isChecked() and debye_positive_i else "linear" ) - born_axis.set_xlabel("q (Å⁻¹)") + born_axis.set_xlabel(Q_A_INVERSE_LABEL) born_axis.set_ylabel("Born Approximation Intensity (arb. units)") debye_axis.set_ylabel("Debye Scattering Intensity (arb. units)") born_axis.set_title("Born Approximation vs Debye Scattering") @@ -1833,25 +1940,159 @@ def emit_group_progress( self.failed.emit(str(exc)) +class ElectronDensityDebyeScatteringWorker(QObject): + progress = Signal(int, int, str) + finished = Signal(object) + canceled = Signal(object) + failed = Signal(str) + + def __init__( + self, + *, + items: tuple[_DebyeScatteringWorkItem, ...], + scope_label: str | None = None, + ) -> None: + super().__init__() + self._items = tuple(items) + self._scope_label = str(scope_label).strip() + self._cancel_requested = False + + @Slot() + def cancel(self) -> None: + self._cancel_requested = True + + def _progress_prefix( + self, + item_index: int, + item_count: int, + item: _DebyeScatteringWorkItem, + ) -> str: + label = str(item.label).strip() + if item_count <= 1 and item.group_key is None: + return "Debye scattering" + prefix = ( + f"Debye {item_index}/{item_count}" + if item_count > 1 + else "Debye scattering" + ) + if label: + prefix += f" [{label}]" + if item_count > 1 and self._scope_label: + prefix += f" in {self._scope_label}" + return prefix + + @Slot() + def run(self) -> None: + if not self._items: + self.failed.emit( + "Debye scattering calculation started without any targets." + ) + return + results: list[_DebyeScatteringCompletedItem] = [] + failures: list[str] = [] + overall_total = max( + sum(max(int(item.progress_total), 1) for item in self._items), + 1, + ) + progress_offset = 0 + try: + for item_index, item in enumerate(self._items, start=1): + if self._cancel_requested: + raise ElectronDensityCalculationCanceled( + "Debye scattering calculation was stopped by the user." + ) + item_total = max(int(item.progress_total), 1) + prefix = self._progress_prefix( + item_index, + len(self._items), + item, + ) + + def emit_item_progress( + current: int, + total: int, + message: str, + *, + _offset: int = progress_offset, + _overall_total: int = overall_total, + _prefix: str = prefix, + ) -> None: + bounded_total = max(int(total), 1) + bounded_current = min( + max(int(current), 0), + bounded_total, + ) + text = str(message).strip() + self.progress.emit( + _offset + bounded_current, + _overall_total, + ( + f"{_prefix}: {text}" + if text + else f"{_prefix}: running..." + ), + ) + + try: + result = ( + compute_average_debye_scattering_profile_for_input( + item.inspection, + q_values=np.asarray(item.q_values, dtype=float), + progress_callback=emit_item_progress, + cancel_callback=lambda: self._cancel_requested, + ) + ) + except ElectronDensityCalculationCanceled: + raise + except Exception as exc: + failures.append(f"{item.label}: {exc}") + progress_offset += item_total + self.progress.emit( + progress_offset, + overall_total, + f"{prefix} failed.", + ) + continue + if self._cancel_requested: + raise ElectronDensityCalculationCanceled( + "Debye scattering calculation was stopped by the user." + ) + results.append( + _DebyeScatteringCompletedItem( + group_key=item.group_key, + label=item.label, + result=result, + ) + ) + progress_offset += item_total + self.finished.emit( + _DebyeScatteringRunPayload( + results=tuple(results), + failures=tuple(failures), + ) + ) + except ElectronDensityCalculationCanceled: + self.canceled.emit( + _DebyeScatteringRunPayload( + results=tuple(results), + failures=tuple(failures), + ) + ) + except Exception as exc: + self.failed.emit(str(exc)) + + class ElectronDensityMappingMainWindow(QMainWindow): """Interactive supporting tool for radial electron-density inspection.""" born_components_built = Signal(object) cancel_calculation_requested = Signal() + cancel_debye_scattering_requested = Signal() @staticmethod def _default_fourier_settings() -> ElectronDensityFourierTransformSettings: - return ElectronDensityFourierTransformSettings( - r_min=-1.0, - r_max=1.0, - domain_mode="mirrored", - window_function="hanning", - q_min=0.02, - q_max=1.2, - q_step=0.01, - resampling_points=2048, - ) + return legacy_born_average_default_fourier_settings() def __init__( self, @@ -1912,9 +2153,13 @@ def __init__( ) self._current_group_run_manual = False self._auto_snap_panes_enabled = self._load_auto_snap_panes_setting() - self._active_mesh_settings = ElectronDensityMeshSettings() + self._active_mesh_settings = ( + legacy_born_average_default_mesh_settings() + ) self._active_mesh_geometry: ElectronDensityMeshGeometry | None = None - self._active_smearing_settings = ElectronDensitySmearingSettings() + self._active_smearing_settings = ( + legacy_born_average_default_smearing_settings() + ) self._active_fourier_settings = self._default_fourier_settings() self._solvent_presets: dict[str, ContrastSolventPreset] = {} self._active_contrast_settings: ( @@ -1941,8 +2186,18 @@ def __init__( self._workspace_load_worker: ( ElectronDensityWorkspaceLoadWorker | None ) = None + self._debye_scattering_thread: QThread | None = None + self._debye_scattering_worker: ( + ElectronDensityDebyeScatteringWorker | None + ) = None self._workspace_load_progress_dialog: SAXSProgressDialog | None = None self._batch_operation_progress_dialog: SAXSProgressDialog | None = None + self._debye_scattering_progress_dialog: SAXSProgressDialog | None = ( + None + ) + self._active_debye_scattering_context: ( + _DebyeScatteringRunContext | None + ) = None self._restore_distribution_state_after_workspace_load = False self._saved_output_entries: list[_SavedOutputEntry] = [] self._output_history_compare_dialog: QDialog | None = None @@ -2499,6 +2754,16 @@ def _close_batch_operation_progress_dialog(self) -> None: if self._batch_operation_progress_dialog is not None: self._batch_operation_progress_dialog.close() + def _ensure_debye_scattering_progress_dialog( + self, + ) -> SAXSProgressDialog: + if self._debye_scattering_progress_dialog is None: + dialog = SAXSProgressDialog(self) + dialog.setModal(True) + dialog.setWindowModality(Qt.WindowModality.WindowModal) + self._debye_scattering_progress_dialog = dialog + return self._debye_scattering_progress_dialog + @Slot(int, int, str) def _on_workspace_load_progress( self, @@ -3101,6 +3366,8 @@ def _reset_debye_scattering_progress(self) -> None: self.debye_scattering_progress_bar.setValue(0) self.debye_scattering_progress_bar.setFormat("%v / %m steps") self.debye_scattering_progress_bar.setHidden(True) + if self._debye_scattering_progress_dialog is not None: + self._debye_scattering_progress_dialog.close() def _begin_debye_scattering_progress( self, @@ -3119,6 +3386,13 @@ def _begin_debye_scattering_progress( self.debye_scattering_progress_bar.setValue(0) self.debye_scattering_progress_bar.setFormat("%v / %m steps") self.debye_scattering_status_label.setText(stripped) + dialog = self._ensure_debye_scattering_progress_dialog() + dialog.begin( + bounded_total, + stripped, + unit_label="steps", + title="Computing Debye Scattering", + ) self.statusBar().showMessage(stripped) QApplication.processEvents() @@ -3137,6 +3411,13 @@ def _update_debye_scattering_progress( self.debye_scattering_progress_bar.setRange(0, bounded_total) self.debye_scattering_progress_bar.setValue(bounded_current) self.debye_scattering_progress_bar.setFormat("%v / %m steps") + if self._debye_scattering_progress_dialog is not None: + self._debye_scattering_progress_dialog.update_progress( + bounded_current, + bounded_total, + stripped, + unit_label="steps", + ) if stripped: self.debye_scattering_status_label.setText(stripped) self.statusBar().showMessage(stripped) @@ -4691,6 +4972,8 @@ def _format_saved_output_timestamp(timestamp: str) -> str: @staticmethod def _saved_output_entry_kind_label(entry_kind: str) -> str: normalized = str(entry_kind).strip().lower() + if normalized == "debye_scattering": + return "Debye Scattering" if normalized == "fourier_transform": return "Fourier Transform" if normalized == "smearing": @@ -4738,6 +5021,14 @@ def _saved_output_entry_summary_text(entry: _SavedOutputEntry) -> str: if entry.transform_result is not None else "Preview only" ) + debye_text = "" + if entry.debye_scattering_result is not None: + debye_text = ( + " Debye: " + f"{entry.debye_scattering_result.source_structure_count} structure" + f"{'' if entry.debye_scattering_result.source_structure_count == 1 else 's'} " + "on the stored Born q-grid." + ) return ( f"Saved {ElectronDensityMappingMainWindow._format_saved_output_timestamp(entry.created_at)} " f"in {'Preview' if entry.preview_mode else 'Computed Distribution'} mode. " @@ -4747,6 +5038,7 @@ def _saved_output_entry_summary_text(entry: _SavedOutputEntry) -> str: f"Solvent: {solvent_text}. " f"Fourier: {fourier_text} with window={entry.fourier_settings.window_function}, " f"r={entry.fourier_settings.r_min:.3f}–{entry.fourier_settings.r_max:.3f} Å." + + debye_text ) @staticmethod @@ -4783,8 +5075,9 @@ def _update_output_history_summary(self) -> None: ) self.output_history_summary_label.setText( "Density calculations, solvent-subtracted outputs, Fourier " - "evaluations, and optional smearing snapshots will be captured " - "here for reload and comparison." + persistence_text + "evaluations, Debye scattering comparisons, and optional " + "smearing snapshots will be captured here for reload and " + "comparison." + persistence_text ) return self.output_history_summary_label.setText( @@ -4866,6 +5159,11 @@ def _populate_output_history_table(self) -> None: if entry.transform_result is not None else " (preview)" ) + + ( + "" + if entry.debye_scattering_result is None + else " · Debye saved" + ) ), ] for column_index, value in enumerate(values): @@ -5431,14 +5729,13 @@ def _build_fourier_transform_group(self) -> QWidget: intro = QLabel( "Prepare a spherical Born-approximation transform of the smeared " - "electron-density profile into q-space. The preview panel shows " - "the mirrored real-space source used by the transform. Mirrored " - "mode is the default: it reflects the profile about r = 0 and " - "evaluates the windowed transform over -rmax to rmax. Toggle " - "legacy mode to restore the historical rmin to rmax behavior. In " - "Apply to All mode, the table becomes the editable per-stoichiometry " - "Fourier settings view: q settings stay shared, while each row " - "keeps its own r range." + "electron-density profile into q-space. The validated default uses " + "a 0 to rmax transform with no window, matching the current Born " + "versus Debye backend comparison settings. Clear the legacy toggle " + "to mirror the profile about r = 0 for an EXAFS-style transform. " + "In Apply to All mode, the table becomes the editable " + "per-stoichiometry Fourier settings view: q settings stay shared, " + "while each row keeps its own r range." ) intro.setWordWrap(True) outer.addWidget(intro) @@ -5722,8 +6019,9 @@ def _build_output_history_group(self) -> QWidget: group = QGroupBox("Saved Output Sets") layout = QVBoxLayout(group) self.output_history_summary_label = QLabel( - "Density calculations, solvent-subtracted outputs, and Fourier " - "evaluations will be captured here for reload and comparison." + "Density calculations, solvent-subtracted outputs, Fourier " + "evaluations, and Debye scattering comparisons will be captured " + "here for reload and comparison." ) self.output_history_summary_label.setWordWrap(True) layout.addWidget(self.output_history_summary_label) @@ -6411,22 +6709,47 @@ def _refresh_profile_plots(self) -> None: self.residual_section.expand() def _selected_solvent_preset_name(self) -> str | None: - return self.solvent_preset_combo.currentData() + selected_data = self.solvent_preset_combo.currentData() + if selected_data is None: + return None + preset_name = str(selected_data).strip() + if preset_name == _SOLVENT_PRESET_NONE: + return None + return preset_name or None + + def _selected_solvent_preset_token(self) -> str | None: + selected_data = self.solvent_preset_combo.currentData() + if selected_data is None: + return None + preset_name = str(selected_data).strip() + return preset_name or None + + def _clear_solvent_contrast_requested_from_controls(self) -> bool: + method = str(self.solvent_method_combo.currentData() or "").strip() + return ( + method == CONTRAST_SOLVENT_METHOD_NEAT + and self._selected_solvent_preset_token() == _SOLVENT_PRESET_NONE + ) def _reload_solvent_presets( self, *, selected_name: str | None = None, ) -> None: - previous_name = selected_name or self._selected_solvent_preset_name() + previous_name = ( + selected_name + if selected_name is not None + else self._selected_solvent_preset_token() + ) self._solvent_presets = load_solvent_presets() self.solvent_preset_combo.blockSignals(True) self.solvent_preset_combo.clear() self.solvent_preset_combo.addItem("Custom entry", None) - selected_index = 0 + self.solvent_preset_combo.addItem("None", _SOLVENT_PRESET_NONE) + selected_index = 1 if previous_name == _SOLVENT_PRESET_NONE else 0 for index, preset_name in enumerate( ordered_solvent_preset_names(self._solvent_presets), - start=1, + start=2, ): preset = self._solvent_presets[preset_name] label = ( @@ -6441,14 +6764,20 @@ def _reload_solvent_presets( @Slot() def _load_selected_solvent_preset(self) -> None: + if self._selected_solvent_preset_token() == _SOLVENT_PRESET_NONE: + self.delete_custom_solvent_button.setEnabled(False) + self._sync_density_method_controls() + return preset_name = self._selected_solvent_preset_name() preset = self._solvent_presets.get(preset_name or "") if preset is None: self.delete_custom_solvent_button.setEnabled(False) + self._sync_density_method_controls() return self.solvent_formula_edit.setText(preset.formula) self.solvent_density_spin.setValue(preset.density_g_per_ml) self.delete_custom_solvent_button.setEnabled(not preset.builtin) + self._sync_density_method_controls() @Slot() def _save_current_solvent_preset(self) -> None: @@ -6559,76 +6888,275 @@ def _sync_density_method_controls(self) -> None: "Use 0.0 e-/ ų to model vacuum." ) else: - self.solvent_method_hint_label.setText( - "Quick estimate mode uses the selected solvent stoichiometry and " - "density. Built-in presets include Water, Vacuum, DMF, and DMSO." - ) + if self._clear_solvent_contrast_requested_from_controls(): + self.solvent_method_hint_label.setText( + "The None solvent option clears the active solvent " + "subtraction from the current density profile or selected " + "stoichiometries." + ) + else: + self.solvent_method_hint_label.setText( + "Quick estimate mode uses the selected solvent " + "stoichiometry and density. Built-in presets include " + "Water, Vacuum, DMF, and DMSO." + ) - @Slot() - def _choose_reference_solvent_file(self) -> None: - start_dir = ( - str( - Path(self.reference_solvent_file_edit.text()) - .expanduser() - .resolve() - .parent - ) - if self.reference_solvent_file_edit.text().strip() - else str(self._project_dir or Path.cwd()) + @staticmethod + def _clear_solvent_contrast_from_profile_result( + result: ElectronDensityProfileResult, + ) -> ElectronDensityProfileResult: + if result.solvent_contrast is None: + return result + return replace(result, solvent_contrast=None) + + def _clear_active_solvent_contrast(self) -> None: + had_pending_configuration = ( + self._active_contrast_settings is not None + or self._active_contrast_name is not None ) - selected_path, _selected_filter = QFileDialog.getOpenFileName( - self, - "Choose Reference Solvent Structure", - start_dir, - "Structure files (*.pdb *.xyz);;All files (*)", + had_applied_contrast = ( + self._profile_result is not None + and self._profile_result.solvent_contrast is not None ) - if not selected_path: + self._active_contrast_settings = None + self._active_contrast_name = None + if not had_applied_contrast: + self._refresh_contrast_display() + self._append_status( + "Cleared the configured solvent subtraction. Future density " + "runs will stay unsubtracted." + if had_pending_configuration + else "Solvent subtraction was already cleared." + ) + self.statusBar().showMessage("Cleared solvent subtraction") + self._sync_workspace_state() return - self.reference_solvent_file_edit.setText( - str(Path(selected_path).expanduser().resolve()) + self._profile_result = ( + self._clear_solvent_contrast_from_profile_result( + self._profile_result + ) + ) + self._debye_scattering_result = None + self._close_debye_scattering_compare_dialog() + self._sync_fourier_controls_to_domain(reset_bounds=False) + self._refresh_profile_plots() + self._refresh_contrast_display() + self._refresh_fourier_preview_from_controls(clear_transform=True) + self._append_status( + "Removed solvent subtraction from the active density profile." ) + self.statusBar().showMessage("Cleared solvent subtraction") - def _contrast_settings_from_controls( + def _clear_solvent_contrast_from_target_clusters( self, - ) -> ContrastSolventDensitySettings: - method = str(self.solvent_method_combo.currentData() or "").strip() - if method == CONTRAST_SOLVENT_METHOD_REFERENCE: - reference_path = self.reference_solvent_file_edit.text().strip() - if not reference_path: - raise ValueError( - "Choose a reference solvent XYZ or PDB file before computing the solvent electron density." - ) - return ContrastSolventDensitySettings.from_values( - method=CONTRAST_SOLVENT_METHOD_REFERENCE, - reference_structure_file=reference_path, - ) - if method == CONTRAST_SOLVENT_METHOD_DIRECT: - direct_density = float(self.direct_density_spin.value()) - if direct_density < 0.0: - raise ValueError( - "Enter a non-negative direct solvent electron density before computing the solvent contrast." - ) - return ContrastSolventDensitySettings.from_values( - method=CONTRAST_SOLVENT_METHOD_DIRECT, - direct_electron_density_e_per_a3=direct_density, + *, + apply_to_all: bool, + ) -> None: + if not self._cluster_group_states: + self._show_error( + "No Stoichiometry Table", + "Load a cluster-folder input before removing solvent " + "subtraction across stoichiometries.", ) - formula = self.solvent_formula_edit.text().strip() - if not formula: - raise ValueError( - "Enter a solvent stoichiometry formula before computing the solvent electron density." + return + ( + target_states, + scope_label, + selected_keys, + ) = self._batch_target_cluster_group_states(apply_to_all=apply_to_all) + if not target_states: + self._show_error( + "No Stoichiometries Selected", + "Select at least one stoichiometry row before removing " + "solvent subtraction.", ) - return ContrastSolventDensitySettings.from_values( - method=CONTRAST_SOLVENT_METHOD_NEAT, - solvent_formula=formula, - solvent_density_g_per_ml=self.solvent_density_spin.value(), - ) - - def _contrast_display_name_from_controls(self) -> str: - method = str(self.solvent_method_combo.currentData() or "").strip() - if method == CONTRAST_SOLVENT_METHOD_REFERENCE: - reference_path = self.reference_solvent_file_edit.text().strip() - return ( - Path(reference_path).stem + return + self._active_contrast_settings = None + self._active_contrast_name = None + updated_count = 0 + already_clear_count = 0 + skipped_pending = 0 + skipped_debye = 0 + total_targets = len(target_states) + self._begin_batch_operation_progress( + total=total_targets, + message=( + "Preparing batch solvent-subtraction removal across " + f"{scope_label}..." + ), + title="Removing Solvent Subtraction", + ) + try: + for index, state in enumerate(target_states, start=1): + self._update_batch_operation_progress( + index - 1, + total_targets, + "Removing solvent subtraction from " + f"{state.display_name} ({index}/{total_targets}).", + ) + if state.single_atom_only: + skipped_debye += 1 + self._update_batch_operation_progress( + index, + total_targets, + "Skipped Debye-only stoichiometry " + f"{state.display_name} ({index}/{total_targets}).", + ) + continue + if state.profile_result is None: + skipped_pending += 1 + self._update_batch_operation_progress( + index, + total_targets, + "Skipped pending stoichiometry " + f"{state.display_name} ({index}/{total_targets}).", + ) + continue + if state.profile_result.solvent_contrast is None: + already_clear_count += 1 + self._update_batch_operation_progress( + index, + total_targets, + "No solvent subtraction was active for " + f"{state.display_name} ({index}/{total_targets}).", + ) + continue + preserved_fourier_settings = ( + self._cluster_state_fourier_settings(state) + ) + state.profile_result = ( + self._clear_solvent_contrast_from_profile_result( + state.profile_result + ) + ) + state.transform_result = None + state.debye_scattering_result = None + self._sync_cluster_state_solvent_metadata(state) + self._set_cluster_state_fourier_settings( + state, + preserved_fourier_settings, + prefer_solvent_cutoff=False, + ) + updated_count += 1 + self._update_batch_operation_progress( + index, + total_targets, + "Removed solvent subtraction from " + f"{state.display_name} ({index}/{total_targets}).", + ) + self._update_batch_operation_progress( + total_targets, + total_targets, + "Refreshing stoichiometry views...", + ) + if updated_count > 0: + self._close_debye_scattering_compare_dialog() + self._refresh_cluster_views_after_batch_update( + target_states, + selected_keys=selected_keys, + ) + self._update_batch_operation_progress( + total_targets, + total_targets, + "Batch solvent-subtraction removal complete.", + ) + finally: + self._close_batch_operation_progress_dialog() + summary_parts = [ + f"removed solvent subtraction from {updated_count} stoichiometr" + f"{'y' if updated_count == 1 else 'ies'}" + ] + if already_clear_count > 0: + summary_parts.append( + f"left {already_clear_count} already-clear row" + f"{'' if already_clear_count == 1 else 's'} unchanged" + ) + if skipped_pending > 0: + summary_parts.append( + f"skipped {skipped_pending} pending density row" + f"{'' if skipped_pending == 1 else 's'}" + ) + if skipped_debye > 0: + summary_parts.append( + f"skipped {skipped_debye} Debye-only row" + f"{'' if skipped_debye == 1 else 's'}" + ) + self._append_status( + "Batch solvent update: " + + "; ".join(summary_parts) + + f" across {scope_label}. Future density runs will stay " + "unsubtracted until a new solvent is applied." + ) + self.statusBar().showMessage( + "Cleared solvent subtraction across batch" + ) + self._sync_workspace_state() + + @Slot() + def _choose_reference_solvent_file(self) -> None: + start_dir = ( + str( + Path(self.reference_solvent_file_edit.text()) + .expanduser() + .resolve() + .parent + ) + if self.reference_solvent_file_edit.text().strip() + else str(self._project_dir or Path.cwd()) + ) + selected_path, _selected_filter = QFileDialog.getOpenFileName( + self, + "Choose Reference Solvent Structure", + start_dir, + "Structure files (*.pdb *.xyz);;All files (*)", + ) + if not selected_path: + return + self.reference_solvent_file_edit.setText( + str(Path(selected_path).expanduser().resolve()) + ) + + def _contrast_settings_from_controls( + self, + ) -> ContrastSolventDensitySettings: + method = str(self.solvent_method_combo.currentData() or "").strip() + if method == CONTRAST_SOLVENT_METHOD_REFERENCE: + reference_path = self.reference_solvent_file_edit.text().strip() + if not reference_path: + raise ValueError( + "Choose a reference solvent XYZ or PDB file before computing the solvent electron density." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_REFERENCE, + reference_structure_file=reference_path, + ) + if method == CONTRAST_SOLVENT_METHOD_DIRECT: + direct_density = float(self.direct_density_spin.value()) + if direct_density < 0.0: + raise ValueError( + "Enter a non-negative direct solvent electron density before computing the solvent contrast." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_DIRECT, + direct_electron_density_e_per_a3=direct_density, + ) + formula = self.solvent_formula_edit.text().strip() + if not formula: + raise ValueError( + "Enter a solvent stoichiometry formula before computing the solvent electron density." + ) + return ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_NEAT, + solvent_formula=formula, + solvent_density_g_per_ml=self.solvent_density_spin.value(), + ) + + def _contrast_display_name_from_controls(self) -> str: + method = str(self.solvent_method_combo.currentData() or "").strip() + if method == CONTRAST_SOLVENT_METHOD_REFERENCE: + reference_path = self.reference_solvent_file_edit.text().strip() + return ( + Path(reference_path).stem if reference_path else "Reference solvent" ) @@ -6805,6 +7333,9 @@ def _apply_active_contrast_to_profile( @Slot() def _compute_solvent_contrast(self) -> None: + if self._clear_solvent_contrast_requested_from_controls(): + self._clear_active_solvent_contrast() + return try: self._active_contrast_settings = ( self._contrast_settings_from_controls() @@ -6838,6 +7369,11 @@ def _apply_solvent_contrast_to_target_clusters( *, apply_to_all: bool, ) -> None: + if self._clear_solvent_contrast_requested_from_controls(): + self._clear_solvent_contrast_from_target_clusters( + apply_to_all=apply_to_all + ) + return if not self._cluster_group_states: self._show_error( "No Stoichiometry Table", @@ -8069,9 +8605,19 @@ def _sync_controls_to_structure( if self._structure is None: return if sync_mesh_rmax: - self.rmax_spin.setValue(max(float(self._structure.rmax), 0.01)) + default_mesh_settings = legacy_born_average_default_mesh_settings( + self._structure + ) + self.rmax_spin.setValue(float(default_mesh_settings.rmax)) + if ( + self._active_mesh_geometry is None + and self._manual_mesh_lock_settings is None + ): + self._active_mesh_settings = default_mesh_settings self._sync_reference_element_controls() self._refresh_center_display() + self._refresh_active_mesh_display() + self._refresh_mesh_notice() self._refresh_structure_summary() def _refresh_active_mesh_display(self) -> None: @@ -9027,6 +9573,13 @@ def _serialize_saved_output_entry( entry.transform_result ) ), + "debye_scattering_result": ( + None + if entry.debye_scattering_result is None + else ElectronDensityMappingMainWindow._serialize_debye_scattering_result( + entry.debye_scattering_result + ) + ), } @staticmethod @@ -9056,6 +9609,12 @@ def _deserialize_saved_output_entry( single_atom_only=False, ) ) + debye_payload = payload.get("debye_scattering_result") + debye_scattering_result = None + if isinstance(debye_payload, dict): + debye_scattering_result = ElectronDensityMappingMainWindow._deserialize_debye_scattering_result( + debye_payload + ) input_path = payload.get("input_path") return _SavedOutputEntry( entry_id=str(payload.get("entry_id") or ""), @@ -9084,6 +9643,7 @@ def _deserialize_saved_output_entry( profile_result=profile_result, fourier_settings=fourier_settings, transform_result=transform_result, + debye_scattering_result=debye_scattering_result, ) def _current_input_path(self) -> Path | None: @@ -9224,6 +9784,18 @@ def _preview_for_saved_output_entry( except Exception: return None + @staticmethod + def _saved_output_entry_debye_result( + entry: _SavedOutputEntry, + ) -> ElectronDensityDebyeScatteringAverageResult | None: + debye_result = entry.debye_scattering_result + if ElectronDensityMappingMainWindow._debye_result_matches_transform( + entry.transform_result, + debye_result, + ): + return debye_result + return None + @staticmethod def _constrain_fourier_settings_for_result( result: ElectronDensityProfileResult, @@ -9314,9 +9886,13 @@ def _capture_saved_output_entry( transform_result: ( ElectronDensityScatteringTransformResult | None ) = None, + debye_scattering_result: ( + ElectronDensityDebyeScatteringAverageResult | None + ) = None, ) -> None: if self._restoring_saved_output_history: return + normalized_entry_kind = str(entry_kind).strip() or "density" result = profile_result if result is None and group_state is not None: result = group_state.profile_result @@ -9329,6 +9905,15 @@ def _capture_saved_output_entry( current_transform = group_state.transform_result elif current_transform is None: current_transform = self._fourier_result + current_debye_result = debye_scattering_result + if ( + current_debye_result is None + and normalized_entry_kind == "debye_scattering" + ): + if group_state is not None: + current_debye_result = group_state.debye_scattering_result + else: + current_debye_result = self._debye_scattering_result if current_transform is not None: fourier_settings = current_transform.preview.settings else: @@ -9348,7 +9933,7 @@ def _capture_saved_output_entry( snapshot = _SavedOutputEntry( entry_id=datetime.now().strftime("%Y%m%dT%H%M%S%f"), created_at=datetime.now().isoformat(timespec="seconds"), - entry_kind=str(entry_kind).strip() or "density", + entry_kind=normalized_entry_kind, input_path=self._current_input_path(), output_basename=self._output_basename(), preview_mode=self._preview_mode, @@ -9360,6 +9945,7 @@ def _capture_saved_output_entry( profile_result=result, fourier_settings=fourier_settings, transform_result=current_transform, + debye_scattering_result=current_debye_result, ) self._saved_output_entries.append(snapshot) self._populate_output_history_table() @@ -9444,7 +10030,9 @@ def _restore_saved_output_entry( if target_state is not None: target_state.profile_result = entry.profile_result target_state.transform_result = entry.transform_result - target_state.debye_scattering_result = None + target_state.debye_scattering_result = ( + entry.debye_scattering_result + ) self._sync_cluster_state_solvent_metadata(target_state) self._populate_cluster_group_table() if row_index >= 0: @@ -9457,7 +10045,7 @@ def _restore_saved_output_entry( else: self._profile_result = entry.profile_result self._fourier_result = entry.transform_result - self._debye_scattering_result = None + self._debye_scattering_result = entry.debye_scattering_result self._structure = entry.profile_result.structure self._active_mesh_settings = ( entry.profile_result.mesh_geometry.settings @@ -9889,37 +10477,272 @@ def _calculate_debye_scattering_action(self) -> None: progress_total = self._debye_scattering_progress_total_for_inspection( self._inspection ) + self._start_debye_scattering_run( + items=( + _DebyeScatteringWorkItem( + group_key=None, + label="Active input", + inspection=self._inspection, + q_values=tuple( + float(value) + for value in np.asarray( + self._fourier_result.q_values, + dtype=float, + ) + ), + progress_total=progress_total, + ), + ), + context=_DebyeScatteringRunContext(mode="single"), + initial_message="Preparing Debye scattering average calculation...", + ) + + def _start_debye_scattering_run( + self, + *, + items: tuple[_DebyeScatteringWorkItem, ...], + context: _DebyeScatteringRunContext, + initial_message: str, + ) -> None: + if ( + self._debye_scattering_thread is not None + and self._debye_scattering_thread.isRunning() + ): + return + progress_total = max( + sum(max(int(item.progress_total), 1) for item in items), + 1, + ) + self._active_debye_scattering_context = context self._set_calculation_running(True) - self.stop_calculation_button.setEnabled(False) self._begin_debye_scattering_progress( total=progress_total, - message="Preparing Debye scattering average calculation...", + message=initial_message, ) - result: ElectronDensityDebyeScatteringAverageResult | None = None - error_message: str | None = None - try: - result = compute_average_debye_scattering_profile_for_input( - self._inspection, - q_values=np.asarray( - self._fourier_result.q_values, dtype=float - ), - progress_callback=self._update_debye_scattering_progress, + self._debye_scattering_thread = QThread(self) + self._debye_scattering_worker = ElectronDensityDebyeScatteringWorker( + items=items, + scope_label=context.scope_label or None, + ) + self._debye_scattering_worker.moveToThread( + self._debye_scattering_thread + ) + self._debye_scattering_thread.started.connect( + self._debye_scattering_worker.run + ) + self.cancel_debye_scattering_requested.connect( + self._debye_scattering_worker.cancel + ) + self._debye_scattering_worker.progress.connect( + self._on_debye_scattering_progress + ) + self._debye_scattering_worker.finished.connect( + self._on_debye_scattering_finished + ) + self._debye_scattering_worker.canceled.connect( + self._on_debye_scattering_canceled + ) + self._debye_scattering_worker.failed.connect( + self._on_debye_scattering_failed + ) + self._debye_scattering_worker.finished.connect( + self._debye_scattering_thread.quit + ) + self._debye_scattering_worker.canceled.connect( + self._debye_scattering_thread.quit + ) + self._debye_scattering_worker.failed.connect( + self._debye_scattering_thread.quit + ) + self._debye_scattering_thread.finished.connect( + self._debye_scattering_worker.deleteLater + ) + self._debye_scattering_thread.finished.connect( + self._debye_scattering_thread.deleteLater + ) + self._debye_scattering_thread.finished.connect( + self._clear_debye_scattering_handles + ) + self._debye_scattering_thread.start(QThread.Priority.LowPriority) + + @Slot() + def _clear_debye_scattering_handles(self) -> None: + self._debye_scattering_worker = None + self._debye_scattering_thread = None + self.stop_calculation_button.setEnabled(False) + self._refresh_run_action_state() + + @Slot(int, int, str) + def _on_debye_scattering_progress( + self, + current: int, + total: int, + message: str, + ) -> None: + self._update_debye_scattering_progress(current, total, message) + + @staticmethod + def _debye_scattering_payload_results( + payload: object, + ) -> tuple[_DebyeScatteringCompletedItem, ...]: + if isinstance(payload, _DebyeScatteringRunPayload): + return payload.results + return () + + @staticmethod + def _debye_scattering_payload_failures( + payload: object, + ) -> tuple[str, ...]: + if isinstance(payload, _DebyeScatteringRunPayload): + return payload.failures + return () + + def _apply_batch_debye_scattering_results( + self, + results: tuple[_DebyeScatteringCompletedItem, ...], + ) -> list[_ClusterDensityGroupState]: + state_by_key = { + state.key: state for state in self._cluster_group_states + } + updated_states: list[_ClusterDensityGroupState] = [] + for item in results: + if item.group_key is None: + continue + state = state_by_key.get(item.group_key) + if state is None: + continue + state.debye_scattering_result = item.result + if state.key == self._selected_cluster_group_key: + self._debye_scattering_result = item.result + self._capture_saved_output_entry( + "debye_scattering", + group_state=state, + profile_result=state.profile_result, + transform_result=state.transform_result, + debye_scattering_result=item.result, + ) + updated_states.append(state) + return updated_states + + def _finish_batch_debye_scattering_run( + self, + *, + results: tuple[_DebyeScatteringCompletedItem, ...], + failures: tuple[str, ...], + context: _DebyeScatteringRunContext, + canceled: bool, + ) -> None: + updated_states = self._apply_batch_debye_scattering_results(results) + updated_count = len(updated_states) + if updated_count <= 0: + self._refresh_debye_scattering_group() + if canceled: + self._append_status( + "Stopped the active Debye scattering calculation before any target rows finished." + ) + self.statusBar().showMessage( + "Debye scattering calculation stopped" + ) + return + if failures: + self._show_error( + "Debye Scattering Error", + "\n".join(failures[:6]), + ) + return + if ( + context.skipped_single_atom > 0 + and context.skipped_pending == 0 + ): + self._show_error( + "No Debye Targets Updated", + "The selected target rows already use direct Debye " + "scattering only, so a separate Born-vs-Debye " + "comparison trace is not needed.", + ) + return + self._show_error( + "Born Transform Required", + "Evaluate the Born-approximation Fourier transform for the " + "target rows before computing Debye comparison traces.", ) - except Exception as exc: - error_message = str(exc) - finally: - self._set_calculation_running(False) - self.stop_calculation_button.setEnabled(False) - self._reset_debye_scattering_progress() - if error_message is not None: - self._show_error("Debye Scattering Error", error_message) return - if result is None: + self._refresh_cluster_views_after_batch_update( + updated_states, + selected_keys=list(context.selected_keys), + ) + summary_parts = [ + f"computed {updated_count} Debye scattering average" + f"{'' if updated_count == 1 else 's'}" + ] + if context.skipped_pending > 0: + summary_parts.append( + f"skipped {context.skipped_pending} row" + f"{'' if context.skipped_pending == 1 else 's'} without a Born trace" + ) + if context.skipped_single_atom > 0: + summary_parts.append( + f"skipped {context.skipped_single_atom} direct-Debye row" + f"{'' if context.skipped_single_atom == 1 else 's'}" + ) + if failures: + summary_parts.append( + f"{len(failures)} row" + f"{'' if len(failures) == 1 else 's'} failed" + ) + if canceled: + summary_parts.append("stopped before the remaining rows finished") + self._append_status( + "Debye scattering batch update: " + + "; ".join(summary_parts) + + f" across {context.scope_label}. Each trace reused the q-grid from its " + "matching Born-approximation transform." + ) + for failure in failures: + self._append_status(f"Debye scattering warning: {failure}") + self.statusBar().showMessage( + ( + "Debye scattering calculation stopped" + if canceled + else "Debye scattering averages ready" + ) + ) + self._sync_workspace_state() + + @Slot(object) + def _on_debye_scattering_finished(self, payload: object) -> None: + context = self._active_debye_scattering_context + self._active_debye_scattering_context = None + self._set_calculation_running(False) + self._reset_debye_scattering_progress() + if context is None: + self._on_debye_scattering_failed( + "Debye scattering calculation finished without context." + ) + return + results = self._debye_scattering_payload_results(payload) + failures = self._debye_scattering_payload_failures(payload) + if context.mode == "batch": + self._finish_batch_debye_scattering_run( + results=results, + failures=failures, + context=context, + canceled=False, + ) + return + if not results: self._show_error( "Debye Scattering Error", - "Debye scattering calculation finished without a result.", + ( + "\n".join(failures[:6]) + if failures + else ( + "Debye scattering calculation finished without a result." + ) + ), ) return + result = results[0].result self._debye_scattering_result = result self._refresh_debye_scattering_group() q_values = np.asarray(result.q_values, dtype=float) @@ -9929,11 +10752,56 @@ def _calculate_debye_scattering_action(self) -> None: f"{'' if result.source_structure_count == 1 else 's'}, " f"{len(q_values)} q points." ) + self._capture_saved_output_entry( + "debye_scattering", + group_state=self._active_cluster_group_state(), + profile_result=self._profile_result, + transform_result=self._fourier_result, + debye_scattering_result=result, + ) for note in result.notes: self._append_status(note) self.statusBar().showMessage("Debye scattering average ready") self._sync_workspace_state() + @Slot(object) + def _on_debye_scattering_canceled(self, payload: object) -> None: + context = self._active_debye_scattering_context + self._active_debye_scattering_context = None + self._set_calculation_running(False) + self._reset_debye_scattering_progress() + self.debye_scattering_status_label.setText( + "Debye scattering calculation stopped." + ) + if context is None: + self._append_status( + "Stopped the active Debye scattering calculation." + ) + self.statusBar().showMessage( + "Debye scattering calculation stopped" + ) + return + results = self._debye_scattering_payload_results(payload) + failures = self._debye_scattering_payload_failures(payload) + if context.mode == "batch": + self._finish_batch_debye_scattering_run( + results=results, + failures=failures, + context=context, + canceled=True, + ) + return + self._append_status("Stopped the active Debye scattering calculation.") + self.statusBar().showMessage("Debye scattering calculation stopped") + + @Slot(str) + def _on_debye_scattering_failed(self, message: str) -> None: + self._active_debye_scattering_context = None + self._set_calculation_running(False) + self._reset_debye_scattering_progress() + self.debye_scattering_status_label.setText("Debye scattering failed.") + self._show_error("Debye Scattering Error", str(message)) + def _calculate_debye_scattering_for_target_clusters( self, *, @@ -9951,10 +10819,8 @@ def _calculate_debye_scattering_for_target_clusters( scope_label, selected_keys, ) = self._batch_target_cluster_group_states(apply_to_all=apply_to_all) - updated_count = 0 skipped_pending = 0 skipped_single_atom = 0 - failures: list[str] = [] eligible_states = [ state for state in target_states @@ -9970,106 +10836,16 @@ def _calculate_debye_scattering_for_target_clusters( ) for state in eligible_states ) - if progress_total > 0: - skipped_count = len(target_states) - len(eligible_states) - progress_message = ( - "Preparing Debye scattering averages across " - f"{len(eligible_states)} target row" - f"{'' if len(eligible_states) == 1 else 's'} in " - f"{scope_label}." - ) - if skipped_count > 0: - progress_message += ( - f" {skipped_count} row" - f"{'' if skipped_count == 1 else 's'} will be skipped." - ) - self._set_calculation_running(True) - self.stop_calculation_button.setEnabled(False) - self._begin_debye_scattering_progress( - total=progress_total, - message=progress_message, - ) - progress_offset = 0 - eligible_index = 0 - try: - for state in target_states: - if state.single_atom_only: - skipped_single_atom += 1 - state.debye_scattering_result = None - continue - born_result = state.transform_result - if not self._is_density_fourier_transform(born_result): - skipped_pending += 1 - state.debye_scattering_result = None - continue - eligible_index += 1 - state_progress_total = ( - self._debye_scattering_progress_total_for_inspection( - state.inspection - ) - ) - - def emit_state_progress( - current: int, - total: int, - message: str, - *, - _state=state, - _offset=progress_offset, - _scope_label=scope_label, - _scope_index=eligible_index, - _scope_total=len(eligible_states), - ) -> None: - self._update_debye_scattering_progress( - _offset - + min(max(int(current), 0), max(int(total), 1)), - progress_total, - "Debye " - f"{_scope_index}/{_scope_total} " - f"[{_state.display_name}] in {_scope_label}: " - f"{str(message).strip()}", - ) - - try: - result = ( - compute_average_debye_scattering_profile_for_input( - state.inspection, - q_values=np.asarray( - born_result.q_values, - dtype=float, - ), - progress_callback=emit_state_progress, - ) - ) - except Exception as exc: - failures.append(f"{state.display_name}: {exc}") - progress_offset += state_progress_total - if progress_total > 0: - self._update_debye_scattering_progress( - progress_offset, - progress_total, - f"Debye {eligible_index}/{len(eligible_states)} " - f"[{state.display_name}] in {scope_label} failed.", - ) - continue - state.debye_scattering_result = result - if state.key == self._selected_cluster_group_key: - self._debye_scattering_result = result - updated_count += 1 - progress_offset += state_progress_total - finally: - if progress_total > 0: - self._set_calculation_running(False) - self.stop_calculation_button.setEnabled(False) - self._reset_debye_scattering_progress() - if updated_count <= 0: + for state in target_states: + if state.single_atom_only: + skipped_single_atom += 1 + state.debye_scattering_result = None + continue + if not self._is_density_fourier_transform(state.transform_result): + skipped_pending += 1 + state.debye_scattering_result = None + if progress_total <= 0: self._refresh_debye_scattering_group() - if failures: - self._show_error( - "Debye Scattering Error", - "\n".join(failures[:6]), - ) - return if skipped_single_atom > 0 and skipped_pending == 0: self._show_error( "No Debye Targets Updated", @@ -10084,39 +10860,47 @@ def emit_state_progress( "target rows before computing Debye comparison traces.", ) return - self._refresh_cluster_views_after_batch_update( - target_states, - selected_keys=selected_keys, - ) - summary_parts = [ - f"computed {updated_count} Debye scattering average" - f"{'' if updated_count == 1 else 's'}" - ] - if skipped_pending > 0: - summary_parts.append( - f"skipped {skipped_pending} row" - f"{'' if skipped_pending == 1 else 's'} without a Born trace" - ) - if skipped_single_atom > 0: - summary_parts.append( - f"skipped {skipped_single_atom} direct-Debye row" - f"{'' if skipped_single_atom == 1 else 's'}" - ) - if failures: - summary_parts.append( - f"{len(failures)} row" - f"{'' if len(failures) == 1 else 's'} failed" - ) - self._append_status( - "Debye scattering batch update: " - + "; ".join(summary_parts) - + f" across {scope_label}. Each trace reused the q-grid from its " - "matching Born-approximation transform." + skipped_count = len(target_states) - len(eligible_states) + progress_message = ( + "Preparing Debye scattering averages across " + f"{len(eligible_states)} target row" + f"{'' if len(eligible_states) == 1 else 's'} in " + f"{scope_label}." + ) + if skipped_count > 0: + progress_message += ( + f" {skipped_count} row" + f"{'' if skipped_count == 1 else 's'} will be skipped." + ) + self._start_debye_scattering_run( + items=tuple( + _DebyeScatteringWorkItem( + group_key=state.key, + label=state.display_name, + inspection=state.inspection, + q_values=tuple( + float(value) + for value in np.asarray( + state.transform_result.q_values, + dtype=float, + ) + ), + progress_total=self._debye_scattering_progress_total_for_inspection( + state.inspection + ), + ) + for state in eligible_states + if state.transform_result is not None + ), + context=_DebyeScatteringRunContext( + mode="batch", + scope_label=scope_label, + selected_keys=tuple(selected_keys), + skipped_pending=skipped_pending, + skipped_single_atom=skipped_single_atom, + ), + initial_message=progress_message, ) - for failure in failures: - self._append_status(f"Debye scattering warning: {failure}") - self.statusBar().showMessage("Debye scattering averages ready") - self._sync_workspace_state() @Slot() def _open_debye_scattering_comparison_plot(self) -> None: @@ -10376,6 +11160,62 @@ def _update_distribution_metadata_after_push( encoding="utf-8", ) + def _ensure_linked_distribution_ready_for_push(self) -> None: + if ( + self._preview_mode + or self._distribution_root_dir is None + or self._project_dir is None + ): + return + metadata_path = self._distribution_root_dir / "distribution.json" + prior_weights_path = self._distribution_root_dir / ( + "md_prior_weights_predicted_structures.json" + if self._use_predicted_structure_weights + else "md_prior_weights.json" + ) + if metadata_path.is_file() and prior_weights_path.is_file(): + return + + from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + ) + from saxshell.saxs.project_manager.project import ( + SAXSProjectManager, + project_artifact_paths, + ) + + project_manager = SAXSProjectManager() + settings = project_manager.load_project(self._project_dir) + settings.component_build_mode = COMPONENT_BUILD_MODE_BORN_APPROXIMATION + settings.use_predicted_structure_weights = bool( + self._use_predicted_structure_weights + ) + if self._project_q_min is not None: + settings.q_min = float(self._project_q_min) + if self._project_q_max is not None: + settings.q_max = float(self._project_q_max) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + if artifact_paths.root_dir != self._distribution_root_dir: + raise ValueError( + "The linked computed distribution no longer matches the " + "active project settings. Reopen the Born Approximation " + "workflow from Project Setup and push the components again." + ) + project_manager.generate_prior_weights(settings) + if not metadata_path.is_file() or not prior_weights_path.is_file(): + raise ValueError( + "The linked computed distribution could not be prepared for " + "Born-approximation component export." + ) + self._append_status( + "Prepared the linked computed distribution metadata and prior " + "weights for Born-approximation component export." + ) + @Slot() def _push_components_to_model(self) -> None: lock_reason = self._distribution_push_lock_reason() @@ -10408,6 +11248,7 @@ def _push_components_to_model(self) -> None: "Pending: " + ", ".join(pending), ) return + self._ensure_linked_distribution_ready_for_push() component_dir, component_map_path = artifact_targets component_dir.mkdir(parents=True, exist_ok=True) saxs_map: dict[str, dict[str, str]] = {} @@ -10865,6 +11706,21 @@ def _apply_mesh_from_controls(self) -> None: @Slot() def _request_calculation_stop(self) -> None: + if self._debye_scattering_worker is not None: + self.stop_calculation_button.setEnabled(False) + self._debye_scattering_worker.cancel() + self.cancel_debye_scattering_requested.emit() + self.debye_scattering_status_label.setText( + "Stopping Debye scattering calculation..." + ) + self.statusBar().showMessage( + "Stopping Debye scattering calculation...", + 5000, + ) + self._append_status( + "Requested a stop for the active Debye scattering calculation." + ) + return if self._calculation_worker is None: return self._calculation_cancel_requested = True @@ -11782,6 +12638,9 @@ def _build_plot_trace_columns( profile_result: ElectronDensityProfileResult | None, fourier_preview: ElectronDensityFourierTransformPreview | None, transform_result: ElectronDensityScatteringTransformResult | None, + debye_scattering_result: ( + ElectronDensityDebyeScatteringAverageResult | None + ) = None, ) -> tuple[list[str], list[list[str]]]: headers: list[str] = [] columns: list[list[str]] = [] @@ -11876,6 +12735,33 @@ def _build_plot_trace_columns( [f"{value:.10g}" for value in amplitude], [f"{value:.10g}" for value in intensity], ] + debye_result = debye_scattering_result + if debye_result is not None: + debye_q_values = np.asarray(debye_result.q_values, dtype=float) + debye_mean = np.asarray( + debye_result.mean_intensity, + dtype=float, + ) + debye_std = np.asarray( + debye_result.std_intensity, + dtype=float, + ) + debye_se = np.asarray( + debye_result.se_intensity, + dtype=float, + ) + headers += [ + "debye_q_a_inv", + "debye_mean_intensity", + "debye_std_intensity", + "debye_se_intensity", + ] + columns += [ + [f"{value:.10g}" for value in debye_q_values], + [f"{value:.10g}" for value in debye_mean], + [f"{value:.10g}" for value in debye_std], + [f"{value:.10g}" for value in debye_se], + ] return headers, columns @staticmethod @@ -11885,12 +12771,16 @@ def _write_plot_trace_csv( profile_result: ElectronDensityProfileResult | None, fourier_preview: ElectronDensityFourierTransformPreview | None, transform_result: ElectronDensityScatteringTransformResult | None, + debye_scattering_result: ( + ElectronDensityDebyeScatteringAverageResult | None + ) = None, ) -> None: headers, columns = ( ElectronDensityMappingMainWindow._build_plot_trace_columns( profile_result=profile_result, fourier_preview=fourier_preview, transform_result=transform_result, + debye_scattering_result=debye_scattering_result, ) ) max_rows = max((len(column) for column in columns), default=0) @@ -11947,6 +12837,7 @@ def _export_plot_traces(self) -> None: profile_result=self._profile_result, fourier_preview=self._fourier_preview, transform_result=self._fourier_result, + debye_scattering_result=self._debye_scattering_result, ) except Exception as exc: self._show_error("Export Failed", str(exc)) @@ -11976,6 +12867,17 @@ def closeEvent(self, event: QCloseEvent) -> None: self._sync_workspace_state() self._close_workspace_load_progress_dialog() self._close_batch_operation_progress_dialog() + if self._debye_scattering_progress_dialog is not None: + self._debye_scattering_progress_dialog.close() + if ( + self._debye_scattering_worker is not None + and self._debye_scattering_thread is not None + and self._debye_scattering_thread.isRunning() + ): + self._debye_scattering_worker.cancel() + self.cancel_debye_scattering_requested.emit() + self._debye_scattering_thread.quit() + self._debye_scattering_thread.wait(1000) if ( self._calculation_worker is not None and self._calculation_thread is not None diff --git a/src/saxshell/saxs/electron_density_mapping/ui/viewer.py b/src/saxshell/saxs/electron_density_mapping/ui/viewer.py index 57ee6d7..1a1b892 100644 --- a/src/saxshell/saxs/electron_density_mapping/ui/viewer.py +++ b/src/saxshell/saxs/electron_density_mapping/ui/viewer.py @@ -26,6 +26,7 @@ QWidget, ) +from saxshell.plotting import Q_A_INVERSE_LABEL from saxshell.saxs.electron_density_mapping.workflow import ( ElectronDensityFourierTransformPreview, ElectronDensityMeshGeometry, @@ -827,7 +828,7 @@ def set_transform_result( axis.set_xscale("log") if log_intensity_axis: axis.set_yscale("log") - axis.set_xlabel("q (Å⁻¹)", labelpad=10.0) + axis.set_xlabel(Q_A_INVERSE_LABEL, labelpad=10.0) axis.set_ylabel("Intensity (arb. units)") axis.set_title("q-Space Scattering Profile") axis.grid(True, which="both", alpha=0.28) diff --git a/src/saxshell/saxs/electron_density_mapping/workflow.py b/src/saxshell/saxs/electron_density_mapping/workflow.py index 9deec41..a843c5e 100644 --- a/src/saxshell/saxs/electron_density_mapping/workflow.py +++ b/src/saxshell/saxs/electron_density_mapping/workflow.py @@ -18,6 +18,7 @@ except ImportError: # pragma: no cover - optional import until runtime xraydb = None +from saxshell.saxs.born_refinement.backend import build_shared_q_grid from saxshell.saxs.contrast.electron_density import ( CONTRAST_SOLVENT_METHOD_DIRECT, CONTRAST_SOLVENT_METHOD_NEAT, @@ -987,6 +988,47 @@ def nearest_atom_coordinates(self) -> np.ndarray: ) +def legacy_born_average_default_mesh_settings( + structure: ElectronDensityStructure | None = None, +) -> ElectronDensityMeshSettings: + target_rmax = 8.0 + if structure is not None: + target_rmax = max(float(np.ceil(float(structure.rmax) + 2.0)), 0.01) + return ElectronDensityMeshSettings( + rstep=0.25, + theta_divisions=120, + phi_divisions=60, + rmax=target_rmax, + ).normalized() + + +def legacy_born_average_default_smearing_settings() -> ( + ElectronDensitySmearingSettings +): + return ElectronDensitySmearingSettings( + debye_waller_factor=0.0 + ).normalized() + + +def legacy_born_average_default_fourier_settings( + *, + q_min: float = 0.01, + q_max: float = 1.2, + q_step: float = 0.01, + r_max: float = 1.0, +) -> ElectronDensityFourierTransformSettings: + return ElectronDensityFourierTransformSettings( + r_min=0.0, + r_max=max(float(r_max), 0.01), + domain_mode="legacy", + window_function="none", + q_min=float(q_min), + q_max=float(q_max), + q_step=float(q_step), + resampling_points=4096, + ).normalized() + + @dataclass(slots=True, frozen=True) class ElectronDensityMeshGeometry: settings: ElectronDensityMeshSettings @@ -2293,15 +2335,11 @@ def _q_values_from_transform_settings( settings: ElectronDensityFourierTransformSettings, ) -> np.ndarray: normalized = settings.normalized() - q_values = np.arange( + return build_shared_q_grid( float(normalized.q_min), - float(normalized.q_max) + float(normalized.q_step) * 0.5, - float(normalized.q_step), - dtype=float, + float(normalized.q_max), + q_step=float(normalized.q_step), ) - if q_values.size > 0 and q_values[-1] > float(normalized.q_max) + 1.0e-12: - q_values = q_values[:-1] - return np.asarray(q_values, dtype=float) def _validated_debye_q_values( @@ -3593,6 +3631,9 @@ def write_electron_density_profile_outputs( "compute_electron_density_scattering_profile", "compute_single_atom_debye_scattering_profile_for_input", "inspect_structure_input", + "legacy_born_average_default_fourier_settings", + "legacy_born_average_default_mesh_settings", + "legacy_born_average_default_smearing_settings", "load_electron_density_structure", "prepare_single_atom_debye_scattering_preview", "prepare_electron_density_fourier_transform", diff --git a/src/saxshell/saxs/prefit/cluster_geometry.py b/src/saxshell/saxs/prefit/cluster_geometry.py index e853658..c2be36f 100644 --- a/src/saxshell/saxs/prefit/cluster_geometry.py +++ b/src/saxshell/saxs/prefit/cluster_geometry.py @@ -435,6 +435,7 @@ def from_dict( def compute_cluster_geometry_metadata( clusters_dir: str | Path, *, + cluster_bins: list[ClusterBin] | tuple[ClusterBin, ...] | None = None, extra_cluster_bins: ( list[ClusterBin] | tuple[ClusterBin, ...] | None ) = None, @@ -458,7 +459,11 @@ def compute_cluster_geometry_metadata( active_ionic_radius_type, default=DEFAULT_IONIC_RADIUS_TYPE, ) - cluster_bins = list(discover_cluster_bins(resolved_clusters_dir)) + cluster_bins = ( + list(cluster_bins) + if cluster_bins is not None + else list(discover_cluster_bins(resolved_clusters_dir)) + ) if extra_cluster_bins: cluster_bins.extend(extra_cluster_bins) total_files = sum(len(cluster_bin.files) for cluster_bin in cluster_bins) diff --git a/src/saxshell/saxs/prefit/workflow.py b/src/saxshell/saxs/prefit/workflow.py index 62f15a9..2037f08 100644 --- a/src/saxshell/saxs/prefit/workflow.py +++ b/src/saxshell/saxs/prefit/workflow.py @@ -17,6 +17,10 @@ load_template_module, load_template_spec, ) +from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + normalize_component_build_mode, +) from saxshell.saxs.prefit.cluster_geometry import ( DEFAULT_IONIC_RADIUS_TYPE, DEFAULT_RADIUS_TYPE, @@ -34,8 +38,8 @@ ProjectSettings, SAXSProjectManager, build_project_paths, + component_source_mode_label, distribution_id_for_settings, - effective_q_range_for_settings, load_built_component_q_range, project_artifact_paths, ) @@ -116,6 +120,29 @@ def normalize_requested_q_range_to_supported( return normalized_min, normalized_max +def component_q_range_boundary_tolerance( + component_build_mode: object, + q_values: np.ndarray, + supported_min: float, + supported_max: float, +) -> float: + tolerance = q_range_boundary_tolerance(supported_min, supported_max) + if ( + normalize_component_build_mode(component_build_mode) + != COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ): + return tolerance + q_grid = np.asarray(q_values, dtype=float) + finite_q = np.unique(q_grid[np.isfinite(q_grid)]) + if finite_q.size < 2: + return tolerance + positive_diffs = np.diff(np.sort(finite_q)) + positive_diffs = positive_diffs[positive_diffs > 0.0] + if positive_diffs.size == 0: + return tolerance + return max(tolerance, float(np.median(positive_diffs)) * 1.01) + + @dataclass(slots=True) class PrefitComponent: structure: str @@ -899,6 +926,57 @@ def _predicted_structure_cluster_bins_for_active_components( included_components=predicted_components, ) + def _representative_structure_cluster_bins_for_active_components( + self, + ) -> list: + if not self.settings.use_representative_structures: + return [] + active_components = { + ( + str(component.structure).strip(), + str(component.motif).strip() or "no_motif", + ) + for component in self.components + if not str(component.motif).strip().startswith("predicted_rank") + } + if not active_components: + return [] + try: + inventory = self.project_manager._representative_cluster_inventory( + self.settings + ) + except Exception as exc: + if self.settings.resolved_clusters_dir is not None: + return [] + raise ValueError( + "Use Representative Structures is enabled, but SAXS Prefit " + "could not load the saved representative structure sources. " + "Reopen Project Setup, verify the representative selection, " + "and push or rebuild the SAXS components again." + ) from exc + return [ + cluster_bin + for cluster_bin in inventory.cluster_bins + if ( + str(cluster_bin.structure).strip(), + str(cluster_bin.motif).strip() or "no_motif", + ) + in active_components + ] + + @staticmethod + def _cluster_geometry_source_dir_for_bins( + cluster_bins: list, + ) -> Path | None: + source_dirs = [ + Path(cluster_bin.source_dir).expanduser().resolve() + for cluster_bin in cluster_bins + if getattr(cluster_bin, "source_dir", None) is not None + ] + if not source_dirs: + return None + return source_dirs[0] + def run_fit( self, parameter_entries: list[PrefitParameterEntry] | None = None, @@ -1518,16 +1596,28 @@ def recommend_scale_settings( if evaluation.solvent_contribution is not None else np.zeros_like(evaluation.model_intensities, dtype=float) ) - target = np.asarray( - evaluation.experimental_intensities - - offset_value - - solvent_contribution, - dtype=float, - ) - model = np.asarray( - evaluation.model_intensities - offset_value - solvent_contribution, - dtype=float, - ) + if self.solvent_contribution_is_scaled_by_global_scale(): + target = np.asarray( + evaluation.experimental_intensities - offset_value, + dtype=float, + ) + model = np.asarray( + evaluation.model_intensities - offset_value, + dtype=float, + ) + else: + target = np.asarray( + evaluation.experimental_intensities + - offset_value + - solvent_contribution, + dtype=float, + ) + model = np.asarray( + evaluation.model_intensities + - offset_value + - solvent_contribution, + dtype=float, + ) mask = np.isfinite(target) & np.isfinite(model) if not np.any(mask): raise ValueError( @@ -1609,12 +1699,20 @@ def recommend_scale_settings( if current_scale > 0.0 else float("nan") ) + adaptive_bounds = ( + self.template_spec.prefit_support.autoscale_bounds_mode + == "adaptive" + ) recommended_minimum = max(recommended_scale / span_factor, 1e-12) recommended_maximum = max( recommended_scale * span_factor, recommended_scale * 1.5, - float(scale_entry.maximum), ) + if not adaptive_bounds: + recommended_maximum = max( + recommended_maximum, + float(scale_entry.maximum), + ) recommended_offset_minimum: float | None = None recommended_offset_maximum: float | None = None if offset_entry is not None and recommended_offset is not None: @@ -1626,14 +1724,22 @@ def recommend_scale_settings( abs(float(recommended_offset)) / span_factor, 1e-12, ) - recommended_offset_minimum = min( - float(offset_entry.minimum), - float(recommended_offset) - offset_padding, - ) - recommended_offset_maximum = max( - float(offset_entry.maximum), - float(recommended_offset) + offset_padding, - ) + if adaptive_bounds: + recommended_offset_minimum = ( + float(recommended_offset) - offset_padding + ) + recommended_offset_maximum = ( + float(recommended_offset) + offset_padding + ) + else: + recommended_offset_minimum = min( + float(offset_entry.minimum), + float(recommended_offset) - offset_padding, + ) + recommended_offset_maximum = max( + float(offset_entry.maximum), + float(recommended_offset) + offset_padding, + ) return PrefitScaleRecommendation( current_scale=raw_current_scale, recommended_scale=recommended_scale, @@ -1647,7 +1753,84 @@ def recommend_scale_settings( points_used=int(np.count_nonzero(mask)), ) + def apply_scale_recommendation_to_entries( + self, + parameter_entries: list[PrefitParameterEntry] | None = None, + *, + recommendation: PrefitScaleRecommendation | None = None, + ) -> list[PrefitParameterEntry]: + entries = self._copy_entries( + parameter_entries or self.parameter_entries + ) + resolved_recommendation = ( + recommendation or self.recommend_scale_settings(entries) + ) + for entry in entries: + if entry.name == "scale": + entry.value = resolved_recommendation.recommended_scale + entry.minimum = resolved_recommendation.recommended_minimum + entry.maximum = resolved_recommendation.recommended_maximum + entry.vary = True + elif ( + entry.name == "offset" + and resolved_recommendation.recommended_offset is not None + ): + entry.value = resolved_recommendation.recommended_offset + if ( + resolved_recommendation.recommended_offset_minimum + is not None + ): + entry.minimum = ( + resolved_recommendation.recommended_offset_minimum + ) + if ( + resolved_recommendation.recommended_offset_maximum + is not None + ): + entry.maximum = ( + resolved_recommendation.recommended_offset_maximum + ) + return entries + + def current_prefit_state_exists(self) -> bool: + return (self.prefit_dir / "prefit_state.json").is_file() + + def should_auto_apply_autoscale_on_load(self) -> bool: + return ( + self.template_spec.prefit_support.auto_apply_autoscale_on_load + and self.can_run_prefit() + and not self.has_best_prefit_entries() + and not self.current_prefit_state_exists() + ) + + def auto_apply_autoscale_on_load( + self, + ) -> PrefitScaleRecommendation | None: + if not self.should_auto_apply_autoscale_on_load(): + return None + recommendation = self.recommend_scale_settings(self.parameter_entries) + self.parameter_entries = self.apply_scale_recommendation_to_entries( + self.parameter_entries, + recommendation=recommendation, + ) + return recommendation + def volume_fraction_estimator_target(self) -> tuple[str, str] | None: + target = self.solution_scattering_volume_fraction_target() + if target is not None: + return target[:2] + return None + + def solution_scattering_volume_fraction_target( + self, + ) -> tuple[str, str, str] | None: + support = self.template_spec.solution_scattering_support + if support.volume_fraction_parameter is not None: + return ( + support.volume_fraction_parameter, + support.volume_fraction_kind, + support.volume_fraction_source, + ) parameter_names = { str(parameter.name).strip() for parameter in self.template_spec.parameters @@ -1655,10 +1838,10 @@ def volume_fraction_estimator_target(self) -> tuple[str, str] | None: } for candidate in SOLUTE_VOLUME_FRACTION_PARAMETER_NAMES: if candidate in parameter_names: - return candidate, "solute" + return candidate, "solute", "saxs_effective" for candidate in SOLVENT_VOLUME_FRACTION_PARAMETER_NAMES: if candidate in parameter_names: - return candidate, "solvent" + return candidate, "solvent", "saxs_effective" return None def supports_volume_fraction_estimator(self) -> bool: @@ -1675,6 +1858,12 @@ def solvent_weight_estimator_target(self) -> str | None: return candidate return None + def solvent_contribution_is_scaled_by_global_scale(self) -> bool: + return ( + self.template_spec.solution_scattering_support.solvent_contribution_scale_mode + == "global_scale" + ) + def supports_cluster_geometry_metadata(self) -> bool: return bool(self.template_spec.cluster_geometry_support.supported) @@ -1782,7 +1971,16 @@ def compute_cluster_geometry_table_with_progress( *, progress_callback=None, ) -> ClusterGeometryMetadataTable: - clusters_dir = self.settings.resolved_clusters_dir + representative_cluster_bins = ( + self._representative_structure_cluster_bins_for_active_components() + ) + clusters_dir = ( + self._cluster_geometry_source_dir_for_bins( + representative_cluster_bins + ) + if representative_cluster_bins + else self.settings.resolved_clusters_dir + ) if clusters_dir is None: raise ValueError( "Select a clusters directory in Project Setup before " @@ -1790,6 +1988,7 @@ def compute_cluster_geometry_table_with_progress( ) table = compute_cluster_geometry_metadata( clusters_dir, + cluster_bins=representative_cluster_bins or None, extra_cluster_bins=( self._predicted_structure_cluster_bins_for_active_components() ), @@ -2225,7 +2424,36 @@ def _entry_signature( (entry.structure, entry.motif, entry.name) for entry in entries ] + @staticmethod + def _load_numeric_table(path: Path, *, min_columns: int = 2) -> np.ndarray: + raw_data = np.asarray(np.loadtxt(path, comments="#"), dtype=float) + if raw_data.ndim == 1: + raw_data = raw_data.reshape(1, -1) + if raw_data.ndim != 2 or raw_data.shape[1] < min_columns: + raise ValueError( + f"Expected at least {min_columns} numeric columns in {path}." + ) + return raw_data + def _load_components(self) -> list[PrefitComponent]: + if not self.project_manager.component_artifacts_match_settings( + self.settings, + artifact_paths=self.artifact_paths, + ): + saved_mode = self.project_manager.built_component_source_mode( + self.settings, + artifact_paths=self.artifact_paths, + ) + if saved_mode is not None: + raise FileNotFoundError( + "The saved SAXS components for this computed " + "distribution were built from " + f"{component_source_mode_label(saved_mode)}, but the " + "current Project Setup selection expects " + f"{component_source_mode_label('representative' if self.settings.use_representative_structures else 'average')}. " + "Rebuild SAXS components in Project Setup before " + "running Prefit." + ) if not self.component_map_path.is_file(): if self.settings.use_predicted_structure_weights: predicted_state = ( @@ -2291,7 +2519,7 @@ def _load_components(self) -> list[PrefitComponent]: for motif in sorted(motif_map, key=_natural_sort_key): profile_file = str(motif_map[motif]) profile_path = self.component_dir / profile_file - raw_data = np.loadtxt(profile_path, comments="#") + raw_data = self._load_numeric_table(profile_path) q_values = np.asarray(raw_data[:, 0], dtype=float) intensities = np.asarray(raw_data[:, 1], dtype=float) components.append( @@ -2345,7 +2573,7 @@ def _load_solvent_trace(self) -> np.ndarray | None: self.paths.experimental_data_dir.glob("solv_*") ): if candidate.is_file(): - raw_data = np.loadtxt(candidate, comments="#") + raw_data = self._load_numeric_table(candidate) return np.interp( q_values, np.asarray(raw_data[:, 0], dtype=float), @@ -2400,7 +2628,19 @@ def _requested_q_bounds( ) if q_values.size == 0: raise ValueError("No SAXS component q-values are available.") - return effective_q_range_for_settings(self.settings, q_values) + requested_min = ( + float(self.settings.q_min) + if self.settings.q_min is not None + else float(np.min(q_values)) + ) + requested_max = ( + float(self.settings.q_max) + if self.settings.q_max is not None + else float(np.max(q_values)) + ) + if requested_min > requested_max: + raise ValueError("q min must be less than or equal to q max.") + return requested_min, requested_max def _ensure_requested_q_range_supported( self, @@ -2409,18 +2649,15 @@ def _ensure_requested_q_range_supported( q_values = np.asarray(source_q_values, dtype=float) requested_min, requested_max = self._requested_q_bounds(q_values) supported_min, supported_max = self._supported_component_q_range() - requested_min, requested_max = ( - normalize_requested_q_range_to_supported( - requested_min, - requested_max, - supported_min, - supported_max, - ) - ) - tolerance = q_range_boundary_tolerance( + tolerance = self._component_q_range_boundary_tolerance( + q_values, supported_min, supported_max, ) + if abs(requested_min - supported_min) <= tolerance: + requested_min = float(supported_min) + if abs(requested_max - supported_max) <= tolerance: + requested_max = float(supported_max) if requested_min < (supported_min - tolerance) or requested_max > ( supported_max + tolerance ): @@ -2434,6 +2671,19 @@ def _ensure_requested_q_range_supported( ) return requested_min, requested_max + def _component_q_range_boundary_tolerance( + self, + q_values: np.ndarray, + supported_min: float, + supported_max: float, + ) -> float: + return component_q_range_boundary_tolerance( + self.settings.component_build_mode, + q_values, + supported_min, + supported_max, + ) + def _component_q_values_from_candidates( self, candidates: list[PrefitComponent] | None = None, @@ -2454,7 +2704,7 @@ def _component_q_values_from_candidates( "SAXS components in Project Setup before previewing the model." ) else: - raw_data = np.loadtxt(component_files[0], comments="#") + raw_data = self._load_numeric_table(component_files[0]) source_q_values = np.asarray(raw_data[:, 0], dtype=float) requested_min, requested_max = ( @@ -2550,7 +2800,7 @@ def _solvent_trace_for_q_values( ) for candidate in component_files: if candidate.is_file(): - raw_data = np.loadtxt(candidate, comments="#") + raw_data = self._load_numeric_table(candidate) return np.interp( np.asarray(q_values, dtype=float), np.asarray(raw_data[:, 0], dtype=float), diff --git a/src/saxshell/saxs/project_manager/__init__.py b/src/saxshell/saxs/project_manager/__init__.py index 49a7904..bd2a5c5 100644 --- a/src/saxshell/saxs/project_manager/__init__.py +++ b/src/saxshell/saxs/project_manager/__init__.py @@ -5,6 +5,9 @@ export_prior_plot_data, list_secondary_filter_elements, plot_md_prior_histogram, + prior_histogram_default_legend_title, + prior_histogram_default_title, + prior_histogram_default_y_label, ) from .project import ( ClusterImportResult, @@ -19,8 +22,10 @@ ProjectSettings, RegisteredFileSnapshot, RegisteredFolderSnapshot, + RepresentativeStructuresProjectState, SAXSProjectManager, build_project_paths, + component_source_mode_label, distribution_id_for_settings, effective_q_range_for_settings, guess_experimental_header_rows, @@ -39,6 +44,7 @@ "RegisteredFileSnapshot", "RegisteredFolderSnapshot", "PredictedStructuresProjectState", + "RepresentativeStructuresProjectState", "ProjectBuildResult", "ProjectArtifactPaths", "ProjectComponentEntry", @@ -46,6 +52,7 @@ "ProjectSettings", "SAXSProjectManager", "build_project_paths", + "component_source_mode_label", "distribution_id_for_settings", "effective_q_range_for_settings", "guess_experimental_header_rows", @@ -60,4 +67,7 @@ "export_prior_plot_data", "list_secondary_filter_elements", "plot_md_prior_histogram", + "prior_histogram_default_legend_title", + "prior_histogram_default_title", + "prior_histogram_default_y_label", ] diff --git a/src/saxshell/saxs/project_manager/prior_plot.py b/src/saxshell/saxs/project_manager/prior_plot.py index 98cbb79..e07823d 100644 --- a/src/saxshell/saxs/project_manager/prior_plot.py +++ b/src/saxshell/saxs/project_manager/prior_plot.py @@ -7,8 +7,12 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.lines import Line2D +from saxshell.plotting.stacked_histogram import ( + StackedHistogramPlotDefaults, + StackedHistogramPlotSettings, + render_stacked_histogram_export_payload, +) from saxshell.saxs.stoichiometry import ( format_stoich_for_axis, parse_stoich_label, @@ -341,7 +345,6 @@ def plot_md_prior_histogram( custom_label_order: list[tuple[str, str]] | None = None, ax=None, ): - small_total_threshold = 1.0 export_payload = build_prior_histogram_export_payload( json_path, mode=mode, @@ -349,121 +352,71 @@ def plot_md_prior_histogram( secondary_element=secondary_element, custom_label_order=custom_label_order, ) - labels = [str(label) for label in export_payload["labels"]] - axis_labels = [str(label) for label in export_payload["axis_labels"]] - segments = [str(segment) for segment in export_payload["segments"]] - segment_labels = [str(label) for label in export_payload["segment_labels"]] plot_mode = str(export_payload["plot_mode"]) - matrix = np.asarray(export_payload["matrix"], dtype=float) - color_keys = [ - [None if key is None else str(key) for key in row] - for row in export_payload.get("color_keys", []) - ] if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure - ax.clear() + defaults = StackedHistogramPlotDefaults( + title=prior_histogram_default_title(plot_mode, secondary_element), + x_label="Structure", + y_label=prior_histogram_default_y_label(plot_mode), + legend_title=prior_histogram_default_legend_title( + plot_mode, + secondary_element, + ), + default_colormap_name=cmap, + raw_category_labels=tuple( + str(label) for label in export_payload.get("labels", ()) + ), + default_label_entries=tuple( + ( + str(raw_label), + str(display_label), + ) + for raw_label, display_label in zip( + export_payload.get("labels", ()), + export_payload.get("axis_labels", ()), + strict=False, + ) + ), + ) + render_stacked_histogram_export_payload( + export_payload, + ax=ax, + defaults=defaults, + settings=StackedHistogramPlotSettings(), + cmap=cmap, + structure_segment_colors=structure_motif_colors, + show_percent=show_percent, + ) + return fig, ax - if not labels: - ax.set_title("No prior-weight data available") - ax.set_xlabel("Structure") - ax.set_ylabel("Fraction") - return fig, ax - colors = plt.get_cmap(cmap)( - np.linspace(0.1, 0.9, max(len(segment_labels), 1), endpoint=True) - ) +def prior_histogram_default_title( + mode: str, + secondary_element: str | None = None, +) -> str: + return _prior_plot_title(mode, secondary_element) - bottoms = np.zeros(len(labels), dtype=float) - for index, segment_label in enumerate(segment_labels): - heights_array = matrix[:, index] - bar_colors = colors[index] - if structure_motif_colors and not plot_mode.startswith("solvent_sort"): - bar_colors = [ - structure_motif_colors.get( - ( - color_keys[row_index][index] - if row_index < len(color_keys) - and index < len(color_keys[row_index]) - else f"{label}_{segments[index]}" - ), - fallback_color, - ) - for row_index, (label, fallback_color) in enumerate( - zip( - labels, - [colors[index]] * len(labels), - ) - ) - ] - ax.bar( - labels, - heights_array, - bottom=bottoms, - label=segment_label, - color=bar_colors, - edgecolor="white", - width=0.8, - ) - bottoms += heights_array - - if show_percent: - showed_small_total_marker = False - for index, total in enumerate(bottoms): - if total >= small_total_threshold: - ax.text( - index, - total + 1.0, - f"{total:.1f}%", - ha="center", - va="bottom", - fontsize=9, - ) - else: - ax.scatter(index, total + 1.0, color="red", s=16, zorder=4) - showed_small_total_marker = True - ax.set_ylim(0.0, max(bottoms.max(initial=0.0) + 4.0, 10.0)) - ax.set_xlabel("Structure") - ax.set_ylabel( +def prior_histogram_default_y_label(mode: str) -> str: + return ( "Percentage of Total Atom-Weighted Count (%)" - if plot_mode in {"atom_fraction", "solvent_sort_atom_fraction"} + if mode in {"atom_fraction", "solvent_sort_atom_fraction"} else "Percentage of Total Structures (%)" ) - ax.set_title(_prior_plot_title(plot_mode, secondary_element)) - ax.set_xticks(range(len(labels))) - ax.set_xticklabels( - axis_labels, - rotation=45, - ha="right", - ) - legend_handles, legend_labels = ax.get_legend_handles_labels() - if show_percent and showed_small_total_marker: - legend_handles.append( - Line2D( - [], - [], - marker="o", - color="red", - linestyle="None", - markersize=5, - ) - ) - legend_labels.append("< 1% total") - ax.legend( - legend_handles, - legend_labels, - title=( - "Motif" - if not plot_mode.startswith("solvent_sort") - else f"{secondary_element} count" - ), - bbox_to_anchor=(1.02, 1.0), - loc="upper left", + + +def prior_histogram_default_legend_title( + mode: str, + secondary_element: str | None = None, +) -> str: + return ( + "Motif" + if not str(mode).startswith("solvent_sort") + else f"{secondary_element or 'Secondary'} count" ) - fig.tight_layout() - return fig, ax def list_secondary_filter_elements( diff --git a/src/saxshell/saxs/project_manager/project.py b/src/saxshell/saxs/project_manager/project.py index 6695f7d..487679d 100644 --- a/src/saxshell/saxs/project_manager/project.py +++ b/src/saxshell/saxs/project_manager/project.py @@ -27,6 +27,8 @@ analyze_contrast_representatives, ) from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, COMPONENT_BUILD_MODE_CONTRAST, COMPONENT_BUILD_MODE_NO_CONTRAST, component_build_mode_label, @@ -44,6 +46,40 @@ from .prior_plot import export_prior_plot_data ProgressCallback = Callable[[int, int, str], None] +COMPONENT_SOURCE_MODE_AVERAGE = "average" +COMPONENT_SOURCE_MODE_REPRESENTATIVE = "representative" +_EXTERNAL_COMPONENT_BUILD_MODES = { + COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + COMPONENT_BUILD_MODE_CONTRAST, +} + + +def normalize_component_source_mode(value: object) -> str: + text = str(value or "").strip().lower() + if text in { + COMPONENT_SOURCE_MODE_REPRESENTATIVE, + "representative_structure", + "representative_structures", + "representative structures", + }: + return COMPONENT_SOURCE_MODE_REPRESENTATIVE + return COMPONENT_SOURCE_MODE_AVERAGE + + +def component_source_mode_for_settings(settings: "ProjectSettings") -> str: + return ( + COMPONENT_SOURCE_MODE_REPRESENTATIVE + if settings.use_representative_structures + else COMPONENT_SOURCE_MODE_AVERAGE + ) + + +def component_source_mode_label(value: object) -> str: + normalized = normalize_component_source_mode(value) + if normalized == COMPONENT_SOURCE_MODE_REPRESENTATIVE: + return "Representative Structures" + return "Average (default)" @dataclass(slots=True) @@ -222,6 +258,117 @@ class PredictedStructuresProjectState: prior_artifacts_ready: bool +@dataclass(slots=True, frozen=True) +class RepresentativeStructuresProjectState: + representative_selection_file: Path | None + representative_count: int + partialsolv_dir: Path | None + nosolv_dir: Path | None + fullsolv_dir: Path | None + source_files_ready: bool = False + available_modes: tuple[str, ...] = () + expected_representative_count: int = 0 + missing_representative_count: int = 0 + invalid_representative_count: int = 0 + selection_ready: bool = False + updated_at: str | None = None + + +def _representative_bin_key( + structure: object, + motif: object, +) -> tuple[str, str] | None: + structure_text = str(structure or "").strip() + if not structure_text: + return None + return ( + structure_text, + _normalized_nonempty_text(motif, default="no_motif"), + ) + + +def _representative_issue_key( + issue: object, +) -> tuple[str, str] | None: + return _representative_bin_key( + getattr(issue, "structure", ""), + getattr(issue, "motif", "no_motif"), + ) + + +def _representative_weight_value(value: object) -> float: + try: + return float(value or 0.0) + except (TypeError, ValueError): + return 0.0 + + +def _representative_expected_bins_from_prior_weights( + prior_weights_path: str | Path | None, +) -> tuple[tuple[str, str], ...] | None: + if prior_weights_path is None: + return None + path = Path(prior_weights_path).expanduser().resolve() + if not path.is_file(): + return None + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return None + if not isinstance(payload, dict): + return None + raw_structures = payload.get("structures", {}) + if not isinstance(raw_structures, dict): + return None + + value_kind = str(payload.get("value_kind", "count")).strip().lower() + expected_bins: set[tuple[str, str]] = set() + for structure, raw_motifs in raw_structures.items(): + if not isinstance(raw_motifs, dict): + continue + for motif, raw_motif_payload in raw_motifs.items(): + if not isinstance(raw_motif_payload, dict): + continue + source_kind = ( + str(raw_motif_payload.get("source_kind", "cluster_dir")) + .strip() + .lower() + ) + if source_kind == "predicted_structure": + continue + if value_kind == "normalized_weight": + weight_value = _representative_weight_value( + raw_motif_payload.get( + "normalized_weight", + raw_motif_payload.get("weight", 0.0), + ) + or 0.0 + ) + else: + weight_value = _representative_weight_value( + raw_motif_payload.get( + "count", + raw_motif_payload.get("weight", 0.0), + ) + or 0.0 + ) + if weight_value <= 0.0: + continue + key = _representative_bin_key(structure, motif) + if key is not None: + expected_bins.add(key) + + return tuple( + sorted( + expected_bins, + key=lambda item: ( + _natural_sort_key(item[0]), + _natural_sort_key(item[1]), + ), + ) + ) + + @dataclass(slots=True, frozen=True) class SavedDistributionRecord: distribution_id: str @@ -233,6 +380,7 @@ class SavedDistributionRecord: template_name: str | None = None component_build_mode: str = COMPONENT_BUILD_MODE_NO_CONTRAST use_predicted_structure_weights: bool = False + use_representative_structures: bool = False exclude_elements: tuple[str, ...] = () clusters_dir: str | None = None q_min: float | None = None @@ -241,6 +389,7 @@ class SavedDistributionRecord: q_points: int | None = None component_artifacts_ready: bool = False prior_artifacts_ready: bool = False + built_component_source_mode: str | None = None @dataclass(slots=True) @@ -505,6 +654,7 @@ class ProjectSettings: project_dir: str model_only_mode: bool = False use_predicted_structure_weights: bool = False + use_representative_structures: bool = False frames_dir: str | None = None pdb_frames_dir: str | None = None clusters_dir: str | None = None @@ -541,10 +691,12 @@ class ProjectSettings: exclude_elements: list[str] = field(default_factory=list) component_trace_colors: dict[str, str] = field(default_factory=dict) component_trace_color_scheme: str = "default" + component_plot_state: dict[str, object] = field(default_factory=dict) experimental_trace_visible: bool = True experimental_trace_color: str = "#000000" solvent_trace_visible: bool = True solvent_trace_color: str = "#008000" + prior_plot_state: dict[str, object] = field(default_factory=dict) runtime_bundle_opener: str | None = None template_reset_template: str | None = None template_reset_parameter_entries: list[dict[str, object]] = field( @@ -645,6 +797,9 @@ def to_dict(self) -> dict[str, object]: payload["component_trace_color_scheme"] = ( str(self.component_trace_color_scheme).strip() or "default" ) + payload["component_plot_state"] = _normalized_json_object( + self.component_plot_state + ) payload["experimental_trace_visible"] = bool( self.experimental_trace_visible ) @@ -653,6 +808,9 @@ def to_dict(self) -> dict[str, object]: ) payload["solvent_trace_visible"] = bool(self.solvent_trace_visible) payload["solvent_trace_color"] = str(self.solvent_trace_color) + payload["prior_plot_state"] = _normalized_json_object( + self.prior_plot_state + ) payload["runtime_bundle_opener"] = _optional_str( self.runtime_bundle_opener ) @@ -717,6 +875,9 @@ def from_dict(cls, payload: dict[str, object]) -> "ProjectSettings": use_predicted_structure_weights=bool( payload.get("use_predicted_structure_weights", False) ), + use_representative_structures=bool( + payload.get("use_representative_structures", False) + ), frames_dir=_optional_str(payload.get("frames_dir")), pdb_frames_dir=_optional_str(payload.get("pdb_frames_dir")), clusters_dir=_optional_str(payload.get("clusters_dir")), @@ -803,6 +964,9 @@ def from_dict(cls, payload: dict[str, object]) -> "ProjectSettings": payload.get("component_trace_color_scheme", "default") ).strip() or "default", + component_plot_state=_normalized_json_object( + payload.get("component_plot_state", {}) + ), experimental_trace_visible=bool( payload.get("experimental_trace_visible", True) ), @@ -817,6 +981,9 @@ def from_dict(cls, payload: dict[str, object]) -> "ProjectSettings": payload.get("solvent_trace_color", "#008000") ).strip() or "#008000", + prior_plot_state=_normalized_json_object( + payload.get("prior_plot_state", {}) + ), runtime_bundle_opener=_optional_str( payload.get("runtime_bundle_opener") ), @@ -867,6 +1034,16 @@ def _normalized_prior_x_axis_order(raw: object) -> list[list[str]]: return result +def _normalized_json_object(raw: object) -> dict[str, object]: + if not isinstance(raw, dict): + return {} + try: + normalized = json.loads(json.dumps(raw)) + except (TypeError, ValueError): + return {} + return normalized if isinstance(normalized, dict) else {} + + def _distribution_id_for_settings( settings: ProjectSettings, *, @@ -905,31 +1082,81 @@ def distribution_id_for_settings(settings: ProjectSettings) -> str: def distribution_label_for_settings(settings: ProjectSettings) -> str: + return _distribution_label( + use_predicted_structure_weights=( + settings.use_predicted_structure_weights + ), + component_build_mode=settings.component_build_mode, + template_name=settings.selected_model_template, + exclude_elements=settings.exclude_elements, + q_min=settings.q_min, + q_max=settings.q_max, + use_experimental_grid=settings.use_experimental_grid, + q_points=settings.q_points, + ) + + +def _distribution_label( + *, + use_predicted_structure_weights: bool, + component_build_mode: object, + template_name: str | None, + exclude_elements: object, + q_min: float | None, + q_max: float | None, + use_experimental_grid: bool, + q_points: int | None, +) -> str: mode = ( "Observed + Predicted Structures" - if settings.use_predicted_structure_weights + if use_predicted_structure_weights else "Observed Only" ) - build_mode = component_build_mode_label(settings.component_build_mode) - template_name = ( - str(settings.selected_model_template or "").strip() or "Unspecified" - ) - excluded = ", ".join(sorted(set(settings.exclude_elements))) or "None" - if settings.q_min is None or settings.q_max is None: + build_mode = component_build_mode_label(component_build_mode) + resolved_template_name = str(template_name or "").strip() or "Unspecified" + excluded = ", ".join(sorted(set(_normalized_elements(exclude_elements)))) + excluded = excluded or "None" + if q_min is None or q_max is None: q_range = "default" else: - q_range = f"{float(settings.q_min):.6g} to {float(settings.q_max):.6g}" + q_range = f"{float(q_min):.6g} to {float(q_max):.6g}" grid = ( "experimental grid" - if settings.use_experimental_grid - else f"resample {int(settings.q_points or 0)}" + if use_experimental_grid + else f"resample {int(q_points or 0)}" ) return ( - f"{mode} | Build: {build_mode} | Template: {template_name} | " + f"{mode} | Build: {build_mode} | Template: {resolved_template_name} | " f"Excluded: {excluded} | q-range: {q_range} | Grid: {grid}" ) +def _distribution_label_from_metadata_payload( + payload: dict[str, object], +) -> str: + stored_label = str(payload.get("label", "")).strip() + template_name = _optional_str(payload.get("template_name")) + if template_name is None: + for part in stored_label.split(" | "): + if part.startswith("Template: "): + template_name = ( + str(part.removeprefix("Template: ")).strip() or None + ) + break + return _distribution_label( + use_predicted_structure_weights=bool( + payload.get("use_predicted_structure_weights", False) + ), + component_build_mode=payload.get("component_build_mode"), + template_name=template_name, + exclude_elements=payload.get("exclude_elements", []), + q_min=_optional_float(payload.get("q_min")), + q_max=_optional_float(payload.get("q_max")), + use_experimental_grid=bool(payload.get("use_experimental_grid", True)), + q_points=_optional_int(payload.get("q_points")), + ) + + def _distribution_id_candidates_for_settings( settings: ProjectSettings, ) -> tuple[str, ...]: @@ -1177,21 +1404,43 @@ def _project_has_saved_distributions(project_dir: str | Path) -> bool: return any(path.is_dir() for path in saved_dir.iterdir()) +def _distribution_built_component_source_mode_from_payload( + payload: dict[str, object], +) -> str | None: + explicit_mode = _optional_str(payload.get("built_component_source_mode")) + if explicit_mode is not None: + return normalize_component_source_mode(explicit_mode) + if bool(payload.get("component_artifacts_ready", False)): + return normalize_component_source_mode( + COMPONENT_SOURCE_MODE_REPRESENTATIVE + if bool(payload.get("use_representative_structures", False)) + else COMPONENT_SOURCE_MODE_AVERAGE + ) + return None + + def _distribution_metadata_from_payload( distribution_dir: Path, metadata_path: Path, payload: dict[str, object], ) -> SavedDistributionRecord | None: distribution_id = str(payload.get("distribution_id", "")).strip() - label = str(payload.get("label", "")).strip() - if not distribution_id or not label: + if not distribution_id: return None exclude_elements = tuple( _normalized_elements(payload.get("exclude_elements", [])) ) + built_component_source_mode = ( + _distribution_built_component_source_mode_from_payload(payload) + ) + use_representative_structures = bool( + payload.get("use_representative_structures", False) + ) + if built_component_source_mode == COMPONENT_SOURCE_MODE_REPRESENTATIVE: + use_representative_structures = True return SavedDistributionRecord( distribution_id=distribution_id, - label=label, + label=_distribution_label_from_metadata_payload(payload), distribution_dir=distribution_dir, metadata_path=metadata_path, created_at=_optional_str(payload.get("created_at")), @@ -1203,6 +1452,7 @@ def _distribution_metadata_from_payload( use_predicted_structure_weights=bool( payload.get("use_predicted_structure_weights", False) ), + use_representative_structures=use_representative_structures, exclude_elements=exclude_elements, clusters_dir=_optional_str(payload.get("clusters_dir")), q_min=_optional_float(payload.get("q_min")), @@ -1215,6 +1465,7 @@ def _distribution_metadata_from_payload( prior_artifacts_ready=bool( payload.get("prior_artifacts_ready", False) ), + built_component_source_mode=built_component_source_mode, ) @@ -1841,6 +2092,305 @@ def inspect_predicted_structures( prior_artifacts_ready=prior_artifacts_ready, ) + def inspect_representative_structures( + self, + project_dir: str | Path, + *, + prior_weights_path: str | Path | None = None, + ) -> RepresentativeStructuresProjectState: + from saxshell.fullrmc.project_model import ensure_rmcsetup_structure + from saxshell.fullrmc.representatives import ( + load_representative_selection_metadata, + normalize_representative_source_solvent_mode, + representative_structure_variant_path, + ) + from saxshell.fullrmc.solvent_handling import ( + load_solvent_handling_metadata, + ) + + paths = ensure_rmcsetup_structure(project_dir) + metadata = load_representative_selection_metadata( + paths.representative_selection_path + ) + solvent_metadata = load_solvent_handling_metadata( + paths.solvent_handling_path + ) + representative_entries = ( + list(metadata.representative_entries) + if metadata is not None + else [] + ) + representative_count = len(representative_entries) + representative_bin_keys = [ + key + for entry in representative_entries + if (key := _representative_bin_key(entry.structure, entry.motif)) + is not None + ] + representative_bin_key_set = set(representative_bin_keys) + expected_bin_keys = _representative_expected_bins_from_prior_weights( + prior_weights_path + ) + expected_bin_key_set = ( + None if expected_bin_keys is None else set(expected_bin_keys) + ) + expected_representative_count = 0 + missing_representative_count = 0 + invalid_representative_count = 0 + if metadata is not None: + metadata_expected_bin_keys = [ + key + for entry in metadata.distribution_selection.active_entries( + metadata.settings.minimum_cluster_count_for_analysis + ) + if ( + key := _representative_bin_key( + entry.structure, entry.motif + ) + ) + is not None + ] + if expected_bin_key_set is None: + expected_bin_key_set = set(metadata_expected_bin_keys) + expected_representative_count = len(expected_bin_key_set) + missing_representative_count = len(metadata.missing_bins) + invalid_representative_count = len(metadata.invalid_bins) + if expected_representative_count <= 0: + expected_representative_count = ( + representative_count + + missing_representative_count + + invalid_representative_count + ) + else: + invalid_issue_keys = { + key + for issue in metadata.invalid_bins + if (key := _representative_issue_key(issue)) + in expected_bin_key_set + } + unresolved_expected_keys = ( + expected_bin_key_set - representative_bin_key_set + ) + missing_representative_count = len( + unresolved_expected_keys - invalid_issue_keys + ) + invalid_representative_count = len(invalid_issue_keys) + invalid_representative_count += len( + representative_bin_key_set - expected_bin_key_set + ) + invalid_representative_count += max( + representative_count - len(representative_bin_key_set), + 0, + ) + expected_representative_count = len(expected_bin_key_set) + elif expected_bin_key_set is not None: + expected_representative_count = len(expected_bin_key_set) + missing_representative_count = len( + expected_bin_key_set - representative_bin_key_set + ) + invalid_representative_count = len( + representative_bin_key_set - expected_bin_key_set + ) + invalid_representative_count += max( + representative_count - len(representative_bin_key_set), + 0, + ) + available_modes: list[str] = [] + source_files_ready = False + if representative_count > 0: + source_paths = [ + Path(entry.source_file).expanduser().resolve() + for entry in representative_entries + ] + source_files_ready = bool(source_paths) and all( + path.is_file() for path in source_paths + ) + if source_files_ready: + for mode in ("nosolv", "partialsolv", "fullsolv"): + if ( + all( + representative_structure_variant_path( + entry.source_file, + mode, + ) + is not None + or normalize_representative_source_solvent_mode( + entry.source_solvent_mode + ) + == mode + for entry in representative_entries + ) + and mode not in available_modes + ): + available_modes.append(mode) + if solvent_metadata is not None and solvent_metadata.entries: + no_solvent_paths = [ + Path(entry.no_solvent_pdb).expanduser().resolve() + for entry in solvent_metadata.entries + ] + full_solvent_paths = [ + Path(entry.completed_pdb).expanduser().resolve() + for entry in solvent_metadata.entries + ] + if ( + no_solvent_paths + and all(path.is_file() for path in no_solvent_paths) + and "nosolv" not in available_modes + ): + available_modes.append("nosolv") + if ( + full_solvent_paths + and all(path.is_file() for path in full_solvent_paths) + and "fullsolv" not in available_modes + ): + available_modes.append("fullsolv") + selection_count_matches = ( + representative_count == expected_representative_count + if expected_bin_key_set is not None + else representative_count >= expected_representative_count + ) + selection_bins_match = ( + representative_bin_key_set == expected_bin_key_set + if expected_bin_key_set is not None + else True + ) + selection_ready = bool( + representative_count > 0 + and source_files_ready + and missing_representative_count == 0 + and invalid_representative_count == 0 + and selection_count_matches + and selection_bins_match + ) + return RepresentativeStructuresProjectState( + representative_selection_file=( + paths.representative_selection_path + if paths.representative_selection_path.is_file() + else None + ), + representative_count=int(representative_count), + partialsolv_dir=( + paths.representative_partial_solvent_dir + if paths.representative_partial_solvent_dir.is_dir() + else None + ), + nosolv_dir=( + paths.pdb_no_solvent_dir + if paths.pdb_no_solvent_dir.is_dir() + else None + ), + fullsolv_dir=( + paths.pdb_with_solvent_dir + if paths.pdb_with_solvent_dir.is_dir() + else None + ), + source_files_ready=bool(source_files_ready), + available_modes=tuple(available_modes), + expected_representative_count=max( + int(expected_representative_count), + 0, + ), + missing_representative_count=max( + int(missing_representative_count), + 0, + ), + invalid_representative_count=max( + int(invalid_representative_count), + 0, + ), + selection_ready=selection_ready, + updated_at=( + None + if metadata is None + else _optional_str(metadata.updated_at) + ), + ) + + def component_artifacts_ready( + self, + settings: ProjectSettings, + *, + artifact_paths: ProjectArtifactPaths | None = None, + ) -> bool: + active_artifact_paths = artifact_paths or project_artifact_paths( + settings + ) + return bool( + active_artifact_paths.component_dir.is_dir() + and any(active_artifact_paths.component_dir.glob("*.txt")) + and active_artifact_paths.component_map_file.is_file() + ) + + def built_component_source_mode( + self, + settings: ProjectSettings, + *, + artifact_paths: ProjectArtifactPaths | None = None, + ) -> str | None: + active_artifact_paths = artifact_paths or project_artifact_paths( + settings + ) + metadata_path = active_artifact_paths.distribution_metadata_file + if metadata_path is None or not metadata_path.is_file(): + return None + try: + payload = json.loads(metadata_path.read_text(encoding="utf-8")) + except Exception: + return None + if not isinstance(payload, dict): + return None + return _distribution_built_component_source_mode_from_payload(payload) + + def component_artifacts_match_settings( + self, + settings: ProjectSettings, + *, + artifact_paths: ProjectArtifactPaths | None = None, + ) -> bool: + normalized_build_mode = normalize_component_build_mode( + settings.component_build_mode + ) + if normalized_build_mode in { + COMPONENT_BUILD_MODE_CONTRAST, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + }: + return self.component_artifacts_ready( + settings, + artifact_paths=artifact_paths, + ) + active_artifact_paths = artifact_paths or project_artifact_paths( + settings + ) + if not self.component_artifacts_ready( + settings, + artifact_paths=active_artifact_paths, + ): + return False + saved_mode = self.built_component_source_mode( + settings, + artifact_paths=active_artifact_paths, + ) + if saved_mode is None: + return True + return saved_mode == component_source_mode_for_settings(settings) + + def _full_observed_cluster_inventory( + self, + settings: ProjectSettings, + *, + progress_callback: ProgressCallback | None = None, + ) -> _ClusterInventory: + clusters_dir = settings.resolved_clusters_dir + if clusters_dir is None: + raise ValueError( + "Select a clusters directory before building models." + ) + return self._collect_cluster_inventory( + clusters_dir, + progress_callback=progress_callback, + ) + def list_saved_distributions( self, project_dir: str | Path, @@ -1921,6 +2471,9 @@ def settings_for_saved_distribution( working_settings.use_predicted_structure_weights = bool( record.use_predicted_structure_weights ) + working_settings.use_representative_structures = bool( + record.use_representative_structures + ) working_settings.exclude_elements = list(record.exclude_elements) working_settings.clusters_dir = record.clusters_dir working_settings.q_min = record.q_min @@ -2210,13 +2763,9 @@ def build_scattering_components( storage_mode="distribution", ) self.ensure_artifact_dirs(artifact_paths) - clusters_dir = settings.resolved_clusters_dir - if clusters_dir is None: - raise ValueError( - "Select a clusters directory before building models." - ) predicted_dataset_file: Path | None = None predicted_component_count = 0 + built_component_source_mode: str | None = None if ( normalize_component_build_mode(settings.component_build_mode) == COMPONENT_BUILD_MODE_CONTRAST @@ -2239,16 +2788,27 @@ def build_scattering_components( cluster_inventory=cluster_inventory, progress_callback=progress_callback, ) + built_component_source_mode = component_source_mode_for_settings( + settings + ) else: - cluster_inventory = self._collect_cluster_inventory( - clusters_dir, + full_cluster_inventory = self._full_observed_cluster_inventory( + settings, progress_callback=progress_callback, ) - settings.available_elements = cluster_inventory.available_elements - settings.cluster_inventory_rows = cluster_inventory.cluster_rows - cluster_rows = list(cluster_inventory.cluster_rows) - component_entries = self._component_entries_from_cluster_bins( - cluster_inventory.cluster_bins + cluster_inventory = self._active_observed_cluster_inventory( + settings, + progress_callback=progress_callback, + ) + settings.available_elements = ( + full_cluster_inventory.available_elements + ) + settings.cluster_inventory_rows = ( + full_cluster_inventory.cluster_rows + ) + cluster_rows = list(full_cluster_inventory.cluster_rows) + component_entries = self._component_entries_from_cluster_inventory( + cluster_inventory ) reused_observed_components = ( settings.use_predicted_structure_weights @@ -2270,17 +2830,15 @@ def build_scattering_components( progress_callback=progress_callback, progress_total=max(cluster_inventory.total_files, 1), ) - component_entries = [ - ProjectComponentEntry( - structure=component.structure, - motif=component.motif, - file_count=component.file_count, - representative=component.representative, - profile_file=component.output_path.name, - source_dir=str(component.source_dir), + component_entries = ( + self._component_entries_from_averaged_components( + averaged_components, + cluster_inventory=cluster_inventory, ) - for component in averaged_components - ] + ) + built_component_source_mode = component_source_mode_for_settings( + settings + ) if ( normalize_component_build_mode(settings.component_build_mode) != COMPONENT_BUILD_MODE_CONTRAST @@ -2297,11 +2855,11 @@ def build_scattering_components( artifact_paths=artifact_paths, q_values=q_values, component_entries=component_entries, - cluster_inventory=cluster_inventory, + cluster_inventory=full_cluster_inventory, ) settings.available_elements = sorted( { - *settings.available_elements, + *full_cluster_inventory.available_elements, *predicted_available_elements, }, key=_natural_sort_key, @@ -2313,6 +2871,7 @@ def build_scattering_components( self._write_distribution_metadata( settings, artifact_paths=artifact_paths, + built_component_source_mode=built_component_source_mode, ) self.save_project( settings, @@ -2572,13 +3131,8 @@ def generate_prior_weights( storage_mode="distribution", ) self.ensure_artifact_dirs(artifact_paths) - clusters_dir = settings.resolved_clusters_dir - if clusters_dir is None: - raise ValueError( - "Select a clusters directory before generating prior weights." - ) - cluster_inventory = self._collect_cluster_inventory( - clusters_dir, + cluster_inventory = self._full_observed_cluster_inventory( + settings, progress_callback=progress_callback, ) settings.available_elements = cluster_inventory.available_elements @@ -2586,13 +3140,23 @@ def generate_prior_weights( component_entries = self._component_entries_for_prior_weights( settings, artifact_paths=artifact_paths, - cluster_bins=cluster_inventory.cluster_bins, + cluster_inventory=cluster_inventory, ) md_prior_weights_path = artifact_paths.prior_weights_file prior_plot_data_path = artifact_paths.prior_plot_data_file cluster_rows = list(cluster_inventory.cluster_rows) predicted_dataset_file: Path | None = None predicted_component_count = 0 + prior_origin_dir = ( + cluster_inventory.cluster_bins[0].source_dir.parent + if settings.use_representative_structures + and cluster_inventory.cluster_bins + else settings.resolved_clusters_dir + ) + if prior_origin_dir is None: + raise ValueError( + "Select a clusters directory before generating prior weights." + ) if progress_callback is not None: progress_callback( 0, @@ -2617,7 +3181,7 @@ def generate_prior_weights( predicted_component_count, ) = self._write_md_prior_weights_with_predicted_structures( md_prior_weights_path=md_prior_weights_path, - clusters_dir=clusters_dir, + clusters_dir=prior_origin_dir, component_entries=component_entries, cluster_bins=cluster_inventory.cluster_bins, available_elements=cluster_inventory.available_elements, @@ -2635,13 +3199,18 @@ def generate_prior_weights( else: self._write_md_prior_weights( md_prior_weights_path=md_prior_weights_path, - clusters_dir=clusters_dir, + clusters_dir=prior_origin_dir, component_entries=component_entries, cluster_bins=cluster_inventory.cluster_bins, available_elements=cluster_inventory.available_elements, q_values=q_values, ) export_prior_plot_data(md_prior_weights_path, prior_plot_data_path) + if ( + normalize_component_build_mode(settings.component_build_mode) + in _EXTERNAL_COMPONENT_BUILD_MODES + ): + self._reset_distribution_component_artifacts(artifact_paths) self._write_distribution_metadata( settings, artifact_paths=artifact_paths, @@ -2670,17 +3239,30 @@ def generate_prior_weights( predicted_component_count=predicted_component_count, ) + def _reset_distribution_component_artifacts( + self, + artifact_paths: ProjectArtifactPaths, + ) -> None: + if artifact_paths.component_dir.is_dir(): + shutil.rmtree(artifact_paths.component_dir) + elif artifact_paths.component_dir.exists(): + artifact_paths.component_dir.unlink() + if artifact_paths.component_map_file.is_file(): + artifact_paths.component_map_file.unlink() + def _write_distribution_metadata( self, settings: ProjectSettings, *, artifact_paths: ProjectArtifactPaths, + built_component_source_mode: str | None = None, ) -> None: metadata_path = artifact_paths.distribution_metadata_file if metadata_path is None: return now = datetime.now().isoformat(timespec="seconds") created_at = now + existing: dict[str, object] = {} if metadata_path.is_file(): try: existing = json.loads( @@ -2689,14 +3271,27 @@ def _write_distribution_metadata( except Exception: existing = {} created_at = str(existing.get("created_at", now)).strip() or now - component_ready = bool( - artifact_paths.component_dir.is_dir() - and any(artifact_paths.component_dir.glob("*.txt")) - and artifact_paths.component_map_file.is_file() + component_ready = self.component_artifacts_ready( + settings, + artifact_paths=artifact_paths, ) prior_ready = bool(artifact_paths.prior_weights_file.is_file()) + resolved_component_source_mode: str | None = None + if component_ready: + if built_component_source_mode is not None: + resolved_component_source_mode = ( + normalize_component_source_mode( + built_component_source_mode + ) + ) + else: + resolved_component_source_mode = ( + _distribution_built_component_source_mode_from_payload( + existing + ) + ) payload = { - "schema_version": 1, + "schema_version": 2, "distribution_id": artifact_paths.distribution_id, "label": distribution_label_for_settings(settings), "created_at": created_at, @@ -2708,6 +3303,9 @@ def _write_distribution_metadata( "use_predicted_structure_weights": bool( settings.use_predicted_structure_weights ), + "use_representative_structures": bool( + settings.use_representative_structures + ), "exclude_elements": sorted(set(settings.exclude_elements)), "clusters_dir": ( None @@ -2724,6 +3322,7 @@ def _write_distribution_metadata( ), "component_artifacts_ready": component_ready, "prior_artifacts_ready": prior_ready, + "built_component_source_mode": resolved_component_source_mode, } metadata_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.write_text( @@ -2970,23 +3569,87 @@ def _component_entries_from_clusters( discover_cluster_bins(clusters_dir) ) + def _component_entries_from_cluster_inventory( + self, + cluster_inventory: _ClusterInventory, + ) -> list[ProjectComponentEntry]: + count_lookup = { + ( + str(row.get("structure", "")).strip(), + _normalized_nonempty_text( + row.get("motif"), + default="no_motif", + ), + ): max(int(row.get("count", 0) or 0), 1) + for row in cluster_inventory.cluster_rows + } + entries = self._component_entries_from_cluster_bins( + cluster_inventory.cluster_bins + ) + return [ + replace( + entry, + file_count=count_lookup.get( + (entry.structure, entry.motif), + max(int(entry.file_count), 1), + ), + ) + for entry in entries + ] + def _component_entries_for_prior_weights( self, settings: ProjectSettings, *, artifact_paths: ProjectArtifactPaths, - cluster_bins: list[ClusterBin], + cluster_inventory: _ClusterInventory, ) -> list[ProjectComponentEntry]: if ( normalize_component_build_mode(settings.component_build_mode) != COMPONENT_BUILD_MODE_CONTRAST ): - return self._component_entries_from_cluster_bins(cluster_bins) + return self._component_entries_from_cluster_inventory( + cluster_inventory + ) return self._contrast_component_entries_from_distribution_artifacts( artifact_paths=artifact_paths, - cluster_bins=cluster_bins, + cluster_bins=cluster_inventory.cluster_bins, ) + def _component_entries_from_averaged_components( + self, + averaged_components: list[object], + *, + cluster_inventory: _ClusterInventory, + ) -> list[ProjectComponentEntry]: + count_lookup = { + ( + str(row.get("structure", "")).strip(), + _normalized_nonempty_text( + row.get("motif"), + default="no_motif", + ), + ): max(int(row.get("count", 0) or 0), 1) + for row in cluster_inventory.cluster_rows + } + entries: list[ProjectComponentEntry] = [] + for component in averaged_components: + key = (component.structure, component.motif) + entries.append( + ProjectComponentEntry( + structure=component.structure, + motif=component.motif, + file_count=count_lookup.get( + key, + max(int(component.file_count), 1), + ), + representative=component.representative, + profile_file=component.output_path.name, + source_dir=str(component.source_dir), + ) + ) + return entries + def _component_entries_from_cluster_bins( self, cluster_bins: list[ClusterBin], @@ -3136,6 +3799,180 @@ def _contrast_representative_file_for_trace( return distribution_snapshot_path return distribution_snapshot_path + def _active_observed_cluster_inventory( + self, + settings: ProjectSettings, + *, + progress_callback: ProgressCallback | None = None, + ) -> _ClusterInventory: + if not settings.use_representative_structures: + return self._full_observed_cluster_inventory( + settings, + progress_callback=progress_callback, + ) + return self._representative_cluster_inventory( + settings, + progress_callback=progress_callback, + ) + + def _representative_cluster_inventory( + self, + settings: ProjectSettings, + *, + progress_callback: ProgressCallback | None = None, + force_refresh: bool = False, + ) -> _ClusterInventory: + from saxshell.fullrmc.project_model import ensure_rmcsetup_structure + from saxshell.fullrmc.representatives import ( + load_representative_selection_metadata, + ) + + rmcsetup_paths = ensure_rmcsetup_structure(settings.project_dir) + cache_key = rmcsetup_paths.representative_selection_path.resolve() + if not force_refresh: + cached = self._cluster_inventory_cache.get(cache_key) + if cached is not None: + return cached + + metadata = load_representative_selection_metadata( + rmcsetup_paths.representative_selection_path + ) + representative_entries = ( + [] if metadata is None else list(metadata.representative_entries) + ) + if not representative_entries: + raise ValueError( + "Use Representative Structures is enabled, but this project " + "does not have any saved representative structures yet." + ) + + total_files = len(representative_entries) + if progress_callback is not None: + progress_callback( + 0, + max(total_files, 1), + "Importing representative structures...", + ) + + cluster_bins: list[ClusterBin] = [] + cluster_rows: list[dict[str, object]] = [] + available_elements: set[str] = set() + processed_files = 0 + selected_weight_total = sum( + max(float(entry.selected_weight), 0.0) + for entry in representative_entries + ) + row_weight_seed: list[tuple[str, str, int, float]] = [] + for entry in representative_entries: + source_file = Path(entry.source_file).expanduser().resolve() + if not source_file.is_file(): + raise FileNotFoundError( + "Representative structure file was not found: " + f"{source_file}" + ) + representative_name = ( + _optional_str(entry.source_file_name) or source_file.name + ) + motif = _normalized_nonempty_text(entry.motif, default="no_motif") + cluster_bins.append( + ClusterBin( + structure=entry.structure, + motif=motif, + source_dir=source_file.parent, + files=(source_file,), + representative=representative_name, + ) + ) + available_elements.update(scan_structure_elements(source_file)) + row_weight_seed.append( + ( + entry.structure, + motif, + max(int(entry.cluster_count), 1), + max(float(entry.selected_weight), 0.0), + ) + ) + processed_files += 1 + if progress_callback is not None: + progress_callback( + processed_files, + max(total_files, 1), + ( + "Importing representative structures: " + f"{entry.structure}/{motif}" + ), + ) + + fallback_weight_total = sum( + cluster_count + for _structure, _motif, cluster_count, _weight in row_weight_seed + ) + atom_weight_total = 0.0 + row_weights: dict[tuple[str, str], float] = {} + for ( + structure, + motif, + cluster_count, + selected_weight, + ) in row_weight_seed: + weight = ( + selected_weight / selected_weight_total + if selected_weight_total > 0.0 + else cluster_count / max(fallback_weight_total, 1) + ) + row_weights[(structure, motif)] = float(weight) + atom_weight_total += float(weight) * _structure_atom_weight( + structure + ) + + for entry, cluster_bin in zip( + representative_entries, + cluster_bins, + strict=True, + ): + source_file = cluster_bin.files[0] + key = (cluster_bin.structure, cluster_bin.motif) + structure_weight = row_weights.get(key, 0.0) + cluster_rows.append( + { + "structure": cluster_bin.structure, + "motif": cluster_bin.motif, + "count": max(int(entry.cluster_count), 1), + "weight": float(structure_weight), + "source_kind": "representative_structure", + "source_dir": str(cluster_bin.source_dir), + "source_file": str(source_file), + "source_file_name": cluster_bin.representative or "", + "representative": cluster_bin.representative or "", + "structure_fraction_percent": float(structure_weight) + * 100.0, + "atom_fraction_percent": ( + float(structure_weight) + * _structure_atom_weight(cluster_bin.structure) + / atom_weight_total + * 100.0 + if atom_weight_total > 0.0 + else 0.0 + ), + } + ) + cluster_rows.sort( + key=lambda row: ( + _natural_sort_key(str(row["structure"])), + _natural_sort_key(str(row["motif"])), + ) + ) + inventory = _ClusterInventory( + cluster_bins=cluster_bins, + available_elements=sorted( + available_elements, key=_natural_sort_key + ), + cluster_rows=cluster_rows, + total_files=max(total_files, 1), + ) + self._cluster_inventory_cache[cache_key] = inventory + return inventory + def _collect_cluster_inventory( self, clusters_dir: str | Path, @@ -3390,6 +4227,15 @@ def _reuse_observed_component_artifacts( ) if observed_artifact_paths is None: return False + observed_only_settings = replace( + settings, + use_predicted_structure_weights=False, + ) + if not self.component_artifacts_match_settings( + observed_only_settings, + artifact_paths=observed_artifact_paths, + ): + return False if not observed_artifact_paths.component_map_file.is_file(): return False try: @@ -3637,33 +4483,66 @@ def _combined_cluster_rows( observed_motif_weights: dict[tuple[str, str], float], predicted_payloads: list[dict[str, object]], ) -> list[dict[str, object]]: + base_row_lookup = { + ( + str(row.get("structure", "")).strip(), + _normalized_nonempty_text( + row.get("motif"), + default="no_motif", + ), + ): dict(row) + for row in cluster_inventory.cluster_rows + } rows: list[dict[str, object]] = [] for cluster_bin in cluster_inventory.cluster_bins: - source_file = ( - str( + base_row = base_row_lookup.get( + (cluster_bin.structure, cluster_bin.motif), + {}, + ) + source_file = str( + base_row.get( + "source_file", ( - cluster_bin.source_dir / cluster_bin.representative - ).resolve() + ( + cluster_bin.source_dir / cluster_bin.representative + ).resolve() + if cluster_bin.representative + else "" + ), ) - if cluster_bin.representative - else "" ) rows.append( { "structure": cluster_bin.structure, "motif": cluster_bin.motif, - "count": int(len(cluster_bin.files)), + "count": int( + base_row.get("count", len(cluster_bin.files)) or 0 + ), "weight": float( observed_motif_weights.get( (cluster_bin.structure, cluster_bin.motif), 0.0, ) ), - "source_kind": "cluster_dir", - "source_dir": str(cluster_bin.source_dir), + "source_kind": str( + base_row.get("source_kind", "cluster_dir") + ), + "source_dir": str( + base_row.get("source_dir", cluster_bin.source_dir) + ), "source_file": source_file, - "source_file_name": cluster_bin.representative or "", - "representative": cluster_bin.representative or "", + "source_file_name": str( + base_row.get( + "source_file_name", + cluster_bin.representative or "", + ) + ), + "representative": str( + base_row.get( + "representative", + cluster_bin.representative or "", + ) + ), } ) for payload in predicted_payloads: diff --git a/src/saxshell/saxs/ui/__init__.py b/src/saxshell/saxs/ui/__init__.py index 1ef40d4..81ae6b1 100644 --- a/src/saxshell/saxs/ui/__init__.py +++ b/src/saxshell/saxs/ui/__init__.py @@ -1,8 +1,4 @@ -from .distribution_window import DistributionSetupWindow -from .experimental_data_loader import ExperimentalDataHeaderDialog -from .main_window import SAXSMainWindow, launch_saxs_ui -from .prior_histogram_window import PriorHistogramWindow -from .progress_dialog import SAXSProgressDialog +from __future__ import annotations __all__ = [ "DistributionSetupWindow", @@ -12,3 +8,30 @@ "SAXSMainWindow", "launch_saxs_ui", ] + + +def __getattr__(name: str): + if name == "DistributionSetupWindow": + from .distribution_window import DistributionSetupWindow + + return DistributionSetupWindow + if name == "ExperimentalDataHeaderDialog": + from .experimental_data_loader import ExperimentalDataHeaderDialog + + return ExperimentalDataHeaderDialog + if name == "PriorHistogramWindow": + from .prior_histogram_window import PriorHistogramWindow + + return PriorHistogramWindow + if name == "SAXSProgressDialog": + from .progress_dialog import SAXSProgressDialog + + return SAXSProgressDialog + if name in {"SAXSMainWindow", "launch_saxs_ui"}: + from .main_window import SAXSMainWindow, launch_saxs_ui + + return { + "SAXSMainWindow": SAXSMainWindow, + "launch_saxs_ui": launch_saxs_ui, + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/saxshell/saxs/ui/_pane_snap.py b/src/saxshell/saxs/ui/_pane_snap.py index c6675ae..e80f24b 100644 --- a/src/saxshell/saxs/ui/_pane_snap.py +++ b/src/saxshell/saxs/ui/_pane_snap.py @@ -28,7 +28,7 @@ def __init__( def set_enabled(self, enabled: bool) -> None: enabled = bool(enabled) - if enabled == self._enabled: + if enabled == getattr(self, "_enabled", False): return self._enabled = enabled app = QApplication.instance() @@ -40,23 +40,29 @@ def set_enabled(self, enabled: bool) -> None: app.removeEventFilter(self) def is_enabled(self) -> bool: - return self._enabled + return getattr(self, "_enabled", False) # ------------------------------------------------------------------ # QObject interface # ------------------------------------------------------------------ def eventFilter(self, watched: QObject, event: QEvent) -> bool: + splitter = getattr(self, "_splitter", None) if ( - not self._enabled - or self._splitter.orientation() != Qt.Orientation.Horizontal - or not self._splitter.isVisible() + not getattr(self, "_enabled", False) + or splitter is None + or splitter.orientation() != Qt.Orientation.Horizontal + or not splitter.isVisible() ): return False + left_widget = getattr(self, "_left_widget", None) + right_widget = getattr(self, "_right_widget", None) + if left_widget is None or right_widget is None: + return False if event.type() == QEvent.Type.MouseButtonPress: - if self._is_descendant(watched, self._left_widget): + if self._is_descendant(watched, left_widget): self._snap_to(0) - elif self._is_descendant(watched, self._right_widget): + elif self._is_descendant(watched, right_widget): self._snap_to(1) return False # never consume events diff --git a/src/saxshell/saxs/ui/main_window.py b/src/saxshell/saxs/ui/main_window.py index d2d7abf..5e04d24 100644 --- a/src/saxshell/saxs/ui/main_window.py +++ b/src/saxshell/saxs/ui/main_window.py @@ -70,12 +70,20 @@ QWidget, ) +from saxshell.fullrmc.packmol_docker import ( + PackmolDockerLink, + load_packmol_docker_link_metadata, + save_packmol_docker_link_metadata, +) +from saxshell.fullrmc.project_model import build_rmcsetup_paths +from saxshell.fullrmc.ui.packmol_docker_dialog import PackmolDockerLinkDialog from saxshell.saxs._model_templates import ( list_template_specs, load_template_spec, ) from saxshell.saxs.contrast.settings import ( COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, COMPONENT_BUILD_MODE_CONTRAST, component_build_mode_label, normalize_component_build_mode, @@ -108,8 +116,8 @@ SOLUTE_VOLUME_FRACTION_PARAMETER_NAMES, SOLVENT_VOLUME_FRACTION_PARAMETER_NAMES, SOLVENT_WEIGHT_PARAMETER_NAMES, + component_q_range_boundary_tolerance, normalize_requested_q_range_to_supported, - q_range_boundary_tolerance, ) from saxshell.saxs.project_manager import ( ClusterImportResult, @@ -118,6 +126,7 @@ ProjectSettings, SAXSProjectManager, build_project_paths, + component_source_mode_label, effective_q_range_for_settings, export_prior_histogram_npy, export_prior_histogram_table, @@ -167,6 +176,7 @@ RECENT_PROJECTS_KEY = "recent_project_dirs" CONSOLE_AUTOSCROLL_KEY = "console_autoscroll_enabled" AUTO_SNAP_PANES_KEY = "auto_snap_panes_enabled" +PACKMOL_DOCKER_PRESETS_KEY = "packmol_docker_presets" MAX_RECENT_PROJECTS = 10 PROJECT_LOAD_PREP_STEPS = 4 PROJECT_LOAD_TOTAL_STEPS = 12 @@ -272,6 +282,7 @@ class ProjectLoadPrefitPayload: workflow: SAXSPrefitWorkflow | None evaluation: PrefitEvaluation | None scale_recommendation: PrefitScaleRecommendation | None = None + autoscale_applied_on_load: bool = False workflow_error: str | None = None preview_error: str | None = None @@ -1216,6 +1227,9 @@ def _build_ui(self) -> None: self.project_setup_tab.open_clusterdynamicsml_requested.connect( self._open_clusterdynamicsml_tool ) + self.project_setup_tab.open_representative_finder_requested.connect( + self._open_representative_finder_tool + ) self.project_setup_tab.open_debye_waller_requested.connect( self._open_debye_waller_analysis_tool ) @@ -1225,6 +1239,9 @@ def _build_ui(self) -> None: self.project_setup_tab.predicted_structure_weights_changed.connect( self._on_predicted_structure_weights_changed ) + self.project_setup_tab.representative_structures_changed.connect( + self._on_representative_structures_changed + ) self.project_setup_tab.load_distribution_requested.connect( self._load_saved_distribution ) @@ -1389,6 +1406,7 @@ def _build_ui(self) -> None: def _build_menu_bar(self) -> None: menu_bar = self.menuBar() + menu_bar.clear() self.file_menu = menu_bar.addMenu("File") self.create_project_action = QAction("Create Project", self) @@ -1418,6 +1436,15 @@ def _build_menu_bar(self) -> None: self.save_project_as_action.triggered.connect(self.save_project_as) self.file_menu.addAction(self.save_project_as_action) + self.link_packmol_docker_action = QAction( + "Link Packmol Docker Container...", + self, + ) + self.link_packmol_docker_action.triggered.connect( + self._open_packmol_docker_link_dialog + ) + self.file_menu.addAction(self.link_packmol_docker_action) + self.tools_menu = menu_bar.addMenu("Tools") self.md_extraction_menu = self.tools_menu.addMenu("MD Extraction") self.mdtrajectory_action = QAction( @@ -1452,9 +1479,6 @@ def _build_menu_bar(self) -> None: self.debye_waller_analysis_action.triggered.connect( self._open_debye_waller_analysis_tool ) - self.structure_analysis_menu.addAction( - self.debye_waller_analysis_action - ) self.cluster_dynamics_menu = self.tools_menu.addMenu( "Cluster Dynamics" @@ -1466,7 +1490,6 @@ def _build_menu_bar(self) -> None: self.clusterdynamics_action.triggered.connect( self._open_clusterdynamics_tool ) - self.cluster_dynamics_menu.addAction(self.clusterdynamics_action) self.clusterdynamicsml_action = QAction( "Open Cluster Dynamics (ML)", @@ -1482,7 +1505,7 @@ def _build_menu_bar(self) -> None: self.pdfsetup_action.triggered.connect(self._open_pdfsetup_tool) self.pdf_menu.addAction(self.pdfsetup_action) - self.fullrmc_action = QAction("Open fullrmc Setup", self) + self.fullrmc_action = QAction("Open RMC Setup (fullrmc)", self) self.fullrmc_action.triggered.connect(self._open_fullrmc_tool) self.pdf_menu.addAction(self.fullrmc_action) @@ -1507,7 +1530,7 @@ def _build_menu_bar(self) -> None: "SAXS Calculation Preview" ) self.contrast_mode_action = QAction( - "Open SAXS Contrast Mode", + "Open Debye Scattering (Contrast Mode)", self, ) self.contrast_mode_action.triggered.connect( @@ -1520,7 +1543,7 @@ def _build_menu_bar(self) -> None: ) self.electron_density_mapping_action = QAction( - "Open Electron Density Mapping", + "Open 1D Born Approximation", self, ) self.electron_density_mapping_action.triggered.connect( @@ -1531,6 +1554,29 @@ def _build_menu_bar(self) -> None: self.component_calculation_preview_menu.addAction( self.electron_density_mapping_action ) + self.fft_born_approximation_action = QAction( + "Open 3D FFT Born Approximation", + self, + ) + self.fft_born_approximation_action.triggered.connect( + lambda _checked=False: self._open_3d_fft_born_approximation_tool( + preview_mode=True + ) + ) + self.component_calculation_preview_menu.addAction( + self.fft_born_approximation_action + ) + + self.representative_finder_action = QAction( + "Open Representative Structures", + self, + ) + self.representative_finder_action.triggered.connect( + self._open_representative_finder_tool + ) + self.structure_analysis_menu.addAction( + self.representative_finder_action + ) self.xray_toolkit_menu = self.tools_menu.addMenu("X-ray Toolkit") self.estimation_menu = self.xray_toolkit_menu @@ -1564,6 +1610,26 @@ def _build_menu_bar(self) -> None: self._open_fluorescence_tool ) self.xray_toolkit_menu.addAction(self.fluorescence_estimate_action) + self.cli_setup_menu = self.tools_menu.addMenu("CLI Setup") + self.representative_cli_setup_action = QAction( + "Open Representative CLI Setup (Beta)", + self, + ) + self.representative_cli_setup_action.triggered.connect( + self._open_representative_cli_setup_tool + ) + self.cli_setup_menu.addAction(self.representative_cli_setup_action) + self.beta_menu = self.tools_menu.addMenu("(beta)") + self.beta_menu.addAction(self.clusterdynamics_action) + self.beta_menu.addAction(self.debye_waller_analysis_action) + self.solvent_shell_builder_action = QAction( + "Open Solvent Shell Builder (Beta)", + self, + ) + self.solvent_shell_builder_action.triggered.connect( + self._open_solvent_shell_builder_tool + ) + self.beta_menu.addAction(self.solvent_shell_builder_action) self.settings_menu = menu_bar.addMenu("Settings") self.console_autoscroll_action = QAction( "Autoscroll Console Output", @@ -1924,6 +1990,12 @@ def advance_step( prefit_payload=payload.prefit, dream_payload=payload.dream, ) + self.project_setup_tab.set_component_plot_state( + payload.settings.component_plot_state + ) + self.project_setup_tab.set_prior_plot_state( + payload.settings.prior_plot_state + ) self._append_project_load_output( f"Recording recent project: {payload.settings.project_dir}" ) @@ -2090,19 +2162,27 @@ def _build_project_load_prefit_payload( evaluation: PrefitEvaluation | None = None preview_error: str | None = None scale_recommendation: PrefitScaleRecommendation | None = None + autoscale_applied_on_load = False + try: + scale_recommendation = workflow.auto_apply_autoscale_on_load() + autoscale_applied_on_load = scale_recommendation is not None + except Exception: + scale_recommendation = None try: evaluation = workflow.evaluate() except Exception as exc: preview_error = str(exc) else: - try: - scale_recommendation = workflow.recommend_scale_settings() - except Exception: - scale_recommendation = None + if not autoscale_applied_on_load: + try: + scale_recommendation = workflow.recommend_scale_settings() + except Exception: + scale_recommendation = None return ProjectLoadPrefitPayload( workflow=workflow, evaluation=evaluation, scale_recommendation=scale_recommendation, + autoscale_applied_on_load=autoscale_applied_on_load, preview_error=preview_error, ) @@ -2249,9 +2329,9 @@ def build_project_components(self) -> None: ) self.current_settings = settings self.project_setup_tab.append_summary( - "Build SAXS Components requested in Born Approximation " - "(Average).\n" - "Launching the electron-density mapping workflow with the " + "Build SAXS Components requested in 1D Born " + "Approximation (Average).\n" + "Launching the legacy 1D Born workflow with the " "active computed-distribution context.\n" "Use that workspace to calculate per-stoichiometry " "electron-density profiles and Fourier-transformed Born " @@ -2259,7 +2339,34 @@ def build_project_components(self) -> None: ) self._open_electron_density_mapping_tool(preview_mode=False) self.statusBar().showMessage( - "Opened electron-density Born Approximation workflow" + "Opened 1D Born Approximation workflow" + ) + return + if ( + settings.component_build_mode + == COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ): + self._save_settings( + settings, + status_message=( + "Project auto-saved before launching the 3D FFT Born " + "Approximation workflow" + ), + ) + self.current_settings = settings + self.project_setup_tab.append_summary( + "Build SAXS Components requested in 3D FFT Born " + "Approximation.\n" + "Launching the separate 3D FFT Born workflow with the " + "active computed-distribution context.\n" + "Use that workspace to evaluate the Cartesian 3D FFT " + "Born calculation, optional solvent-density contrast, " + "and comparison overlays while the component-export " + "integration is prepared." + ) + self._open_3d_fft_born_approximation_tool(preview_mode=False) + self.statusBar().showMessage( + "Opened 3D FFT Born Approximation workflow" ) return self._save_settings( @@ -2305,8 +2412,10 @@ def _current_project_has_built_components(self) -> bool: ) if settings is None: return False - component_dir = project_artifact_paths(settings).component_dir - return component_dir.is_dir() and any(component_dir.glob("*.txt")) + return self.project_manager.component_artifacts_match_settings( + settings, + artifact_paths=project_artifact_paths(settings), + ) def _on_prefit_field_interaction_requested(self) -> None: if self._current_project_has_built_components(): @@ -2622,12 +2731,15 @@ def _on_show_deprecated_templates_changed(self, enabled: bool) -> None: def save_project_state(self) -> None: try: settings = self._settings_from_project_tab() + rebuild_message = self._component_q_range_rebuild_message(settings) saved_path = self._save_settings(settings) - self._sync_live_project_settings_after_save(settings) + if rebuild_message is None: + self._sync_live_project_settings_after_save(settings) + else: + self.current_settings = settings self.project_setup_tab.append_summary( f"Saved project state to {saved_path}" ) - rebuild_message = self._component_q_range_rebuild_message(settings) if rebuild_message is not None: self._show_error( "Expanded q-range requires rebuilding SAXS components", @@ -2759,6 +2871,26 @@ def _normalized_registered_path_value(value: object) -> str | None: text = str(value or "").strip() return text or None + @staticmethod + def _component_q_values_for_q_range_tolerance( + component_dir: Path, + ) -> np.ndarray: + for component_path in sorted(component_dir.glob("*.txt")): + try: + raw_data = np.loadtxt(component_path, comments="#") + except Exception: + continue + if raw_data.size == 0: + continue + if raw_data.ndim == 1: + raw_data = raw_data.reshape(1, -1) + if raw_data.shape[1] < 1: + continue + q_values = np.asarray(raw_data[:, 0], dtype=float) + if q_values.size > 0: + return q_values + return np.asarray([], dtype=float) + def _component_q_range_rebuild_message( self, settings: ProjectSettings, @@ -2825,7 +2957,11 @@ def _component_q_range_rebuild_message( supported_max, ) ) - tolerance = q_range_boundary_tolerance( + tolerance = component_q_range_boundary_tolerance( + settings.component_build_mode, + self._component_q_values_for_q_range_tolerance( + artifact_paths.component_dir + ), supported_min, supported_max, ) @@ -3533,25 +3669,37 @@ def _prefit_cluster_geometry_matches_sync_snapshot( ) def _refresh_prefit_volume_fraction_section(self) -> None: - target = self._current_volume_fraction_target() + target = self._current_volume_fraction_target_details() solvent_weight_target = self._current_solvent_weight_target() visible = self.prefit_workflow is not None self.prefit_tab.set_solute_volume_fraction_visible(visible) parameter_name = None fraction_kind = None + fraction_source = "saxs_effective" if target is not None: - parameter_name, fraction_kind = target + parameter_name, fraction_kind, fraction_source = target self.prefit_tab.set_solute_volume_fraction_target( parameter_name, fraction_kind, + fraction_source, solvent_weight_target, ) self._sync_solution_scattering_tool_targets() def _current_volume_fraction_target(self) -> tuple[str, str] | None: + target = self._current_volume_fraction_target_details() + if target is None: + return None + return target[:2] + + def _current_volume_fraction_target_details( + self, + ) -> tuple[str, str, str] | None: if self.prefit_workflow is None: return None - return self.prefit_workflow.volume_fraction_estimator_target() + return ( + self.prefit_workflow.solution_scattering_volume_fraction_target() + ) def _current_solvent_weight_target(self) -> str | None: if self.prefit_workflow is None: @@ -3559,12 +3707,13 @@ def _current_solvent_weight_target(self) -> str | None: return self.prefit_workflow.solvent_weight_estimator_target() def _sync_solution_scattering_tool_targets(self) -> None: - target = self._current_volume_fraction_target() + target = self._current_volume_fraction_target_details() if target is None: parameter_name = None fraction_kind = None + fraction_source = "saxs_effective" else: - parameter_name, fraction_kind = target + parameter_name, fraction_kind, fraction_source = target solvent_weight_parameter = self._current_solvent_weight_target() for window in ( self._solute_volume_fraction_tool_window, @@ -3577,6 +3726,7 @@ def _sync_solution_scattering_tool_targets(self) -> None: window.estimator_widget.set_target_parameter( parameter_name, fraction_kind, + fraction_source, solvent_weight_parameter, ) @@ -3600,7 +3750,6 @@ def _apply_estimator_parameter_to_prefit( maximum = max( float(current_entry.maximum), parameter_value, - 1.0, ) self.prefit_tab.set_parameter_row( parameter_name, @@ -3634,57 +3783,87 @@ def _on_solution_scattering_estimate_calculated( preview_changed = False try: - volume_target = self._current_volume_fraction_target() + volume_target = self._current_volume_fraction_target_details() interaction_estimate = ( estimate_payload.interaction_contrast_estimate ) - if interaction_estimate is not None and volume_target is not None: - parameter_name, fraction_kind = volume_target - parameter_value = ( - float( - interaction_estimate.saxs_effective_solute_interaction_ratio + if volume_target is not None: + parameter_name, fraction_kind, fraction_source = volume_target + parameter_value: float | None = None + if fraction_source == "physical": + volume_estimate = estimate_payload.volume_fraction_estimate + if volume_estimate is not None: + parameter_value = ( + float(volume_estimate.solute_volume_fraction) + if fraction_kind == "solute" + else float(volume_estimate.solvent_volume_fraction) + ) + elif interaction_estimate is not None: + parameter_value = ( + float( + interaction_estimate.saxs_effective_solute_interaction_ratio + ) + if fraction_kind == "solute" + else float( + interaction_estimate.saxs_effective_solvent_background_ratio + ) ) - if fraction_kind == "solute" - else float( - interaction_estimate.saxs_effective_solvent_background_ratio + if parameter_value is None: + if hasattr(widget, "append_application_note"): + cast(object, widget).append_application_note( + "Calculated solution-scattering estimates, but the " + f"{parameter_name} target could not be populated " + f"from the requested {fraction_source} volume " + "fraction source." + ) + else: + self._apply_estimator_parameter_to_prefit( + parameter_name, + parameter_value, ) - ) - self._apply_estimator_parameter_to_prefit( - parameter_name, - parameter_value, - ) - applied_notes.append( - f"Applied {parameter_name} = " - f"{parameter_value:.{DISPLAY_FRACTION_DECIMALS}f} " - "from the SAXS-effective interaction ratio." - ) - log_lines.extend( - [ - f"Volume-fraction target: {parameter_name}", - f"Model fraction kind: {fraction_kind}", - ( - "Physical solute-associated volume fraction: " - f"{interaction_estimate.physical_solute_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" - ), - ( - "Physical solvent-associated volume fraction: " - f"{interaction_estimate.physical_solvent_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" - ), - ( - "SAXS-effective solute interaction ratio: " - f"{interaction_estimate.saxs_effective_solute_interaction_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" - ), - ( - "SAXS-effective solvent background ratio: " - f"{interaction_estimate.saxs_effective_solvent_background_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" - ), - ( - "Contrast weight factor: " - f"{interaction_estimate.contrast_weight_factor:.6g}" - ), - ] - ) - preview_changed = True + source_label = ( + "physical bulk volume fraction" + if fraction_source == "physical" + else "SAXS-effective interaction ratio" + ) + applied_notes.append( + f"Applied {parameter_name} = " + f"{parameter_value:.{DISPLAY_FRACTION_DECIMALS}f} " + f"from the {source_label}." + ) + log_lines.extend( + [ + f"Volume-fraction target: {parameter_name}", + f"Model fraction kind: {fraction_kind}", + f"Target fraction source: {fraction_source}", + ] + ) + if interaction_estimate is not None: + log_lines.extend( + [ + ( + "Physical solute-associated volume fraction: " + f"{interaction_estimate.physical_solute_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Physical solvent-associated volume fraction: " + f"{interaction_estimate.physical_solvent_associated_volume_fraction:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "SAXS-effective solute interaction ratio: " + f"{interaction_estimate.saxs_effective_solute_interaction_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "SAXS-effective solvent background ratio: " + f"{interaction_estimate.saxs_effective_solvent_background_ratio:.{DISPLAY_FRACTION_DECIMALS}f}" + ), + ( + "Contrast weight factor: " + f"{interaction_estimate.contrast_weight_factor:.6g}" + ), + ] + ) + preview_changed = True elif ( estimate_payload.volume_fraction_estimate is not None and hasattr(widget, "append_application_note") @@ -3704,7 +3883,10 @@ def _on_solution_scattering_estimate_calculated( attenuation_scale = float( estimate_payload.attenuation_estimate.solvent_scattering_scale_factor ) - uses_split_fraction_parameter = volume_target is not None + uses_split_fraction_parameter = ( + volume_target is not None + and volume_target[2] == "saxs_effective" + ) solvent_scale = attenuation_scale if ( not uses_split_fraction_parameter @@ -7752,6 +7934,15 @@ def start_load_step( f"{len(available_elements)} available element" f"{'' if len(available_elements) == 1 else 's'}." ) + start_load_step( + "Loading representative structure status...", + log_message="Checking Representative Structures readiness.", + ) + representative_status = ( + self._refresh_representative_structure_status(settings) + ) + if log_callback is not None: + log_callback(representative_status) start_load_step( "Loading predicted structure status...", log_message="Checking Predicted Structures readiness.", @@ -7902,10 +8093,6 @@ def _apply_loaded_prefit_payload( self._format_prefit_summary(payload.evaluation) ) self._update_prefit_stoichiometry_status() - if payload.scale_recommendation is not None: - self._append_scale_recommendation_log( - payload.scale_recommendation - ) self.prefit_tab.set_log_text( self._format_prefit_console_intro( evaluation=payload.evaluation, @@ -7916,6 +8103,13 @@ def _apply_loaded_prefit_payload( self.prefit_tab.append_log( "Loaded the Best Prefit preset from the project file." ) + elif ( + payload.autoscale_applied_on_load + and payload.scale_recommendation is not None + ): + self._append_initial_autoscale_log(payload.scale_recommendation) + elif payload.scale_recommendation is not None: + self._append_scale_recommendation_log(payload.scale_recommendation) self._refresh_saved_prefit_states() def _apply_loaded_dream_payload( @@ -7966,16 +8160,25 @@ def _refresh_saved_distributions( active_settings = settings or self.current_settings if active_settings is None: self.project_setup_tab.set_available_distributions([]) + self.project_setup_tab.set_active_distribution(None) self.project_setup_tab.set_current_distribution_details(None) self._update_active_contrast_distribution_view_state(None) return 0 + artifact_paths = project_artifact_paths(active_settings) + active_distribution_id = ( + str(artifact_paths.distribution_id or "").strip() or None + ) records = self.project_manager.list_saved_distributions( active_settings.project_dir ) labels = [] + active_distribution_labels: dict[str, str] = {} tooltips: dict[str, str] = {} details: dict[str, str] = {} for record in records: + dream_count = self._distribution_dream_run_count( + record.distribution_dir + ) readiness = [] if record.component_artifacts_ready: readiness.append("components") @@ -7988,10 +8191,12 @@ def _refresh_saved_distributions( ) labels.append( ( - f"{record.label}{readiness_text}", + f"{record.label}{readiness_text} | " + f"{self._distribution_dream_fit_text(dream_count)}", record.distribution_id, ) ) + active_distribution_labels[record.distribution_id] = record.label tooltips[record.distribution_id] = ( self._distribution_tooltip_for_record(record) ) @@ -8000,15 +8205,25 @@ def _refresh_saved_distributions( ) selected_id = None if labels: - selected_id = project_artifact_paths( - active_settings - ).distribution_id + selected_id = active_distribution_id self.project_setup_tab.set_available_distributions( labels, selected_distribution_id=selected_id, distribution_tooltips=tooltips, distribution_details=details, ) + active_distribution_text, active_distribution_tooltip = ( + self._active_distribution_field_state( + active_settings, + active_distribution_id=active_distribution_id, + active_distribution_labels=active_distribution_labels, + distribution_details=details, + ) + ) + self.project_setup_tab.set_active_distribution( + active_distribution_text, + tooltip=active_distribution_tooltip, + ) self.project_setup_tab.set_current_distribution_details( self._current_distribution_details_text(active_settings) ) @@ -8039,6 +8254,18 @@ def _distribution_time_text(value: str | None) -> str: text = str(value or "").strip() return text or "Unavailable" + @staticmethod + def _distribution_dream_fit_text(dream_count: int) -> str: + return f"DREAM fits: {max(int(dream_count), 0)}" + + @staticmethod + def _distribution_component_source_preference_text( + use_representative_structures: bool, + ) -> str: + return component_source_mode_label( + "representative" if use_representative_structures else "average" + ) + @staticmethod def _distribution_prefit_summary( distribution_dir: Path, @@ -8095,12 +8322,19 @@ def _distribution_component_count(distribution_dir: Path) -> int: return 0 def _distribution_tooltip_for_record(self, record) -> str: + dream_count = self._distribution_dream_run_count( + record.distribution_dir + ) return "\n".join( [ record.label, f"Distribution ID: {record.distribution_id}", "Build mode: " + component_build_mode_label(record.component_build_mode), + "Component source preference: " + + self._distribution_component_source_preference_text( + record.use_representative_structures + ), "Template: " + str(record.template_name or "Unspecified"), "q-range: " + self._distribution_q_range_text( @@ -8112,6 +8346,7 @@ def _distribution_tooltip_for_record(self, record) -> str: use_experimental_grid=record.use_experimental_grid, q_points=record.q_points, ), + self._distribution_dream_fit_text(dream_count), "Updated: " + self._distribution_time_text(record.updated_at), ] ) @@ -8133,6 +8368,10 @@ def _distribution_details_for_record(self, record) -> str: "Template: " + str(record.template_name or "Unspecified"), "Build mode: " + component_build_mode_label(record.component_build_mode), + "Component source preference: " + + self._distribution_component_source_preference_text( + record.use_representative_structures + ), "Structure weighting: " + ( "Observed + Predicted Structures" @@ -8151,7 +8390,9 @@ def _distribution_details_for_record(self, record) -> str: ), "Saved components: " + ( - f"ready ({component_count} traces)" + "ready " + f"({component_count} traces, built from " + f"{component_source_mode_label(record.built_component_source_mode or 'average')})" if record.component_artifacts_ready else "not built yet" ), @@ -8163,7 +8404,7 @@ def _distribution_details_for_record(self, record) -> str: ), f"Saved prefits: {prefit_count}", f"Best Prefit R^2: {best_prefit_r_squared}", - f"Saved DREAM runs: {dream_count}", + f"Saved {self._distribution_dream_fit_text(dream_count)}", "Best DREAM R^2: Available after loading a DREAM run", "Created: " + self._distribution_time_text(record.created_at), "Updated: " + self._distribution_time_text(record.updated_at), @@ -8194,6 +8435,10 @@ def _current_distribution_details_text( + str(settings.selected_model_template or "Unspecified"), "Build mode: " + component_build_mode_label(settings.component_build_mode), + "Component source preference: " + + self._distribution_component_source_preference_text( + settings.use_representative_structures + ), "Structure weighting: " + ( "Observed + Predicted Structures" @@ -8222,6 +8467,33 @@ def _current_distribution_details_text( ] ) + def _active_distribution_field_state( + self, + settings: ProjectSettings, + *, + active_distribution_id: str | None, + active_distribution_labels: dict[str, str], + distribution_details: dict[str, str], + ) -> tuple[str, str]: + if active_distribution_id: + active_label = active_distribution_labels.get( + active_distribution_id + ) + active_details = distribution_details.get(active_distribution_id) + if active_label: + return active_label, ( + active_details + or self._current_distribution_details_text(settings) + ) + return ( + f"Pending: {active_distribution_id}", + self._current_distribution_details_text(settings), + ) + return ( + "No active computed distribution loaded.", + self._current_distribution_details_text(settings), + ) + def _set_dream_tab_enabled(self, enabled: bool) -> None: dream_index = self.tabs.indexOf(self.dream_tab) if dream_index < 0: @@ -8296,6 +8568,51 @@ def _on_predicted_structure_weights_changed(self, enabled: bool) -> None: str(exc), ) + @Slot(bool) + def _on_representative_structures_changed(self, enabled: bool) -> None: + if self.current_settings is None: + return + try: + settings = self._settings_from_project_tab() + settings.use_representative_structures = bool(enabled) + self._save_settings(settings) + self.current_settings = settings + self._apply_project_settings(settings) + self.statusBar().showMessage( + "Use Representative Structures enabled" + if enabled + else "Use Representative Structures disabled" + ) + except Exception as exc: + self._show_error( + "Update Representative Structures mode failed", + str(exc), + ) + + @Slot(str) + def _handle_representative_structure_results_changed( + self, + project_dir: str, + ) -> None: + if self.current_settings is None: + return + active_project_dir = Path(self.current_settings.project_dir).resolve() + if active_project_dir != Path(project_dir).expanduser().resolve(): + return + try: + status_text = self._refresh_representative_structure_status( + self.current_settings + ) + self.project_setup_tab.append_summary(status_text) + self.statusBar().showMessage( + "Representative structures updated for the active project" + ) + except Exception as exc: + self.project_setup_tab.append_summary( + "Representative structures changed, but the main UI " + f"could not refresh their status automatically:\n{exc}" + ) + @Slot(str) def _load_saved_distribution(self, distribution_id: str) -> None: if self.current_settings is None: @@ -8340,6 +8657,9 @@ def _settings_from_project_tab(self) -> ProjectSettings: base.use_predicted_structure_weights = ( self.project_setup_tab.use_predicted_structure_weights() ) + base.use_representative_structures = ( + self.project_setup_tab.use_representative_structures() + ) base.frames_dir = ( str(self.project_setup_tab.frames_dir()) if self.project_setup_tab.frames_dir() is not None @@ -8405,6 +8725,9 @@ def _settings_from_project_tab(self) -> ProjectSettings: base.component_trace_color_scheme = ( self.project_setup_tab.component_trace_color_scheme() ) + base.component_plot_state = ( + self.project_setup_tab.component_plot_state() + ) base.experimental_trace_visible = ( self.project_setup_tab.experimental_trace_visible() ) @@ -8425,6 +8748,7 @@ def _settings_from_project_tab(self) -> ProjectSettings: base.prior_histogram_x_axis_order = ( self.project_setup_tab.prior_histogram_x_axis_order() ) + base.prior_plot_state = self.project_setup_tab.prior_plot_state() return base def _load_prefit_workflow(self) -> SAXSPrefitWorkflow: @@ -8434,6 +8758,13 @@ def _load_prefit_workflow(self) -> SAXSPrefitWorkflow: self.current_settings.project_dir, template_name=self.current_settings.selected_model_template, ) + initial_autoscale_recommendation = None + try: + initial_autoscale_recommendation = ( + self.prefit_workflow.auto_apply_autoscale_on_load() + ) + except Exception: + initial_autoscale_recommendation = None self.current_settings = self.prefit_workflow.settings self._sync_active_template_controls( self.prefit_workflow.template_spec.name, @@ -8468,6 +8799,10 @@ def _load_prefit_workflow(self) -> SAXSPrefitWorkflow: self.prefit_tab.append_log( "Loaded the Best Prefit preset from the project file." ) + elif initial_autoscale_recommendation is not None: + self._append_initial_autoscale_log( + initial_autoscale_recommendation + ) self._refresh_saved_prefit_states() self._refresh_model_only_mode_state() return self.prefit_workflow @@ -8514,6 +8849,7 @@ def _apply_prefit_template_fallback( self.prefit_tab.set_solute_volume_fraction_target( None, None, + "saxs_effective", solvent_weight_target, ) else: @@ -8548,9 +8884,16 @@ def _apply_prefit_template_fallback( @staticmethod def _volume_fraction_target_for_template_spec( template_spec, - ) -> tuple[str, str] | None: + ) -> tuple[str, str, str] | None: if template_spec is None: return None + support = template_spec.solution_scattering_support + if support.volume_fraction_parameter is not None: + return ( + support.volume_fraction_parameter, + support.volume_fraction_kind, + support.volume_fraction_source, + ) parameter_names = { str(parameter.name).strip() for parameter in template_spec.parameters @@ -8558,10 +8901,10 @@ def _volume_fraction_target_for_template_spec( } for candidate in SOLUTE_VOLUME_FRACTION_PARAMETER_NAMES: if candidate in parameter_names: - return candidate, "solute" + return candidate, "solute", "saxs_effective" for candidate in SOLVENT_VOLUME_FRACTION_PARAMETER_NAMES: if candidate in parameter_names: - return candidate, "solvent" + return candidate, "solvent", "saxs_effective" return None @staticmethod @@ -8720,6 +9063,12 @@ def _refresh_component_plot(self) -> list[Path]: self.project_setup_tab.draw_component_plot(None) return [] artifact_paths = project_artifact_paths(self.current_settings) + if not self.project_manager.component_artifacts_match_settings( + self.current_settings, + artifact_paths=artifact_paths, + ): + self.project_setup_tab.draw_component_plot(None) + return [] component_paths = sorted(artifact_paths.component_dir.glob("*.txt")) self.project_setup_tab.draw_component_plot(component_paths or None) return component_paths @@ -8904,10 +9253,11 @@ def _refresh_predicted_structure_status( ) return status_text artifact_paths = project_artifact_paths(active_settings) - component_artifacts_ready = bool( - artifact_paths.component_dir.is_dir() - and any(artifact_paths.component_dir.glob("*.txt")) - and artifact_paths.component_map_file.is_file() + component_artifacts_ready = ( + self.project_manager.component_artifacts_match_settings( + active_settings, + artifact_paths=artifact_paths, + ) ) prior_artifacts_ready = bool( artifact_paths.prior_weights_file.is_file() @@ -8954,6 +9304,155 @@ def _refresh_predicted_structure_status( self.project_setup_tab.set_predicted_structure_status_text(status_text) return status_text + def _refresh_representative_structure_status( + self, + settings: ProjectSettings | None = None, + ) -> str: + active_settings = settings or self.current_settings + if active_settings is None: + self.project_setup_tab.set_representative_structure_availability( + False + ) + status_text = ( + self.project_setup_tab._default_representative_structure_status_text() + ) + self.project_setup_tab.set_representative_structure_status_text( + status_text + ) + return status_text + representative_state = ( + self.project_manager.inspect_representative_structures( + active_settings.project_dir, + prior_weights_path=project_artifact_paths( + active_settings + ).prior_weights_file, + ) + ) + representative_available = bool(representative_state.selection_ready) + mode_labels = { + "partialsolv": "partialsolv", + "nosolv": "nosolv", + "fullsolv": "fullsolv", + } + available_modes = tuple( + mode_labels.get(mode, mode) + for mode in representative_state.available_modes + ) + if ( + active_settings.use_representative_structures + and not representative_available + and self.current_settings is not None + and Path(self.current_settings.project_dir).resolve() + == Path(active_settings.project_dir).expanduser().resolve() + ): + corrected_settings = ProjectSettings.from_dict( + active_settings.to_dict() + ) + corrected_settings.use_representative_structures = False + self.project_manager.save_project( + corrected_settings, + refresh_registered_paths=False, + ) + self.current_settings = corrected_settings + active_settings = corrected_settings + effective_enabled = bool( + active_settings.use_representative_structures + and representative_available + ) + self.project_setup_tab.set_representative_structure_availability( + representative_available, + representative_count=representative_state.representative_count, + available_modes=available_modes, + ) + self.project_setup_tab.set_use_representative_structures( + effective_enabled + ) + if not representative_available: + if representative_state.representative_count > 0: + status_text = ( + "Representative Structures mode is off.\n" + f"Saved {representative_state.representative_count} of " + f"{max(representative_state.expected_representative_count, 1)} " + "required representative structure" + f"{'' if max(representative_state.expected_representative_count, 1) == 1 else 's'}.\n" + f"Missing bins: {representative_state.missing_representative_count}\n" + f"Invalid bins: {representative_state.invalid_representative_count}\n" + "Finish computing every representative structure in the " + "selected distribution before enabling Use Representative " + "Structures here." + ) + else: + status_text = ( + "Representative Structures mode is off.\n" + "No saved project representative structures were found yet.\n" + "Open Tools > Structure Analysis > Open Representative " + "Structures, compute the representative set, then enable " + "Use Representative Structures here." + ) + self.project_setup_tab.set_representative_structure_status_text( + status_text + ) + return status_text + mode_text = ( + ", ".join(available_modes) if available_modes else "partialsolv" + ) + selection_file = representative_state.representative_selection_file + location_text = ( + "" + if selection_file is None + else f"\nRepresentative metadata: {selection_file.name}" + ) + if not effective_enabled: + status_text = ( + "Representative Structures mode is off.\n" + f"Found {representative_state.representative_count} saved " + "representative structure" + f"{'' if representative_state.representative_count == 1 else 's'} " + f"for this project.\nAvailable saved sets: {mode_text}." + f"{location_text}\nEnable Use Representative Structures to " + "make compatible Debye, 1D Born, 3D FFT, and RMCSetup workflows " + "prefer the saved project representatives instead of average cluster folders." + ) + self.project_setup_tab.set_representative_structure_status_text( + status_text + ) + return status_text + artifact_paths = project_artifact_paths(active_settings) + component_artifacts_ready = ( + self.project_manager.component_artifacts_ready( + active_settings, + artifact_paths=artifact_paths, + ) + ) + component_artifacts_match = ( + self.project_manager.component_artifacts_match_settings( + active_settings, + artifact_paths=artifact_paths, + ) + ) + rebuild_note = "" + if component_artifacts_ready and not component_artifacts_match: + rebuild_note = ( + "\nRebuild SAXS components to replace the saved average traces " + "with representative-structure traces for this computed " + "distribution. The saved prior weights still come from the full " + "stoichiometry distribution." + ) + status_text = ( + "Representative Structures mode is on.\n" + f"Using {representative_state.representative_count} saved project " + "representative structure" + f"{'' if representative_state.representative_count == 1 else 's'} " + f"with saved sets: {mode_text}.{location_text}\n" + "Compatible Debye, 1D Born, 3D FFT, and RMCSetup workflows will " + "prefer these project representatives over average cluster folders." + f"{rebuild_note}" + ) + self.project_setup_tab.set_representative_structure_status_text( + status_text + ) + return status_text + def _refresh_saved_prefit_states( self, *, @@ -9517,6 +10016,98 @@ def _set_auto_snap_panes_enabled( def _recent_projects_settings(self) -> QSettings: return QSettings("SAXShell", "SAXS") + def _packmol_docker_settings(self) -> QSettings: + return QSettings("SAXShell", "RMCSetup") + + def _recent_packmol_docker_presets(self) -> list[PackmolDockerLink]: + raw_value = self._packmol_docker_settings().value( + PACKMOL_DOCKER_PRESETS_KEY, + "[]", + ) + if isinstance(raw_value, str): + try: + payload = json.loads(raw_value) + except Exception: + payload = [] + elif isinstance(raw_value, (list, tuple)): + payload = list(raw_value) + else: + payload = [] + presets: list[PackmolDockerLink] = [] + for entry in payload: + preset = PackmolDockerLink.from_dict( + dict(entry) if isinstance(entry, dict) else None + ) + if preset is not None: + presets.append(preset) + return presets + + def _remember_packmol_docker_preset( + self, + link: PackmolDockerLink, + ) -> None: + preset = PackmolDockerLink.from_dict(link.to_preset_dict()) + if preset is None: + return + signature = ( + preset.container_name, + preset.packmol_command, + preset.shell_command, + preset.container_project_root, + ) + kept = [ + existing + for existing in self._recent_packmol_docker_presets() + if ( + existing.container_name, + existing.packmol_command, + existing.shell_command, + existing.container_project_root, + ) + != signature + ] + payload = [preset.to_preset_dict()] + [ + item.to_preset_dict() for item in kept[:7] + ] + self._packmol_docker_settings().setValue( + PACKMOL_DOCKER_PRESETS_KEY, + json.dumps(payload), + ) + + def _open_packmol_docker_link_dialog(self) -> None: + if self.current_settings is None: + QMessageBox.information( + self, + "No project loaded", + "Open or create a SAXS project before linking a Packmol " + "Docker container from the main UI.", + ) + return + project_dir = Path(self.current_settings.project_dir).resolve() + rmcsetup_paths = build_rmcsetup_paths(project_dir) + current_link = load_packmol_docker_link_metadata( + rmcsetup_paths.packmol_docker_link_path + ) + dialog = PackmolDockerLinkDialog( + current_link=current_link, + recent_presets=self._recent_packmol_docker_presets(), + parent=self, + ) + if not dialog.exec(): + return + link = dialog.selected_link() + if link is None: + return + link.linked_at = datetime.now().isoformat(timespec="seconds") + save_packmol_docker_link_metadata( + rmcsetup_paths.packmol_docker_link_path, + link, + ) + self._remember_packmol_docker_preset(link) + self.statusBar().showMessage( + f"Linked Packmol Docker container {link.container_name}" + ) + def _recent_project_paths(self) -> list[str]: raw_value = self._recent_projects_settings().value( RECENT_PROJECTS_KEY, @@ -9582,6 +10173,7 @@ def _update_file_menu_state(self) -> None: has_project = self.current_settings is not None self.save_project_action.setEnabled(has_project) self.save_project_as_action.setEnabled(has_project) + self.link_packmol_docker_action.setEnabled(has_project) def _remap_copied_project_paths( self, @@ -9658,6 +10250,55 @@ def _active_project_launch_settings(self) -> ProjectSettings | None: except Exception: return self.current_settings + def _preferred_representative_input_path( + self, + settings: ProjectSettings | None, + ) -> Path | None: + if settings is None or not settings.use_representative_structures: + return None + try: + representative_state = ( + self.project_manager.inspect_representative_structures( + settings.project_dir, + prior_weights_path=project_artifact_paths( + settings + ).prior_weights_file, + ) + ) + except Exception: + return None + if not representative_state.selection_ready: + return None + available_modes = set(representative_state.available_modes) + candidate_dirs = [ + ( + representative_state.partialsolv_dir + if "partialsolv" in available_modes + else None + ), + ( + representative_state.nosolv_dir + if "nosolv" in available_modes + else None + ), + ( + representative_state.fullsolv_dir + if "fullsolv" in available_modes + else None + ), + ] + existing_dirs = [ + directory + for directory in candidate_dirs + if directory is not None and directory.is_dir() + ] + if len(existing_dirs) == 1: + return existing_dirs[0] + selection_file = representative_state.representative_selection_file + if selection_file is not None and selection_file.parent.is_dir(): + return selection_file.parent + return existing_dirs[0] if existing_dirs else None + def _active_contrast_distribution_artifact_context( self, settings: ProjectSettings | None = None, @@ -10192,6 +10833,84 @@ def _open_structure_viewer_tool(self) -> None: self._track_child_tool_window(window) self.statusBar().showMessage("Opened Structure Viewer") + def _open_solvent_shell_builder_tool(self) -> None: + from saxshell.fullrmc.ui.solvent_shell_builder_window import ( + launch_solvent_shell_builder_ui, + ) + + project_dir = None + initial_input_path = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + initial_input_path = ( + self.current_settings.resolved_pdb_frames_dir + or self.current_settings.resolved_frames_dir + ) + window = launch_solvent_shell_builder_ui( + initial_project_dir=project_dir, + initial_input_path=initial_input_path, + ) + self._track_child_tool_window(window) + if project_dir is not None: + self.statusBar().showMessage( + f"Opened solvent shell builder (beta) for {project_dir}" + ) + else: + self.statusBar().showMessage("Opened solvent shell builder (beta)") + + def _open_representative_finder_tool(self) -> None: + from saxshell.representativefinder.ui.main_window import ( + launch_representativefinder_ui, + ) + + project_dir = None + initial_input_path = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + initial_input_path = self.current_settings.resolved_clusters_dir + window = launch_representativefinder_ui( + initial_project_dir=project_dir, + initial_input_path=initial_input_path, + ) + project_results_changed = getattr( + window, "project_results_changed", None + ) + if project_results_changed is not None and hasattr( + project_results_changed, "connect" + ): + project_results_changed.connect( + self._handle_representative_structure_results_changed + ) + self._track_child_tool_window(window) + if project_dir is not None: + self.statusBar().showMessage( + "Opened representative structures for " f"{project_dir}" + ) + else: + self.statusBar().showMessage("Opened representative structures") + + def _open_representative_cli_setup_tool(self) -> None: + from saxshell.representativefinder.ui.run_file_window import ( + launch_representativefinder_run_file_ui, + ) + + project_dir = None + initial_input_path = None + if self.current_settings is not None: + project_dir = Path(self.current_settings.project_dir).resolve() + initial_input_path = self.current_settings.resolved_clusters_dir + window = launch_representativefinder_run_file_ui( + initial_project_dir=project_dir, + initial_input_path=initial_input_path, + ) + self._track_child_tool_window(window) + if project_dir is not None: + self.statusBar().showMessage( + "Opened representative CLI setup for " f"{project_dir}" + ) + else: + self.statusBar().showMessage("Opened representative CLI setup") + def _open_contrast_mode_tool( self, *, @@ -10333,6 +11052,7 @@ def _open_electron_density_mapping_tool( distribution_id = None distribution_root_dir = None use_predicted_structure_weights = False + preferred_representative_input = None if settings is not None: project_dir = Path(settings.project_dir).resolve() q_min = settings.q_min @@ -10340,13 +11060,18 @@ def _open_electron_density_mapping_tool( use_predicted_structure_weights = bool( settings.use_predicted_structure_weights ) + preferred_representative_input = ( + self._preferred_representative_input_path(settings) + ) candidate_paths = ( ( + preferred_representative_input, settings.resolved_pdb_frames_dir, settings.resolved_frames_dir, ) if preview_mode else ( + preferred_representative_input, settings.resolved_clusters_dir, settings.resolved_pdb_frames_dir, settings.resolved_frames_dir, @@ -10383,29 +11108,145 @@ def _open_electron_density_mapping_tool( window.set_auto_snap_enabled(self._auto_snap_panes_enabled) self._connect_born_approximation_updates(window) self._track_child_tool_window(window) + if input_path is not None: + self.statusBar().showMessage( + ("Opened 1D Born Approximation preview for " f"{input_path}") + if preview_mode + else f"Opened 1D Born Approximation for {input_path}" + ) + elif project_dir is not None: + self.statusBar().showMessage( + ("Opened 1D Born Approximation preview for " f"{project_dir}") + if preview_mode + else f"Opened 1D Born Approximation for {project_dir}" + ) + else: + self.statusBar().showMessage( + "Opened 1D Born Approximation preview" + if preview_mode + else "Opened 1D Born Approximation" + ) + + def _open_3d_fft_born_approximation_tool( + self, + *, + preview_mode: bool = True, + ) -> None: + try: + from saxshell.saxs.contrast_fft.ui.main_window import ( + launch_3d_fft_born_approximation_ui, + ) + except Exception as exc: + self._show_error( + "3D FFT Born Approximation Error", + "Could not load the 3D FFT Born Approximation window:\n" + f"{exc}", + ) + return + + try: + settings = self._active_project_launch_settings() + project_dir = None + input_path = None + output_dir = None + q_min = None + q_max = None + distribution_id = None + distribution_root_dir = None + use_predicted_structure_weights = False + use_representative_structures = False + preferred_representative_input = None + if settings is not None: + project_dir = Path(settings.project_dir).resolve() + q_min = settings.q_min + q_max = settings.q_max + use_predicted_structure_weights = bool( + settings.use_predicted_structure_weights + ) + use_representative_structures = bool( + settings.use_representative_structures + ) + preferred_representative_input = ( + self._preferred_representative_input_path(settings) + ) + candidate_paths = ( + ( + preferred_representative_input, + settings.resolved_pdb_frames_dir, + settings.resolved_frames_dir, + ) + if preview_mode + else ( + preferred_representative_input, + settings.resolved_clusters_dir, + settings.resolved_pdb_frames_dir, + settings.resolved_frames_dir, + ) + ) + for candidate in candidate_paths: + if candidate is not None and candidate.exists(): + input_path = candidate + break + if not preview_mode: + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + ) + distribution_id = artifact_paths.distribution_id + distribution_root_dir = artifact_paths.root_dir + output_dir = ( + artifact_paths.root_dir / "born_approximation_3d_fft" + ) + window = launch_3d_fft_born_approximation_ui( + initial_project_dir=project_dir, + initial_input_path=input_path, + initial_output_dir=output_dir, + initial_project_q_min=q_min, + initial_project_q_max=q_max, + initial_distribution_id=distribution_id, + initial_distribution_root_dir=distribution_root_dir, + initial_use_predicted_structure_weights=( + use_predicted_structure_weights + ), + initial_use_representative_structures=( + use_representative_structures + ), + preview_mode=preview_mode, + ) + except Exception as exc: + self._show_error( + "3D FFT Born Approximation Error", + "Could not open the 3D FFT Born Approximation window:\n" + f"{exc}", + ) + return + if hasattr(window, "set_auto_snap_enabled"): + window.set_auto_snap_enabled(self._auto_snap_panes_enabled) + self._connect_born_approximation_updates(window) + self._track_child_tool_window(window) if input_path is not None: self.statusBar().showMessage( ( - "Opened electron density mapping preview for " + "Opened 3D FFT Born Approximation preview for " f"{input_path}" ) if preview_mode - else f"Opened electron density mapping for {input_path}" + else f"Opened 3D FFT Born Approximation for {input_path}" ) elif project_dir is not None: self.statusBar().showMessage( ( - "Opened electron density mapping preview for " + "Opened 3D FFT Born Approximation preview for " f"{project_dir}" ) if preview_mode - else f"Opened electron density mapping for {project_dir}" + else f"Opened 3D FFT Born Approximation for {project_dir}" ) else: self.statusBar().showMessage( - "Opened electron density mapping preview" + "Opened 3D FFT Born Approximation preview" if preview_mode - else "Opened electron density mapping" + else "Opened 3D FFT Born Approximation" ) @Slot(object) @@ -11937,6 +12778,34 @@ def _append_scale_recommendation_log( ) self.prefit_tab.append_log(message) + def _append_initial_autoscale_log( + self, + recommendation: PrefitScaleRecommendation, + ) -> None: + offset_lines: list[str] = [] + if recommendation.recommended_offset is not None: + offset_lines.append( + f"Offset: {recommendation.recommended_offset:.6g}" + ) + if ( + recommendation.recommended_offset_minimum is not None + and recommendation.recommended_offset_maximum is not None + ): + offset_lines.append( + "Offset range: " + f"{recommendation.recommended_offset_minimum:.6g} to " + f"{recommendation.recommended_offset_maximum:.6g}" + ) + message = ( + "Applied initial autoscale settings for this template.\n" + + f"Scale: {recommendation.recommended_scale:.6g}\n" + + f"Scale range: {recommendation.recommended_minimum:.6g} " + + f"to {recommendation.recommended_maximum:.6g}\n" + + ("\n".join(offset_lines) + "\n" if offset_lines else "") + + f"Points used: {recommendation.points_used}" + ) + self.prefit_tab.append_log(message) + def _maybe_append_scale_recommendation( self, entries=None, diff --git a/src/saxshell/saxs/ui/prefit_tab.py b/src/saxshell/saxs/ui/prefit_tab.py index 2bcf0e6..191a2e2 100644 --- a/src/saxshell/saxs/ui/prefit_tab.py +++ b/src/saxshell/saxs/ui/prefit_tab.py @@ -35,6 +35,14 @@ QWidget, ) +from saxshell.plotting import ( + Q_A_INVERSE_LABEL, + LinePlotDefaults, + LinePlotEditorControls, + LinePlotSeriesDefaults, + LinePlotSettings, + PlotEditorWindow, +) from saxshell.saxs._model_templates import TemplateSpec from saxshell.saxs.prefit import ( ClusterGeometryMetadataRow, @@ -230,6 +238,9 @@ def __init__(self, parent: QWidget | None = None) -> None: DEFAULT_IONIC_RADIUS_TYPE ) self._active_template_name: str | None = None + self._line_plot_settings = LinePlotSettings() + self._plot_editor_window: PlotEditorWindow | None = None + self._plot_editor_controls: LinePlotEditorControls | None = None self._suspend_template_selection_signal = False self._build_ui() self._install_field_interaction_watchers() @@ -600,6 +611,8 @@ def _build_plot_group(self) -> QGroupBox: self.log_y_checkbox = QCheckBox("Log Y") self.log_y_checkbox.setChecked(True) self.log_y_checkbox.toggled.connect(self._redraw_current_plot) + self.open_plot_editor_button = QPushButton("Open Plot Editor") + self.open_plot_editor_button.clicked.connect(self.open_plot_editor) self.save_plot_data_button = QPushButton("Export Plot Data") self.save_plot_data_button.clicked.connect( self.save_plot_data_requested.emit @@ -610,6 +623,7 @@ def _build_plot_group(self) -> QGroupBox: controls.addWidget(self.show_structure_factor_trace_checkbox) controls.addWidget(self.log_x_checkbox) controls.addWidget(self.log_y_checkbox) + controls.addWidget(self.open_plot_editor_button) controls.addWidget(self.save_plot_data_button) controls.addStretch(1) layout.addLayout(controls) @@ -1037,6 +1051,7 @@ def set_solute_volume_fraction_target( self, parameter_name: str | None, fraction_kind: str | None, + fraction_source: str = "saxs_effective", solvent_weight_parameter: str | None = None, ) -> None: target_messages: list[str] = [] @@ -1046,11 +1061,20 @@ def set_solute_volume_fraction_target( if str(fraction_kind).strip() == "solute" else "solvent" ) + source_label = ( + "physical volume fraction" + if str(fraction_source).strip() == "physical" + else "SAXS-effective interaction fraction" + ) target_messages.append( - f"{target_label} SAXS-effective interaction fraction -> {parameter_name}" + f"{target_label} {source_label} -> {parameter_name}" ) if solvent_weight_parameter: - if parameter_name and fraction_kind: + if ( + parameter_name + and fraction_kind + and str(fraction_source).strip() == "saxs_effective" + ): target_messages.append( f"attenuation solvent scale -> {solvent_weight_parameter}" ) @@ -1071,6 +1095,7 @@ def set_solute_volume_fraction_target( self.solute_volume_fraction_widget.set_target_parameter( parameter_name, fraction_kind, + fraction_source, solvent_weight_parameter, ) @@ -1708,20 +1733,236 @@ def run_config(self) -> PrefitRunConfig: max_nfev=int(self.nfev_spin.value()), ) - def plot_evaluation( + def open_plot_editor(self) -> None: + if self._plot_editor_window is not None: + self._plot_editor_window.show() + self._plot_editor_window.raise_() + self._plot_editor_window.activateWindow() + self._plot_editor_window.refresh_preview() + return + + defaults = self._current_plot_defaults(self._current_evaluation) + self._line_plot_settings.sync_series(defaults.series_defaults) + self._plot_editor_controls = LinePlotEditorControls( + settings=self._line_plot_settings, + defaults=defaults, + parent=self, + ) + self._plot_editor_controls.label_settings_changed.connect( + self._redraw_current_plot + ) + self._plot_editor_controls.settings_changed.connect( + self._redraw_current_plot + ) + self._plot_editor_window = PlotEditorWindow( + window_title="SAXS Prefit Plot Editor", + controls_widget=self._plot_editor_controls, + render_preview=self._render_plot_editor_preview, + pickle_state_provider=self._plot_editor_pickle_state, + apply_loaded_pickle_state=self._apply_loaded_plot_editor_pickle_state, + parent=self, + ) + self._plot_editor_window.closed.connect(self._on_plot_editor_closed) + self._plot_editor_window.refresh_preview() + self._plot_editor_window.show() + self._plot_editor_window.raise_() + self._plot_editor_window.activateWindow() + + def _on_plot_editor_closed(self) -> None: + self._plot_editor_window = None + self._plot_editor_controls = None + + def _current_plot_defaults( + self, + evaluation: PrefitEvaluation | None, + ) -> LinePlotDefaults: + has_evaluation = evaluation is not None + has_experimental = ( + has_evaluation and evaluation.experimental_intensities is not None + ) + has_residuals = has_evaluation and evaluation.residuals is not None + structure_values = ( + np.asarray(evaluation.structure_factor_trace, dtype=float) + if has_evaluation and evaluation.structure_factor_trace is not None + else np.asarray([], dtype=float) + ) + has_structure_factor_axis = bool( + has_evaluation + and self.show_structure_factor_trace_checkbox.isChecked() + and np.any(np.isfinite(structure_values)) + ) + series_defaults: list[LinePlotSeriesDefaults] = [] + if ( + has_experimental + and self.show_experimental_trace_checkbox.isChecked() + ): + series_defaults.append( + LinePlotSeriesDefaults( + key="experimental", + label="Experimental", + axis_label="Main", + ) + ) + if ( + has_evaluation + and self.show_solvent_trace_checkbox.isChecked() + and evaluation.solvent_contribution is not None + ): + solvent_values = np.asarray( + evaluation.solvent_contribution, + dtype=float, + ) + solvent_mask = np.isfinite(solvent_values) + if self.log_y_checkbox.isChecked(): + solvent_mask &= solvent_values > 0.0 + if np.any(solvent_mask): + series_defaults.append( + LinePlotSeriesDefaults( + key="solvent_contribution", + label="Solvent contribution", + axis_label="Main", + ) + ) + if has_structure_factor_axis: + series_defaults.append( + LinePlotSeriesDefaults( + key="structure_factor", + label="Structure factor S(q)", + axis_label="Structure Factor", + ) + ) + if has_evaluation and self.show_model_trace_checkbox.isChecked(): + series_defaults.append( + LinePlotSeriesDefaults( + key="model", + label="Model", + axis_label="Main", + ) + ) + if has_experimental and has_residuals: + series_defaults.append( + LinePlotSeriesDefaults( + key="residual", + label="Residual", + axis_label="Residual", + ) + ) + return LinePlotDefaults( + title="", + x_label=Q_A_INVERSE_LABEL, + primary_y_label="Intensity (arb. units)", + secondary_y_label="S(q)", + residual_y_label="Residual", + has_secondary_y_axis=has_structure_factor_axis, + has_residual_y_axis=bool(has_experimental and has_residuals), + has_annotation=has_evaluation, + default_legend_location="best", + default_show_annotation=True, + series_defaults=tuple(series_defaults), + ) + + def _refresh_plot_editor_controls(self, *, force: bool = False) -> None: + if self._plot_editor_controls is None: + return + defaults = self._current_plot_defaults(self._current_evaluation) + self._line_plot_settings.sync_series(defaults.series_defaults) + if force or self._plot_editor_controls.needs_default_sync(defaults): + self._plot_editor_controls.sync_defaults(defaults) + + def _plot_editor_pickle_state(self) -> dict[str, object]: + return { + "plot_editor_state": { + "kind": "line_plot_editor_state", + "version": 1, + "plot_context": "saxs_prefit_preview", + "line_plot_settings": self._line_plot_settings.to_dict(), + "panel_state": { + "show_experimental": bool( + self.show_experimental_trace_checkbox.isChecked() + ), + "show_model": bool( + self.show_model_trace_checkbox.isChecked() + ), + "show_solvent": bool( + self.show_solvent_trace_checkbox.isChecked() + ), + "show_structure_factor": bool( + self.show_structure_factor_trace_checkbox.isChecked() + ), + "log_x": bool(self.log_x_checkbox.isChecked()), + "log_y": bool(self.log_y_checkbox.isChecked()), + }, + } + } + + def _apply_loaded_plot_editor_pickle_state( + self, + payload: dict[str, object], + ) -> bool: + editor_state = payload.get("plot_editor_state") + if not isinstance(editor_state, dict): + return False + if str(editor_state.get("kind")) != "line_plot_editor_state": + return False + if str(editor_state.get("plot_context")) != "saxs_prefit_preview": + return False + plot_settings = editor_state.get("line_plot_settings") + if isinstance(plot_settings, dict): + self._line_plot_settings.update_from_dict(plot_settings) + panel_state = editor_state.get("panel_state") + if isinstance(panel_state, dict): + self.show_experimental_trace_checkbox.setChecked( + bool(panel_state.get("show_experimental", True)) + ) + self.show_model_trace_checkbox.setChecked( + bool(panel_state.get("show_model", True)) + ) + self.show_solvent_trace_checkbox.setChecked( + bool(panel_state.get("show_solvent", False)) + ) + self.show_structure_factor_trace_checkbox.setChecked( + bool(panel_state.get("show_structure_factor", False)) + ) + self.log_x_checkbox.setChecked( + bool(panel_state.get("log_x", True)) + ) + self.log_y_checkbox.setChecked( + bool(panel_state.get("log_y", True)) + ) + self._refresh_plot_editor_controls(force=True) + self._redraw_current_plot() + return True + + def _render_plot_editor_preview(self, figure: Figure) -> None: + self._render_evaluation_figure( + figure, + self._current_evaluation, + interactive=False, + ) + + def _render_evaluation_figure( self, + figure: Figure, evaluation: PrefitEvaluation | None, + *, + interactive: bool, ) -> None: - self._current_evaluation = evaluation - self._legend_line_map.clear() - self._legend_handle_lookup.clear() - for axis in self.figure.axes: - axis.set_xscale("linear") - self.figure.clear() + if interactive: + self._legend_line_map.clear() + self._legend_handle_lookup.clear() + for axis in figure.axes: + try: + axis.set_xscale("linear") + axis.set_yscale("linear") + except Exception: + continue + figure.clear() self._update_prefit_trace_toggle_state(evaluation) self._update_plot_group_title() + defaults = self._current_plot_defaults(evaluation) + self._line_plot_settings.sync_series(defaults.series_defaults) if evaluation is None: - axis = self.figure.add_subplot(111) + axis = figure.add_subplot(111) axis.text( 0.5, 0.5, @@ -1730,21 +1971,24 @@ def plot_evaluation( va="center", ) axis.set_axis_off() - self.canvas.draw() + if interactive: + self._refresh_plot_editor_controls(force=True) + figure.tight_layout() return has_experimental = evaluation.experimental_intensities is not None has_residuals = evaluation.residuals is not None if has_experimental and has_residuals: - grid = self.figure.add_gridspec(2, 1, height_ratios=[3, 1]) - top = self.figure.add_subplot(grid[0, 0]) - bottom = self.figure.add_subplot(grid[1, 0], sharex=top) + grid = figure.add_gridspec(2, 1, height_ratios=[3, 1]) + top = figure.add_subplot(grid[0, 0]) + bottom = figure.add_subplot(grid[1, 0], sharex=top) else: - top = self.figure.add_subplot(111) + top = figure.add_subplot(111) bottom = None - plotted_lines = [] + plotted_lines: list[object] = [] structure_axis = None + font_family = self._line_plot_settings.font_family.strip() if ( has_experimental @@ -1754,7 +1998,10 @@ def plot_evaluation( evaluation.q_values, evaluation.experimental_intensities, color="black", - label="Experimental", + label=self._line_plot_settings.display_series_label( + "experimental", + "Experimental", + ), ) plotted_lines.append(experimental_line) @@ -1775,7 +2022,10 @@ def plot_evaluation( solvent_values[solvent_mask], color="green", linewidth=1.5, - label="Solvent contribution", + label=self._line_plot_settings.display_series_label( + "solvent_contribution", + "Solvent contribution", + ), ) plotted_lines.append(solvent_line) @@ -1801,9 +2051,17 @@ def plot_evaluation( color="tab:purple", linestyle="--", linewidth=1.5, - label="Structure factor S(q)", + label=self._line_plot_settings.display_series_label( + "structure_factor", + "Structure factor S(q)", + ), + ) + structure_axis.set_ylabel( + self._line_plot_settings.resolve_secondary_y_label( + defaults + ), + color="tab:purple", ) - structure_axis.set_ylabel("S(q)", color="tab:purple") structure_axis.tick_params(axis="y", colors="tab:purple") structure_axis.spines["right"].set_color("tab:purple") plotted_lines.append(structure_line) @@ -1813,29 +2071,81 @@ def plot_evaluation( evaluation.q_values, evaluation.model_intensities, color="tab:red", - label="Model", + label=self._line_plot_settings.display_series_label( + "model", + "Model", + ), ) plotted_lines.append(model_line) + top.set_xscale("log" if self.log_x_checkbox.isChecked() else "linear") top.set_yscale("log" if self.log_y_checkbox.isChecked() else "linear") - top.set_ylabel("Intensity (arb. units)") - top.text( - 0.02, - 0.02, - "\n".join(self._prefit_metric_lines(evaluation)), - transform=top.transAxes, - ha="left", - va="bottom", - fontsize=9, - bbox={ - "boxstyle": "round,pad=0.35", - "facecolor": "white", - "edgecolor": "0.6", - "alpha": 0.85, - }, - ) - if plotted_lines: - self._build_interactive_legend(top, plotted_lines) + top.set_ylabel( + self._line_plot_settings.resolve_primary_y_label(defaults) + ) + if self._line_plot_settings.resolve_show_annotation(defaults): + annotation_kwargs: dict[str, object] = { + "transform": top.transAxes, + "ha": "left", + "va": "bottom", + "fontsize": self._line_plot_settings.annotation_font_size, + "bbox": { + "boxstyle": "round,pad=0.35", + "facecolor": "white", + "edgecolor": "0.6", + "alpha": 0.85, + }, + } + if font_family: + annotation_kwargs["fontfamily"] = font_family + top.text( + 0.02, + 0.02, + "\n".join(self._prefit_metric_lines(evaluation)), + **annotation_kwargs, + ) + + title = self._line_plot_settings.resolve_title(defaults) + if title: + title_kwargs: dict[str, object] = { + "x": self._line_plot_settings.resolve_title_position_x( + defaults + ), + "y": self._line_plot_settings.resolve_title_position_y( + defaults + ), + "fontsize": self._line_plot_settings.title_font_size, + } + if font_family: + title_kwargs["fontfamily"] = font_family + top.set_title(title, **title_kwargs) + else: + top.set_title("") + + if plotted_lines and self._line_plot_settings.resolve_show_legend( + defaults + ): + legend_location = self._line_plot_settings.resolve_legend_location( + defaults + ) + legend_font_size = self._line_plot_settings.legend_font_size + if interactive: + self._build_interactive_legend( + top, + plotted_lines, + location=legend_location, + font_size=legend_font_size, + font_family=font_family, + ) + else: + preview_legend = top.legend( + handles=plotted_lines, + loc=legend_location, + fontsize=legend_font_size, + ) + if preview_legend is not None and font_family: + for text in preview_legend.get_texts(): + text.set_fontfamily(font_family) if bottom is not None and evaluation.residuals is not None: bottom.axhline(0.0, color="0.5", linewidth=1.0) @@ -1843,16 +2153,122 @@ def plot_evaluation( evaluation.q_values, evaluation.residuals, color="tab:blue", + label=self._line_plot_settings.display_series_label( + "residual", + "Residual", + ), ) bottom.set_xscale( "log" if self.log_x_checkbox.isChecked() else "linear" ) - bottom.set_xlabel("q (Å⁻¹)") - bottom.set_ylabel("Residual") + bottom.set_xlabel( + self._line_plot_settings.resolve_x_label(defaults) + ) + bottom.set_ylabel( + self._line_plot_settings.resolve_residual_y_label(defaults) + ) else: - top.set_xlabel("q (Å⁻¹)") - self.figure.tight_layout() + top.set_xlabel(self._line_plot_settings.resolve_x_label(defaults)) + + x_axis_label_font_size = self._line_plot_settings.axis_label_font_size + x_tick_label_font_size = self._line_plot_settings.tick_label_font_size + primary_axis_label_font_size = ( + self._line_plot_settings.resolve_primary_axis_label_font_size( + defaults + ) + ) + primary_tick_label_font_size = ( + self._line_plot_settings.resolve_primary_tick_label_font_size( + defaults + ) + ) + secondary_axis_label_font_size = ( + self._line_plot_settings.resolve_secondary_axis_label_font_size( + defaults + ) + ) + secondary_tick_label_font_size = ( + self._line_plot_settings.resolve_secondary_tick_label_font_size( + defaults + ) + ) + + for axis in figure.axes: + axis.xaxis.label.set_fontsize(x_axis_label_font_size) + axis.tick_params( + axis="x", + which="both", + labelsize=x_tick_label_font_size, + ) + if font_family: + axis.xaxis.label.set_fontfamily(font_family) + for label in list(axis.get_xticklabels()) + list( + axis.get_xticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) + + for axis in (top, bottom): + if axis is None: + continue + axis.yaxis.label.set_fontsize(primary_axis_label_font_size) + axis.tick_params( + axis="y", + which="both", + labelsize=primary_tick_label_font_size, + ) + axis.yaxis.get_offset_text().set_fontsize( + primary_tick_label_font_size + ) + if font_family: + axis.yaxis.label.set_fontfamily(font_family) + axis.yaxis.get_offset_text().set_fontfamily(font_family) + for label in list(axis.get_yticklabels()) + list( + axis.get_yticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) + + if structure_axis is not None: + structure_axis.yaxis.label.set_fontsize( + secondary_axis_label_font_size + ) + structure_axis.tick_params( + axis="y", + which="both", + labelsize=secondary_tick_label_font_size, + ) + structure_axis.yaxis.get_offset_text().set_fontsize( + secondary_tick_label_font_size + ) + if font_family: + structure_axis.yaxis.label.set_fontfamily(font_family) + structure_axis.yaxis.get_offset_text().set_fontfamily( + font_family + ) + for label in list(structure_axis.get_yticklabels()) + list( + structure_axis.get_yticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) + + if interactive: + self._refresh_plot_editor_controls() + figure.tight_layout() + + def plot_evaluation( + self, + evaluation: PrefitEvaluation | None, + ) -> None: + self._current_evaluation = evaluation + self._render_evaluation_figure( + self.figure, + evaluation, + interactive=True, + ) self.canvas.draw() + if self._plot_editor_window is not None: + self._plot_editor_window.refresh_preview() def current_evaluation(self) -> PrefitEvaluation | None: return self._current_evaluation @@ -2657,10 +3073,21 @@ def reset_cluster_geometry_progress(self) -> None: self.cluster_geometry_progress_bar.setValue(0) self.cluster_geometry_progress_bar.setFormat("%v / %m files") - def _build_interactive_legend(self, axis, lines: list[object]) -> None: - legend = axis.legend(handles=lines, loc="best") + def _build_interactive_legend( + self, + axis, + lines: list[object], + *, + location: str = "best", + font_size: float = 9.0, + font_family: str = "", + ) -> None: + legend = axis.legend(handles=lines, loc=location, fontsize=font_size) if legend is None: return + if font_family: + for text in legend.get_texts(): + text.set_fontfamily(font_family) legend_handles = getattr(legend, "legend_handles", None) if legend_handles is None: legend_handles = getattr(legend, "legendHandles", []) diff --git a/src/saxshell/saxs/ui/project_setup_tab.py b/src/saxshell/saxs/ui/project_setup_tab.py index 960e850..5346c64 100644 --- a/src/saxshell/saxs/ui/project_setup_tab.py +++ b/src/saxshell/saxs/ui/project_setup_tab.py @@ -45,6 +45,20 @@ QWidget, ) +from saxshell.plotting import ( + Q_A_INVERSE_LABEL, + LinePlotDefaults, + LinePlotEditorControls, + LinePlotSeriesDefaults, + LinePlotSettings, + PlotEditorWindow, + StackedHistogramPlotDefaults, + StackedHistogramPlotEditorControls, + StackedHistogramPlotSettings, +) +from saxshell.plotting.stacked_histogram import ( + render_stacked_histogram_export_payload, +) from saxshell.saxs._model_templates import TemplateSpec from saxshell.saxs.contrast.settings import ( COMPONENT_BUILD_MODE_NO_CONTRAST, @@ -64,7 +78,9 @@ build_prior_histogram_export_payload, build_project_paths, load_experimental_data_file, - plot_md_prior_histogram, + prior_histogram_default_legend_title, + prior_histogram_default_title, + prior_histogram_default_y_label, ) from saxshell.saxs.stoichiometry import parse_stoich_label from saxshell.saxs.ui._pane_snap import PaneSnapFilter @@ -191,6 +207,7 @@ class ProjectSetupTab(QWidget): open_xyz2pdb_requested = Signal() open_cluster_requested = Signal() open_clusterdynamicsml_requested = Signal() + open_representative_finder_requested = Signal() open_debye_waller_requested = Signal() scan_clusters_requested = Signal() build_components_requested = Signal() @@ -207,6 +224,7 @@ class ProjectSetupTab(QWidget): show_deprecated_templates_changed = Signal(bool) model_only_mode_changed = Signal(bool) predicted_structure_weights_changed = Signal(bool) + representative_structures_changed = Signal(bool) def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) @@ -241,20 +259,42 @@ def __init__(self, parent: QWidget | None = None) -> None: self._preview_update_suspend_depth = 0 self._pending_saxs_preview_redraw = False self._pending_prior_preview_redraw = False + self._pending_component_plot_axes_state: ( + list[dict[str, object]] | None + ) = None + self._pending_prior_plot_axes_state: list[dict[str, object]] | None = ( + None + ) self._active_template_name: str | None = None self._project_selected = False self._debye_waller_ready = False self._debye_waller_status_note: str | None = None + self._representative_structures_available = False + self._representative_structure_count = 0 + self._representative_structure_modes: tuple[str, ...] = () self._predicted_structures_available = False self._predicted_structure_count = 0 self._distribution_tooltips: dict[str, str] = {} self._distribution_details: dict[str, str] = {} + self._active_distribution_text = ( + "No active computed distribution loaded." + ) self._current_distribution_details_text = ( "Create or load a computed distribution to review its saved " "build attributes here." ) self._suspend_template_selection_signal = False + self._component_plot_settings = LinePlotSettings() + self._component_plot_editor_window: PlotEditorWindow | None = None + self._component_plot_editor_controls: LinePlotEditorControls | None = ( + None + ) self._prior_x_axis_custom_order: list[tuple[str, str]] = [] + self._prior_plot_settings = StackedHistogramPlotSettings() + self._prior_plot_editor_window: PlotEditorWindow | None = None + self._prior_plot_editor_controls: ( + StackedHistogramPlotEditorControls | None + ) = None self._build_ui() self._update_data_trace_control_state() self._update_component_trace_control_state() @@ -409,6 +449,17 @@ def _build_inputs_group(self) -> QGroupBox: self.use_predicted_structure_weights_checkbox.toggled.connect( self._on_predicted_structure_weights_toggled ) + layout.addRow("", self._representative_structure_controls_row()) + self.representative_structure_status_label = QLabel( + self._default_representative_structure_status_text() + ) + self.representative_structure_status_label.setWordWrap(True) + self.representative_structure_status_label.setMinimumHeight(58) + self.representative_structure_status_label.setSizePolicy( + QSizePolicy.Policy.Expanding, + QSizePolicy.Policy.MinimumExpanding, + ) + layout.addRow("", self.representative_structure_status_label) layout.addRow("", self._predicted_structure_controls_row()) self.predicted_structure_status_label = QLabel( @@ -481,6 +532,39 @@ def _predicted_structure_controls_row(self) -> QWidget: layout.addWidget(self.predict_structures_button) return row + def _representative_structure_controls_row(self) -> QWidget: + row = QWidget() + layout = QHBoxLayout(row) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(6) + self.use_representative_structures_checkbox = QCheckBox( + "Use Representative Structures" + ) + self.use_representative_structures_checkbox.setChecked(False) + self.use_representative_structures_checkbox.toggled.connect( + self._on_representative_structures_toggled + ) + layout.addWidget( + self.use_representative_structures_checkbox, + stretch=1, + ) + self.representative_structure_ready_indicator = QLabel() + self.representative_structure_ready_indicator.setFixedSize(14, 14) + layout.addWidget(self.representative_structure_ready_indicator) + self.open_representative_finder_button = QPushButton( + "Representative Structures" + ) + self.open_representative_finder_button.setToolTip( + "Open the representative-structures workflow for this project." + ) + self.open_representative_finder_button.clicked.connect( + self.open_representative_finder_requested.emit + ) + layout.addWidget(self.open_representative_finder_button) + self._refresh_representative_structure_indicator() + self._refresh_representative_structure_controls() + return row + def _build_model_group(self) -> QGroupBox: group = QGroupBox("Model and Build") layout = QVBoxLayout(group) @@ -542,7 +626,9 @@ def _build_model_group(self) -> QGroupBox: self.build_prior_weights_button.clicked.connect( self.build_prior_weights_requested.emit ) - self.debye_waller_button = QPushButton("Compute Debye-Waller Factors") + self.debye_waller_button = QPushButton( + "Compute Debye-Waller Factors (beta)" + ) self.debye_waller_button.clicked.connect( self.open_debye_waller_requested.emit ) @@ -555,7 +641,8 @@ def _build_model_group(self) -> QGroupBox: self.component_build_mode_combo.setToolTip( "Choose how this computed distribution will prepare SAXS " "components. Debye modes use the direct component builders, " - "while Born Approximation launches the electron-density mapping workflow." + "1D Born Approximation launches the legacy averaged-density workflow, " + "and 3D FFT Born Approximation launches the separate Cartesian FFT workflow." ) for mode, label in component_build_mode_choices(): self.component_build_mode_combo.addItem(label, userData=mode) @@ -590,6 +677,13 @@ def _build_model_group(self) -> QGroupBox: self.computed_distribution_combo.currentIndexChanged.connect( self._update_distribution_details_panel ) + self.active_distribution_edit = QLineEdit() + self.active_distribution_edit.setReadOnly(True) + self.active_distribution_edit.setMinimumWidth(420) + self.active_distribution_edit.setText(self._active_distribution_text) + self.active_distribution_edit.setToolTip( + self._active_distribution_text + ) self.load_distribution_button = QPushButton("Load Distribution") self.load_distribution_button.setEnabled(False) self.load_distribution_button.clicked.connect( @@ -852,6 +946,14 @@ def _distribution_row(self) -> QWidget: top_layout.addWidget(self.load_distribution_button) layout.addWidget(top_row) + active_row = QWidget() + active_layout = QHBoxLayout(active_row) + active_layout.setContentsMargins(0, 0, 0, 0) + active_layout.setSpacing(6) + active_layout.addWidget(QLabel("Active"), stretch=0) + active_layout.addWidget(self.active_distribution_edit, stretch=1) + layout.addWidget(active_row) + install_row = QWidget() install_layout = QHBoxLayout(install_row) install_layout.setContentsMargins(0, 0, 0, 0) @@ -945,6 +1047,12 @@ def _build_component_group(self) -> QGroupBox: self.save_component_plot_data_button.clicked.connect( self.save_component_plot_data_requested.emit ) + self.open_component_plot_editor_button = QPushButton( + "Open Plot Editor" + ) + self.open_component_plot_editor_button.clicked.connect( + self.open_component_plot_editor + ) controls.addWidget(self.component_log_x_checkbox) controls.addWidget(self.component_log_y_checkbox) controls.addWidget(self.component_legend_toggle_button) @@ -954,6 +1062,7 @@ def _build_component_group(self) -> QGroupBox: controls.addWidget(self.component_predicted_traces_button) controls.addWidget(QLabel("Trace Colors")) controls.addWidget(self.component_trace_color_scheme_combo) + controls.addWidget(self.open_component_plot_editor_button) controls.addWidget(self.save_component_plot_data_button) controls.addStretch(1) layout.addLayout(controls) @@ -1004,7 +1113,7 @@ def _build_prior_group(self) -> QGroupBox: userData="solvent_sort_atom_fraction", ) self.prior_mode_combo.currentTextChanged.connect( - self._update_prior_control_state + self._on_prior_mode_changed ) self.secondary_filter_label = QLabel("Secondary atom") self.secondary_filter_combo = QComboBox() @@ -1045,11 +1154,16 @@ def _build_prior_group(self) -> QGroupBox: self.edit_prior_x_axis_button.clicked.connect( self._on_edit_prior_x_axis_order ) + self.open_prior_plot_editor_button = QPushButton("Open Plot Editor") + self.open_prior_plot_editor_button.clicked.connect( + self.open_prior_plot_editor + ) controls.addWidget(QLabel("Mode")) controls.addWidget(self.prior_mode_combo) controls.addWidget(QLabel("X-Axis Ordering")) controls.addWidget(self.prior_x_axis_order_combo) controls.addWidget(self.edit_prior_x_axis_button) + controls.addWidget(self.open_prior_plot_editor_button) controls.addWidget(self.secondary_filter_label) controls.addWidget(self.secondary_filter_combo) controls.addWidget(self.generate_prior_plot_button) @@ -1214,6 +1328,8 @@ def set_project_selected(self, selected: bool) -> None: selected and self.prior_x_axis_order_combo.currentData() == "custom" ) + self.open_component_plot_editor_button.setEnabled(selected) + self.open_prior_plot_editor_button.setEnabled(selected) self.generate_prior_plot_button.setEnabled(selected) self.save_prior_png_button.setEnabled(selected) self.save_component_plot_data_button.setEnabled(selected) @@ -1223,7 +1339,9 @@ def set_project_selected(self, selected: bool) -> None: if not selected: self._experimental_summary = None self._solvent_summary = None + self.set_representative_structure_availability(False) self.set_predicted_structure_availability(False) + self.set_active_distribution(None) self._refresh_data_status_labels() self._apply_model_only_mode_state() self._refresh_debye_waller_controls() @@ -1275,6 +1393,9 @@ def set_project_settings( self.set_use_predicted_structure_weights( settings.use_predicted_structure_weights ) + self.set_use_representative_structures( + settings.use_representative_structures + ) displayed_data_path = settings.experimental_data_path or ( settings.copied_experimental_data_file or "" ) @@ -1345,6 +1466,8 @@ def set_project_settings( self.set_prior_histogram_x_axis_order( settings.prior_histogram_x_axis_order ) + self.set_component_plot_state(settings.component_plot_state) + self.set_prior_plot_state(settings.prior_plot_state) if settings.cluster_inventory_rows: self._recognized_cluster_rows = list( @@ -1603,6 +1726,25 @@ def set_current_distribution_details(self, text: str | None) -> None: ) self._update_distribution_details_panel() + def set_active_distribution( + self, + text: str | None, + *, + tooltip: str | None = None, + ) -> None: + active_text = ( + str(text or "").strip() + or "No active computed distribution loaded." + ) + self._active_distribution_text = active_text + self.active_distribution_edit.setText(active_text) + self.active_distribution_edit.setToolTip( + str(tooltip or active_text).strip() or active_text + ) + + def active_distribution_text(self) -> str: + return self.active_distribution_edit.text().strip() + def set_active_contrast_distribution_view_available( self, available: bool, @@ -1713,6 +1855,95 @@ def set_use_predicted_structure_weights(self, enabled: bool) -> None: self._refresh_predicted_structure_controls() self._update_component_trace_control_state() + def use_representative_structures(self) -> bool: + return bool(self.use_representative_structures_checkbox.isChecked()) + + def set_use_representative_structures(self, enabled: bool) -> None: + self.use_representative_structures_checkbox.blockSignals(True) + self.use_representative_structures_checkbox.setChecked(bool(enabled)) + self.use_representative_structures_checkbox.blockSignals(False) + self._refresh_representative_structure_controls() + + def set_representative_structure_availability( + self, + available: bool, + *, + representative_count: int = 0, + available_modes: tuple[str, ...] | list[str] = (), + ) -> None: + self._representative_structures_available = bool(available) + self._representative_structure_count = max( + int(representative_count), + 0, + ) + self._representative_structure_modes = tuple(available_modes) + if not self._representative_structures_available: + self.set_use_representative_structures(False) + self._refresh_representative_structure_indicator() + self._refresh_representative_structure_controls() + + def _default_representative_structure_status_text(self) -> str: + return ( + "Representative Structures mode is off.\n" + "Run Representative Structures to compute one saved project " + "representative per stoichiometry before enabling this option." + ) + + def set_representative_structure_status_text(self, text: str) -> None: + normalized = str(text).strip() + self.representative_structure_status_label.setText( + normalized or self._default_representative_structure_status_text() + ) + + def _refresh_representative_structure_indicator(self) -> None: + if self._representative_structures_available: + self.representative_structure_ready_indicator.setStyleSheet( + "background-color: #16a34a; border-radius: 7px;" + ) + mode_text = ( + ", ".join(self._representative_structure_modes) + if self._representative_structure_modes + else "partialsolv" + ) + self.representative_structure_ready_indicator.setToolTip( + f"{self._representative_structure_count} representative structure" + f"{'' if self._representative_structure_count == 1 else 's'} " + "are available in this project.\n" + f"Saved sets: {mode_text}" + ) + else: + self.representative_structure_ready_indicator.setStyleSheet( + "background-color: #6b7280; border-radius: 7px;" + ) + self.representative_structure_ready_indicator.setToolTip( + "Representative structures have not been saved for this project yet." + ) + + def _refresh_representative_structure_controls(self) -> None: + checkbox_enabled = bool( + self._project_selected + and ( + self._representative_structures_available + or self.use_representative_structures() + ) + ) + self.use_representative_structures_checkbox.setEnabled( + checkbox_enabled + ) + if self._representative_structures_available: + self.use_representative_structures_checkbox.setToolTip( + "Prefer the saved project representative structures in compatible " + "SAXS and RMCSetup workflows. When disabled, the project falls " + "back to average cluster folders." + ) + else: + self.use_representative_structures_checkbox.setToolTip( + "Run Representative Structures for this project before enabling this option." + ) + self.open_representative_finder_button.setEnabled( + self._project_selected + ) + def set_predicted_structure_availability( self, available: bool, @@ -1999,6 +2230,7 @@ def set_prior_histogram_x_axis_order( self.prior_x_axis_order_combo.setCurrentIndex(idx) self.prior_x_axis_order_combo.blockSignals(False) self.edit_prior_x_axis_button.setEnabled(False) + self._refresh_prior_plot_editor_controls(force=True) def _active_prior_x_axis_order(self) -> list[tuple[str, str]] | None: if ( @@ -2011,6 +2243,7 @@ def _active_prior_x_axis_order(self) -> list[tuple[str, str]] | None: def _on_prior_x_axis_order_changed(self, _index: int) -> None: is_custom = self.prior_x_axis_order_combo.currentData() == "custom" self.edit_prior_x_axis_button.setEnabled(is_custom) + self._refresh_prior_plot_editor_controls(force=True) self._redraw_prior_preview_if_needed() def _on_edit_prior_x_axis_order(self) -> None: @@ -2041,9 +2274,572 @@ def _on_edit_prior_x_axis_order(self) -> None: self.autosave_project_requested.emit( "updated prior histogram x-axis order" ) + self._refresh_prior_plot_editor_controls(force=True) if self.prior_x_axis_order_combo.currentData() == "custom": self._redraw_prior_preview_if_needed() + def open_component_plot_editor(self) -> None: + if self._component_plot_editor_window is not None: + self._component_plot_editor_window.show() + self._component_plot_editor_window.raise_() + self._component_plot_editor_window.activateWindow() + self._component_plot_editor_window.refresh_preview() + return + + defaults = self._current_component_plot_defaults() + self._component_plot_settings.sync_series(defaults.series_defaults) + self._component_plot_editor_controls = LinePlotEditorControls( + settings=self._component_plot_settings, + defaults=defaults, + parent=self, + ) + self._component_plot_editor_controls.label_settings_changed.connect( + self._redraw_saxs_preview + ) + self._component_plot_editor_controls.settings_changed.connect( + self._redraw_saxs_preview + ) + self._component_plot_editor_window = PlotEditorWindow( + window_title="SAXS Component Plot Editor", + controls_widget=self._component_plot_editor_controls, + render_preview=self._render_component_plot_editor_preview, + pickle_state_provider=self._component_plot_editor_pickle_state, + apply_loaded_pickle_state=self._apply_loaded_component_plot_editor_pickle_state, + parent=self, + ) + self._component_plot_editor_window.closed.connect( + self._on_component_plot_editor_closed + ) + self._component_plot_editor_window.refresh_preview() + self._component_plot_editor_window.show() + self._component_plot_editor_window.raise_() + self._component_plot_editor_window.activateWindow() + + def _on_component_plot_editor_closed(self) -> None: + self._component_plot_editor_window = None + self._component_plot_editor_controls = None + + def _current_component_plot_defaults(self) -> LinePlotDefaults: + show_data_preview = not self.model_only_mode() + has_data_preview = show_data_preview and ( + self._experimental_summary is not None + or self._solvent_summary is not None + ) + has_components = bool(self._component_paths) + has_secondary_axis = bool(has_data_preview and has_components) + series_defaults: list[LinePlotSeriesDefaults] = [] + + if ( + self._experimental_summary is not None + and show_data_preview + and self.experimental_trace_visible() + ): + series_defaults.append( + LinePlotSeriesDefaults( + key="experimental_data", + label="Experimental data", + axis_label="Data" if has_secondary_axis else "Main", + ) + ) + q_values = np.asarray( + self._experimental_summary.q_values, dtype=float + ) + selected_mask = self._selected_q_mask(q_values) + if ( + selected_mask is not None + and np.any(selected_mask) + and not np.all(selected_mask) + ): + series_defaults.append( + LinePlotSeriesDefaults( + key="selected_q_range", + label="Selected q-range", + axis_label="Data" if has_secondary_axis else "Main", + ) + ) + + if ( + self._solvent_summary is not None + and show_data_preview + and self.solvent_trace_visible() + ): + series_defaults.append( + LinePlotSeriesDefaults( + key="solvent_data", + label="Solvent data", + axis_label="Data" if has_secondary_axis else "Main", + ) + ) + solvent_q = np.asarray(self._solvent_summary.q_values, dtype=float) + solvent_mask = self._selected_q_mask(solvent_q) + if ( + solvent_mask is not None + and np.any(solvent_mask) + and not np.all(solvent_mask) + ): + series_defaults.append( + LinePlotSeriesDefaults( + key="selected_solvent_q_range", + label="Selected solvent q-range", + axis_label="Data" if has_secondary_axis else "Main", + ) + ) + + for component_path in self._component_paths or []: + series_defaults.append( + LinePlotSeriesDefaults( + key=f"component::{component_path.stem}", + label=component_path.stem, + axis_label="Model" if has_secondary_axis else "Main", + ) + ) + + if has_components: + if self._experimental_summary is not None and has_data_preview: + title = "Experimental Data and SAXS Components" + primary_y_label = "Experimental Intensity (arb. units)" + secondary_y_label = "Model Intensity (arb. units)" + elif has_data_preview: + title = "Data and SAXS Components" + primary_y_label = "Intensity (arb. units)" + secondary_y_label = "Model Intensity (arb. units)" + else: + title = "SAXS Component Preview" + primary_y_label = "Model Intensity (arb. units)" + secondary_y_label = "" + elif has_data_preview: + title = ( + "Experimental Data Preview" + if self._experimental_summary is not None + else "Data Preview" + ) + primary_y_label = "Intensity (arb. units)" + secondary_y_label = "" + else: + title = "" + primary_y_label = "Intensity (arb. units)" + secondary_y_label = "" + + return LinePlotDefaults( + title=title, + x_label=Q_A_INVERSE_LABEL, + primary_y_label=primary_y_label, + secondary_y_label=secondary_y_label, + has_secondary_y_axis=has_secondary_axis, + default_legend_location="upper right", + series_defaults=tuple(series_defaults), + ) + + def _refresh_component_plot_editor_controls( + self, *, force: bool = False + ) -> None: + if self._component_plot_editor_controls is None: + return + defaults = self._current_component_plot_defaults() + self._component_plot_settings.sync_series(defaults.series_defaults) + if force or self._component_plot_editor_controls.needs_default_sync( + defaults + ): + self._component_plot_editor_controls.sync_defaults(defaults) + + def _component_plot_editor_pickle_state(self) -> dict[str, object]: + return { + "plot_editor_state": { + "kind": "line_plot_editor_state", + "version": 1, + "plot_context": "saxs_component_preview", + "line_plot_settings": self._component_plot_settings.to_dict(), + "panel_state": { + "log_x": bool(self.component_log_x_checkbox.isChecked()), + "log_y": bool(self.component_log_y_checkbox.isChecked()), + "show_legend": bool( + self.component_legend_toggle_button.isChecked() + ), + "autoscale_to_model_range": bool( + self.component_model_range_button.isChecked() + ), + }, + } + } + + def component_plot_state(self) -> dict[str, object]: + payload = self._component_plot_editor_pickle_state() + payload["axes"] = self._capture_figure_axes_state( + self.component_figure + ) + return payload + + def _apply_loaded_component_plot_editor_pickle_state( + self, + payload: dict[str, object], + ) -> bool: + editor_state = payload.get("plot_editor_state") + if not isinstance(editor_state, dict): + return False + if str(editor_state.get("kind")) != "line_plot_editor_state": + return False + if str(editor_state.get("plot_context")) != "saxs_component_preview": + return False + + plot_settings = editor_state.get("line_plot_settings") + if isinstance(plot_settings, dict): + self._component_plot_settings.update_from_dict(plot_settings) + panel_state = editor_state.get("panel_state") + if isinstance(panel_state, dict): + self.component_log_x_checkbox.setChecked( + bool(panel_state.get("log_x", True)) + ) + self.component_log_y_checkbox.setChecked( + bool(panel_state.get("log_y", True)) + ) + self.component_legend_toggle_button.setChecked( + bool(panel_state.get("show_legend", True)) + ) + self.component_model_range_button.setChecked( + bool(panel_state.get("autoscale_to_model_range", False)) + ) + + self._refresh_component_plot_editor_controls(force=True) + self._redraw_saxs_preview() + return True + + def set_component_plot_state( + self, + payload: dict[str, object] | None, + ) -> None: + normalized = dict(payload or {}) + self._pending_component_plot_axes_state = ( + self._normalized_figure_axes_state(normalized.get("axes")) + ) + applied = False + if normalized: + applied = self._apply_loaded_component_plot_editor_pickle_state( + normalized + ) + if ( + not applied + and self._pending_component_plot_axes_state is not None + and not self._preview_updates_suspended() + ): + self._redraw_saxs_preview() + + def open_prior_plot_editor(self) -> None: + if self._prior_plot_editor_window is not None: + self._prior_plot_editor_window.show() + self._prior_plot_editor_window.raise_() + self._prior_plot_editor_window.activateWindow() + self._prior_plot_editor_window.refresh_preview() + return + + defaults = self._current_prior_plot_defaults() + self._apply_prior_plot_label_state(defaults) + self._prior_plot_editor_controls = StackedHistogramPlotEditorControls( + settings=self._prior_plot_settings, + defaults=defaults, + parent=self, + ) + self._prior_plot_editor_controls.label_settings_changed.connect( + self._on_prior_plot_editor_label_settings_changed + ) + self._prior_plot_editor_controls.settings_changed.connect( + self._redraw_prior_preview_if_needed + ) + self._prior_plot_editor_controls.colormap_changed.connect( + self._on_prior_plot_editor_colormap_changed + ) + self._prior_plot_editor_window = PlotEditorWindow( + window_title="Prior Histogram Plot Editor", + controls_widget=self._prior_plot_editor_controls, + render_preview=self._render_prior_plot_figure, + pickle_state_provider=self._prior_plot_editor_pickle_state, + apply_loaded_pickle_state=self._apply_loaded_prior_plot_editor_pickle_state, + parent=self, + ) + self._prior_plot_editor_window.closed.connect( + self._on_prior_plot_editor_closed + ) + self._prior_plot_editor_window.refresh_preview() + self._prior_plot_editor_window.show() + self._prior_plot_editor_window.raise_() + self._prior_plot_editor_window.activateWindow() + + def _on_prior_plot_editor_closed(self) -> None: + self._prior_plot_editor_window = None + self._prior_plot_editor_controls = None + + def _on_prior_plot_editor_colormap_changed( + self, colormap_name: str + ) -> None: + index = self.prior_color_combo.findText(colormap_name) + if index < 0 or index == self.prior_color_combo.currentIndex(): + return + self.prior_color_combo.setCurrentIndex(index) + + def _on_prior_plot_editor_label_settings_changed(self) -> None: + defaults = self._current_prior_plot_defaults() + entries = self._prior_plot_settings.ordered_label_entries(defaults) + default_entries = list(defaults.default_label_entries) + if entries == default_entries: + self._prior_x_axis_custom_order = [] + combo_index = self.prior_x_axis_order_combo.findData("auto") + self.prior_x_axis_order_combo.blockSignals(True) + self.prior_x_axis_order_combo.setCurrentIndex(max(combo_index, 0)) + self.prior_x_axis_order_combo.blockSignals(False) + self.edit_prior_x_axis_button.setEnabled(False) + else: + self._prior_x_axis_custom_order = list(entries) + combo_index = self.prior_x_axis_order_combo.findData("custom") + self.prior_x_axis_order_combo.blockSignals(True) + self.prior_x_axis_order_combo.setCurrentIndex(max(combo_index, 0)) + self.prior_x_axis_order_combo.blockSignals(False) + self.edit_prior_x_axis_button.setEnabled(True) + self.autosave_project_requested.emit( + "updated prior histogram x-axis order" + ) + + def _current_prior_plot_defaults(self) -> StackedHistogramPlotDefaults: + default_label_entries: tuple[tuple[str, str], ...] = () + raw_labels: tuple[str, ...] = () + if self._current_prior_json_path is not None: + try: + payload = build_prior_histogram_export_payload( + self._current_prior_json_path, + mode=self.prior_mode(), + secondary_element=self.prior_secondary_element(), + ) + except Exception: + payload = None + if payload is not None: + raw_labels = tuple(str(label) for label in payload["labels"]) + default_label_entries = tuple( + ( + str(raw_label), + str(display_label), + ) + for raw_label, display_label in zip( + payload["labels"], + payload["axis_labels"], + strict=False, + ) + ) + + return StackedHistogramPlotDefaults( + title=prior_histogram_default_title( + self.prior_mode(), + self.prior_secondary_element(), + ), + x_label="Structure", + y_label=prior_histogram_default_y_label(self.prior_mode()), + legend_title=prior_histogram_default_legend_title( + self.prior_mode(), + self.prior_secondary_element(), + ), + default_colormap_name=self.prior_cmap(), + available_colormap_names=tuple(HISTOGRAM_COLORMAP_NAMES), + raw_category_labels=raw_labels, + default_label_entries=default_label_entries, + ) + + def _apply_prior_plot_label_state( + self, + defaults: StackedHistogramPlotDefaults, + ) -> None: + self._prior_plot_settings.sync_labels( + defaults.raw_category_labels, + default_label_entries=defaults.default_label_entries, + ) + active_entries = self._active_prior_x_axis_order() + if active_entries is None: + self._prior_plot_settings.label_order = list( + defaults.raw_category_labels + ) + default_map = dict(defaults.default_label_entries) + self._prior_plot_settings.label_map = { + raw_label: default_map.get(raw_label, raw_label) + for raw_label in defaults.raw_category_labels + } + return + + ordered = [ + raw_label + for raw_label, _display_label in active_entries + if raw_label in defaults.raw_category_labels + ] + ordered_set = set(ordered) + ordered.extend( + raw_label + for raw_label in defaults.raw_category_labels + if raw_label not in ordered_set + ) + default_map = dict(defaults.default_label_entries) + custom_map = { + str(raw_label): str(display_label) + for raw_label, display_label in active_entries + } + self._prior_plot_settings.label_order = list(ordered) + self._prior_plot_settings.label_map = { + raw_label: custom_map.get( + raw_label, + default_map.get(raw_label, raw_label), + ) + for raw_label in ordered + } + + def _refresh_prior_plot_editor_controls( + self, *, force: bool = False + ) -> None: + if self._prior_plot_editor_controls is None: + return + defaults = self._current_prior_plot_defaults() + self._apply_prior_plot_label_state(defaults) + if force or self._prior_plot_editor_controls.needs_default_sync( + defaults + ): + self._prior_plot_editor_controls.sync_defaults(defaults) + + def _prior_plot_editor_pickle_state(self) -> dict[str, object]: + return { + "plot_editor_state": { + "kind": "stacked_histogram_plot_editor_state", + "version": 1, + "stacked_histogram_settings": self._prior_plot_settings.to_dict(), + "panel_state": { + "mode": self.prior_mode(), + "secondary_element": self.selected_prior_secondary_element(), + "colormap_name": self.prior_cmap(), + "match_trace_colors": self.prior_match_trace_colors(), + "x_axis_order_mode": str( + self.prior_x_axis_order_combo.currentData() or "auto" + ), + }, + } + } + + def prior_plot_state(self) -> dict[str, object]: + payload = self._prior_plot_editor_pickle_state() + payload["axes"] = self._capture_figure_axes_state(self.prior_figure) + return payload + + def _apply_loaded_prior_plot_editor_pickle_state( + self, + payload: dict[str, object], + ) -> bool: + editor_state = payload.get("plot_editor_state") + if not isinstance(editor_state, dict): + return False + if ( + str(editor_state.get("kind")) + != "stacked_histogram_plot_editor_state" + ): + return False + + histogram_settings = editor_state.get("stacked_histogram_settings") + if isinstance(histogram_settings, dict): + self._prior_plot_settings.update_from_dict(histogram_settings) + + panel_state = editor_state.get("panel_state") + if isinstance(panel_state, dict): + self._apply_prior_plot_panel_state_from_pickle(panel_state) + + defaults = self._current_prior_plot_defaults() + requested_mode = ( + str(panel_state.get("x_axis_order_mode") or "").strip() + if isinstance(panel_state, dict) + else "" + ) + loaded_entries = self._prior_plot_settings.ordered_label_entries( + defaults + ) + if requested_mode == "custom" or loaded_entries != list( + defaults.default_label_entries + ): + self._prior_x_axis_custom_order = list(loaded_entries) + combo_index = self.prior_x_axis_order_combo.findData("custom") + self.prior_x_axis_order_combo.blockSignals(True) + self.prior_x_axis_order_combo.setCurrentIndex(max(combo_index, 0)) + self.prior_x_axis_order_combo.blockSignals(False) + self.edit_prior_x_axis_button.setEnabled(True) + else: + self._prior_x_axis_custom_order = [] + combo_index = self.prior_x_axis_order_combo.findData("auto") + self.prior_x_axis_order_combo.blockSignals(True) + self.prior_x_axis_order_combo.setCurrentIndex(max(combo_index, 0)) + self.prior_x_axis_order_combo.blockSignals(False) + self.edit_prior_x_axis_button.setEnabled(False) + + self._refresh_prior_plot_editor_controls(force=True) + self.draw_prior_plot(self._current_prior_json_path) + return True + + def set_prior_plot_state( + self, + payload: dict[str, object] | None, + ) -> None: + normalized = dict(payload or {}) + self._pending_prior_plot_axes_state = ( + self._normalized_figure_axes_state(normalized.get("axes")) + ) + applied = False + if normalized: + applied = self._apply_loaded_prior_plot_editor_pickle_state( + normalized + ) + if ( + not applied + and self._pending_prior_plot_axes_state is not None + and not self._preview_updates_suspended() + ): + self.draw_prior_plot(self._current_prior_json_path) + + def _apply_prior_plot_panel_state_from_pickle( + self, + panel_state: dict[str, object], + ) -> None: + self.prior_mode_combo.blockSignals(True) + self.secondary_filter_combo.blockSignals(True) + self.prior_color_combo.blockSignals(True) + self.prior_match_trace_colors_checkbox.blockSignals(True) + try: + self._set_combo_data_if_present( + self.prior_mode_combo, + panel_state.get("mode"), + ) + secondary_element = panel_state.get("secondary_element") + if secondary_element is not None: + index = self.secondary_filter_combo.findText( + str(secondary_element) + ) + if index >= 0: + self.secondary_filter_combo.setCurrentIndex(index) + self._set_combo_text_if_present( + self.prior_color_combo, + panel_state.get("colormap_name"), + ) + if "match_trace_colors" in panel_state: + self.prior_match_trace_colors_checkbox.setChecked( + bool(panel_state["match_trace_colors"]) + ) + finally: + self.prior_mode_combo.blockSignals(False) + self.secondary_filter_combo.blockSignals(False) + self.prior_color_combo.blockSignals(False) + self.prior_match_trace_colors_checkbox.blockSignals(False) + self._update_prior_control_state() + + @staticmethod + def _set_combo_data_if_present(combo: QComboBox, value: object) -> None: + index = combo.findData(value) + if index >= 0: + combo.setCurrentIndex(index) + + @staticmethod + def _set_combo_text_if_present(combo: QComboBox, value: object) -> None: + if value is None: + return + index = combo.findText(str(value)) + if index >= 0: + combo.setCurrentIndex(index) + def prior_mode(self) -> str: return str(self.prior_mode_combo.currentData() or "structure_fraction") @@ -2449,23 +3245,31 @@ def reset_activity_progress(self) -> None: self.activity_progress_bar.setValue(0) self.activity_progress_bar.setFormat("%v / %m items") - def _redraw_saxs_preview(self) -> None: - if self._preview_updates_suspended(): - self._pending_saxs_preview_redraw = True - return - for axis in list(self.component_figure.axes): + def _render_component_plot_editor_preview(self, figure: Figure) -> None: + self._render_component_plot_figure(figure, interactive=False) + + def _render_component_plot_figure( + self, + figure: Figure, + *, + interactive: bool, + ) -> None: + if interactive: + self._legend_line_map.clear() + self._component_legend_lookup.clear() + self._component_line_lookup.clear() + self._component_color_lookup.clear() + self._observed_component_keys = [] + self._predicted_component_keys = [] + for axis in list(figure.axes): try: axis.set_xscale("linear") axis.set_yscale("linear") except Exception: continue - self.component_figure.clear() - self._legend_line_map.clear() - self._component_legend_lookup.clear() - self._component_line_lookup.clear() - self._component_color_lookup.clear() - self._observed_component_keys = [] - self._predicted_component_keys = [] + figure.clear() + defaults = self._current_component_plot_defaults() + self._component_plot_settings.sync_series(defaults.series_defaults) show_data_preview = not self.model_only_mode() has_data_preview = show_data_preview and ( self._experimental_summary is not None @@ -2473,7 +3277,7 @@ def _redraw_saxs_preview(self) -> None: ) has_components = bool(self._component_paths) if not has_data_preview and not has_components: - axis = self.component_figure.add_subplot(111) + axis = figure.add_subplot(111) axis.text( 0.5, 0.5, @@ -2494,13 +3298,14 @@ def _redraw_saxs_preview(self) -> None: wrap=True, ) axis.set_axis_off() - self._update_component_table_visuals() - self._update_component_trace_control_state() - self.component_figure.tight_layout() - self.component_canvas.draw() + if interactive: + self._refresh_component_plot_editor_controls() + self._update_component_table_visuals() + self._update_component_trace_control_state() + figure.tight_layout() return - base_axis = self.component_figure.add_subplot(111) + base_axis = figure.add_subplot(111) experimental_axis = base_axis if has_data_preview else None component_axis = ( base_axis if has_components and not has_data_preview else None @@ -2523,42 +3328,97 @@ def _redraw_saxs_preview(self) -> None: self._draw_component_profiles( component_axis, self._component_paths or [], + track_lines=interactive, ) ) if ( self._experimental_summary is not None and experimental_axis is not None + and component_axis is not experimental_axis ): self._normalize_component_axis( experimental_axis, component_axis, ) - experimental_axis.set_ylabel( - "Experimental Intensity (arb. units)" + + x_label = self._component_plot_settings.resolve_x_label(defaults) + title = self._component_plot_settings.resolve_title(defaults) + title_x = self._component_plot_settings.resolve_title_position_x( + defaults + ) + title_y = self._component_plot_settings.resolve_title_position_y( + defaults + ) + title_font_size = self._component_plot_settings.title_font_size + x_axis_label_font_size = ( + self._component_plot_settings.axis_label_font_size + ) + x_tick_label_font_size = ( + self._component_plot_settings.tick_label_font_size + ) + primary_axis_label_font_size = ( + self._component_plot_settings.resolve_primary_axis_label_font_size( + defaults + ) + ) + primary_tick_label_font_size = ( + self._component_plot_settings.resolve_primary_tick_label_font_size( + defaults + ) + ) + secondary_axis_label_font_size = self._component_plot_settings.resolve_secondary_axis_label_font_size( + defaults + ) + secondary_tick_label_font_size = self._component_plot_settings.resolve_secondary_tick_label_font_size( + defaults + ) + font_family = self._component_plot_settings.font_family.strip() + primary_y_label = ( + self._component_plot_settings.resolve_primary_y_label(defaults) + ) + + if experimental_axis is not None: + experimental_axis.set_xlabel(x_label) + experimental_axis.set_ylabel(primary_y_label) + elif component_axis is not None: + component_axis.set_xlabel(x_label) + component_axis.set_ylabel(primary_y_label) + + if ( + component_axis is not None + and experimental_axis is not None + and component_axis is not experimental_axis + ): + component_axis.set_ylabel( + self._component_plot_settings.resolve_secondary_y_label( + defaults ) - component_axis.set_ylabel("Model Intensity (arb. units)") - base_axis.set_title("Experimental Data and SAXS Components") - elif has_data_preview and experimental_axis is not None: - experimental_axis.set_ylabel("Intensity (arb. units)") - component_axis.set_ylabel("Model Intensity (arb. units)") - base_axis.set_title("Data and SAXS Components") - else: - component_axis.set_ylabel("Model Intensity (arb. units)") - base_axis.set_title("SAXS Component Preview") - elif experimental_axis is not None: - experimental_axis.set_ylabel("Intensity (arb. units)") - if self._experimental_summary is not None: - base_axis.set_title("Experimental Data Preview") - else: - base_axis.set_title("Data Preview") + ) + + if title: + title_kwargs: dict[str, object] = { + "x": title_x, + "y": title_y, + "fontsize": title_font_size, + } + if font_family: + title_kwargs["fontfamily"] = font_family + base_axis.set_title(title, **title_kwargs) + else: + base_axis.set_title("") if ( component_axis is not None and self.component_model_range_button.isChecked() ): self._autoscale_to_model_range( - experimental_axis, + ( + experimental_axis + if component_axis is not experimental_axis + else None + ), component_axis, + list(component_axis.get_lines()), ) anchor_axis = experimental_axis or component_axis @@ -2566,13 +3426,116 @@ def _redraw_saxs_preview(self) -> None: anchor_axis is not None and plotted_lines and self.component_legend_toggle_button.isChecked() + and self._component_plot_settings.resolve_show_legend(defaults) ): - self._build_interactive_legend(anchor_axis, plotted_lines) + legend_location = ( + self._component_plot_settings.resolve_legend_location(defaults) + ) + legend_font_size = self._component_plot_settings.legend_font_size + if interactive: + self._build_interactive_legend( + anchor_axis, + plotted_lines, + location=legend_location, + font_size=legend_font_size, + font_family=font_family, + ) + else: + preview_legend = anchor_axis.legend( + plotted_lines, + [line.get_label() for line in plotted_lines], + loc=legend_location, + fontsize=legend_font_size, + framealpha=0.9, + ) + if preview_legend is not None and font_family: + for text in preview_legend.get_texts(): + text.set_fontfamily(font_family) + + for axis in figure.axes: + axis.xaxis.label.set_fontsize(x_axis_label_font_size) + axis.tick_params( + axis="x", + which="both", + labelsize=x_tick_label_font_size, + ) + if font_family: + axis.xaxis.label.set_fontfamily(font_family) + for label in list(axis.get_xticklabels()) + list( + axis.get_xticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) - self._update_component_table_visuals() - self._update_component_trace_control_state() - self.component_figure.tight_layout() + primary_y_axis = experimental_axis or component_axis + if primary_y_axis is not None: + primary_y_axis.yaxis.label.set_fontsize( + primary_axis_label_font_size + ) + primary_y_axis.tick_params( + axis="y", + which="both", + labelsize=primary_tick_label_font_size, + ) + primary_y_axis.yaxis.get_offset_text().set_fontsize( + primary_tick_label_font_size + ) + if font_family: + primary_y_axis.yaxis.label.set_fontfamily(font_family) + primary_y_axis.yaxis.get_offset_text().set_fontfamily( + font_family + ) + for label in list(primary_y_axis.get_yticklabels()) + list( + primary_y_axis.get_yticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) + + if ( + component_axis is not None + and experimental_axis is not None + and component_axis is not experimental_axis + ): + component_axis.yaxis.label.set_fontsize( + secondary_axis_label_font_size + ) + component_axis.tick_params( + axis="y", + which="both", + labelsize=secondary_tick_label_font_size, + ) + component_axis.yaxis.get_offset_text().set_fontsize( + secondary_tick_label_font_size + ) + if font_family: + component_axis.yaxis.label.set_fontfamily(font_family) + component_axis.yaxis.get_offset_text().set_fontfamily( + font_family + ) + for label in list(component_axis.get_yticklabels()) + list( + component_axis.get_yticklabels(minor=True) + ): + if font_family: + label.set_fontfamily(font_family) + + if interactive: + self._refresh_component_plot_editor_controls() + self._update_component_table_visuals() + self._update_component_trace_control_state() + figure.tight_layout() + + def _redraw_saxs_preview(self) -> None: + if self._preview_updates_suspended(): + self._pending_saxs_preview_redraw = True + return + self._render_component_plot_figure( + self.component_figure, + interactive=True, + ) + self._apply_pending_component_plot_axes_state(self.component_figure) self.component_canvas.draw() + if self._component_plot_editor_window is not None: + self._component_plot_editor_window.refresh_preview() def draw_prior_plot(self, json_path: str | Path | None) -> None: self._current_prior_json_path = ( @@ -2583,9 +3546,19 @@ def draw_prior_plot(self, json_path: str | Path | None) -> None: if self._preview_updates_suspended(): self._pending_prior_preview_redraw = True return - self.prior_figure.clear() - if json_path is None: - axis = self.prior_figure.add_subplot(111) + self._render_prior_plot_figure(self.prior_figure) + self.prior_canvas.draw_idle() + if self._prior_plot_editor_window is not None: + self._prior_plot_editor_window.refresh_preview() + + def _render_prior_plot_figure(self, figure: Figure) -> None: + defaults = self._current_prior_plot_defaults() + self._apply_prior_plot_label_state(defaults) + self._refresh_prior_plot_editor_controls() + + figure.clear() + if self._current_prior_json_path is None: + axis = figure.add_subplot(111) axis.text( 0.5, 0.5, @@ -2594,29 +3567,38 @@ def draw_prior_plot(self, json_path: str | Path | None) -> None: va="center", ) axis.set_axis_off() - else: - axis = self.prior_figure.add_subplot(111) - try: - plot_md_prior_histogram( - json_path, - mode=self.prior_mode(), - secondary_element=self.prior_secondary_element(), - cmap=self.prior_cmap(), - structure_motif_colors=self.prior_structure_motif_colors(), - custom_label_order=self._active_prior_x_axis_order(), - ax=axis, - ) - except Exception as exc: - axis.text( - 0.5, - 0.5, - str(exc), - ha="center", - va="center", - wrap=True, - ) - axis.set_axis_off() - self.prior_canvas.draw() + return + + axis = figure.add_subplot(111) + try: + export_payload = build_prior_histogram_export_payload( + self._current_prior_json_path, + mode=self.prior_mode(), + secondary_element=self.prior_secondary_element(), + custom_label_order=self._active_prior_x_axis_order(), + ) + render_stacked_histogram_export_payload( + export_payload, + ax=axis, + defaults=defaults, + settings=self._prior_plot_settings, + cmap=self.prior_cmap(), + structure_segment_colors=self.prior_structure_motif_colors(), + show_percent=True, + ) + except Exception as exc: + axis.text( + 0.5, + 0.5, + str(exc), + ha="center", + va="center", + wrap=True, + ) + axis.set_axis_off() + return + + self._apply_pending_prior_plot_axes_state(figure) def refresh_available_elements(self) -> None: self.request_cluster_scan() @@ -3404,6 +4386,15 @@ def _on_predicted_structure_weights_toggled(self, enabled: bool) -> None: else "disabled Use Predicted Structure Weights" ) + def _on_representative_structures_toggled(self, enabled: bool) -> None: + self._refresh_representative_structure_controls() + self.representative_structures_changed.emit(bool(enabled)) + self.autosave_project_requested.emit( + "enabled Use Representative Structures" + if enabled + else "disabled Use Representative Structures" + ) + def _update_resample_grid_state(self) -> None: self.resample_points_spin.setEnabled( not self.use_experimental_grid_checkbox.isChecked() @@ -3464,6 +4455,11 @@ def _update_secondary_filter_options( def _prior_mode_uses_secondary_filter(self) -> bool: return self.prior_mode().startswith("solvent_sort") + def _on_prior_mode_changed(self, _text: str) -> None: + self._update_prior_control_state() + self._refresh_prior_plot_editor_controls(force=True) + self._redraw_prior_preview_if_needed() + def _update_prior_control_state(self) -> None: uses_secondary = self._prior_mode_uses_secondary_filter() has_secondary_options = self.secondary_filter_combo.count() > 0 @@ -3490,8 +4486,6 @@ def _update_prior_control_state(self) -> None: self.secondary_filter_combo.setToolTip("") def _redraw_prior_preview_if_needed(self) -> None: - if self._current_prior_json_path is None: - return if self._preview_updates_suspended(): self._pending_prior_preview_redraw = True return @@ -3501,6 +4495,116 @@ def _on_component_trace_color_scheme_changed(self) -> None: self._redraw_saxs_preview() self._redraw_prior_preview_if_needed() + @staticmethod + def _capture_figure_axes_state( + figure: Figure, + ) -> list[dict[str, object]]: + states: list[dict[str, object]] = [] + for axis in figure.axes: + if not axis.has_data(): + continue + x_limits = axis.get_xlim() + y_limits = axis.get_ylim() + states.append( + { + "xscale": str(axis.get_xscale()), + "yscale": str(axis.get_yscale()), + "xlim": [float(x_limits[0]), float(x_limits[1])], + "ylim": [float(y_limits[0]), float(y_limits[1])], + } + ) + return states + + @staticmethod + def _normalized_figure_axes_state( + raw: object, + ) -> list[dict[str, object]] | None: + if not isinstance(raw, list): + return None + states: list[dict[str, object]] = [] + for entry in raw: + if not isinstance(entry, dict): + continue + state: dict[str, object] = {} + xscale = str(entry.get("xscale", "")).strip() + yscale = str(entry.get("yscale", "")).strip() + if xscale: + state["xscale"] = xscale + if yscale: + state["yscale"] = yscale + for key in ("xlim", "ylim"): + limits = entry.get(key) + if not isinstance(limits, (list, tuple)) or len(limits) != 2: + continue + try: + lower = float(limits[0]) + upper = float(limits[1]) + except (TypeError, ValueError): + continue + if ( + np.isfinite(lower) + and np.isfinite(upper) + and lower != upper + ): + state[key] = [lower, upper] + if state: + states.append(state) + return states or None + + def _apply_figure_axes_state( + self, + figure: Figure, + axes_state: list[dict[str, object]] | None, + ) -> bool: + if not axes_state: + return True + data_axes = [axis for axis in figure.axes if axis.has_data()] + if not data_axes: + return False + applied = False + for axis, state in zip(data_axes, axes_state): + try: + xscale = str(state.get("xscale", "")).strip() + if xscale and axis.get_xscale() != xscale: + axis.set_xscale(xscale) + yscale = str(state.get("yscale", "")).strip() + if yscale and axis.get_yscale() != yscale: + axis.set_yscale(yscale) + x_limits = state.get("xlim") + if isinstance(x_limits, list) and len(x_limits) == 2: + axis.set_xlim(float(x_limits[0]), float(x_limits[1])) + y_limits = state.get("ylim") + if isinstance(y_limits, list) and len(y_limits) == 2: + axis.set_ylim(float(y_limits[0]), float(y_limits[1])) + except Exception: + continue + applied = True + return applied + + def _apply_pending_component_plot_axes_state( + self, + figure: Figure, + ) -> None: + if self._pending_component_plot_axes_state is None: + return + if self._apply_figure_axes_state( + figure, + self._pending_component_plot_axes_state, + ): + self._pending_component_plot_axes_state = None + + def _apply_pending_prior_plot_axes_state( + self, + figure: Figure, + ) -> None: + if self._pending_prior_plot_axes_state is None: + return + if self._apply_figure_axes_state( + figure, + self._pending_prior_plot_axes_state, + ): + self._pending_prior_plot_axes_state = None + def _draw_experimental_preview( self, axis, @@ -3511,31 +4615,43 @@ def _draw_experimental_preview( q_values = np.asarray(summary.q_values, dtype=float) intensities = np.asarray(summary.intensities, dtype=float) exp_color = self.experimental_trace_color() + experimental_label = ( + self._component_plot_settings.display_series_label( + "experimental_data", + "Experimental data", + ) + ) (full_line,) = axis.plot( q_values, intensities, color=exp_color, alpha=0.35, linewidth=1.3, - label="Experimental data", + label=experimental_label, ) lines.append(full_line) selected_mask = self._selected_q_mask(q_values) if selected_mask is not None and np.any(selected_mask): if not np.all(selected_mask): + selected_label = ( + self._component_plot_settings.display_series_label( + "selected_q_range", + "Selected q-range", + ) + ) (selected_line,) = axis.plot( q_values[selected_mask], intensities[selected_mask], color=exp_color, linewidth=1.8, - label="Selected q-range", + label=selected_label, ) lines.append(selected_line) else: full_line.set_alpha(1.0) full_line.set_linewidth(1.8) - full_line.set_label("Experimental data") + full_line.set_label(experimental_label) else: axis.text( 0.5, @@ -3557,13 +4673,17 @@ def _draw_experimental_preview( dtype=float, ) solvent_color = self.solvent_trace_color() + solvent_label = self._component_plot_settings.display_series_label( + "solvent_data", + "Solvent data", + ) (solvent_line,) = axis.plot( solvent_q_values, solvent_intensities, color=solvent_color, alpha=0.45, linewidth=1.3, - label="Solvent data", + label=solvent_label, ) lines.append(solvent_line) @@ -3572,18 +4692,24 @@ def _draw_experimental_preview( solvent_selected_mask ): if not np.all(solvent_selected_mask): + selected_solvent_label = ( + self._component_plot_settings.display_series_label( + "selected_solvent_q_range", + "Selected solvent q-range", + ) + ) (selected_solvent_line,) = axis.plot( solvent_q_values[solvent_selected_mask], solvent_intensities[solvent_selected_mask], color=solvent_color, linewidth=1.8, - label="Selected solvent q-range", + label=selected_solvent_label, ) lines.append(selected_solvent_line) else: solvent_line.set_alpha(1.0) solvent_line.set_linewidth(1.8) - solvent_line.set_label("Solvent data") + solvent_line.set_label(solvent_label) self._apply_saxs_axis_style(axis, is_component_axis=False) return lines @@ -3592,6 +4718,8 @@ def _draw_component_profiles( self, axis, component_paths: list[Path], + *, + track_lines: bool = True, ) -> list[object]: if not component_paths: axis.text( @@ -3615,22 +4743,29 @@ def _draw_component_profiles( visible = self._component_visibility.get(component_key, True) color_override = self._component_color_overrides.get(component_key) line_color = color_override or scheme_colors.get(component_key) + display_label = self._component_plot_settings.display_series_label( + f"component::{component_key}", + component_path.stem, + ) (line,) = axis.plot( q_values, intensities, - label=component_path.stem, + label=display_label, linewidth=1.4, visible=visible, color=line_color, ) line.set_gid(component_key) - self._component_visibility.setdefault(component_key, visible) - self._component_line_lookup[component_key] = line - self._component_color_lookup[component_key] = str(line.get_color()) - if source_kind == "predicted_structure": - self._predicted_component_keys.append(component_key) - else: - self._observed_component_keys.append(component_key) + if track_lines: + self._component_visibility.setdefault(component_key, visible) + self._component_line_lookup[component_key] = line + self._component_color_lookup[component_key] = str( + line.get_color() + ) + if source_kind == "predicted_structure": + self._predicted_component_keys.append(component_key) + else: + self._observed_component_keys.append(component_key) lines.append(line) self._apply_saxs_axis_style(axis, is_component_axis=True) return lines @@ -3673,7 +4808,7 @@ def _apply_saxs_axis_style(self, axis, *, is_component_axis: bool) -> None: "log" if self.component_log_y_checkbox.isChecked() else "linear" ) if not is_component_axis or self._experimental_summary is None: - axis.set_xlabel("q (Å⁻¹)") + axis.set_xlabel(Q_A_INVERSE_LABEL) if not is_component_axis: axis.set_ylabel("Intensity (arb. units)") @@ -3731,22 +4866,37 @@ def _normalize_component_axis( ) component_axis.set_ylim(right_limits) - def _build_interactive_legend(self, axis, lines: list[object]) -> None: + def _build_interactive_legend( + self, + axis, + lines: list[object], + *, + location: str = "upper right", + font_size: float = 9.0, + font_family: str = "", + ) -> None: legend_columns = max(1, int(np.ceil(len(lines) / 5.0))) + legend_kwargs: dict[str, object] = { + "fontsize": font_size, + "loc": location, + "borderaxespad": 0.3, + "framealpha": 0.9, + "ncols": legend_columns, + "columnspacing": 0.9, + "handlelength": 1.5, + } + if location == "upper right": + legend_kwargs["bbox_to_anchor"] = (0.985, 0.985) legend = axis.legend( lines, [line.get_label() for line in lines], - fontsize="small", - loc="upper right", - bbox_to_anchor=(0.985, 0.985), - borderaxespad=0.3, - framealpha=0.9, - ncols=legend_columns, - columnspacing=0.9, - handlelength=1.5, + **legend_kwargs, ) if legend is None: return + if font_family: + for text in legend.get_texts(): + text.set_fontfamily(font_family) self._legend_line_map.clear() self._component_legend_lookup.clear() legend_handles = getattr(legend, "legend_handles", None) @@ -4050,12 +5200,18 @@ def _autoscale_to_model_range( self, experimental_axis, component_axis, + component_lines: list[object] | None = None, ) -> None: - component_lines = [ - line - for line in self._component_line_lookup.values() - if line.get_visible() - ] + if component_lines is None: + component_lines = [ + line + for line in self._component_line_lookup.values() + if line.get_visible() + ] + else: + component_lines = [ + line for line in component_lines if line.get_visible() + ] if not component_lines: return model_q_values = np.concatenate( diff --git a/src/saxshell/saxs/ui/solution_scattering_widget.py b/src/saxshell/saxs/ui/solution_scattering_widget.py index c20d3e0..75007b2 100644 --- a/src/saxshell/saxs/ui/solution_scattering_widget.py +++ b/src/saxshell/saxs/ui/solution_scattering_widget.py @@ -641,16 +641,24 @@ def set_target_parameter( self, parameter_name: str | None, fraction_kind: str | None, + fraction_source: str = "saxs_effective", solvent_weight_parameter: str | None = None, ) -> None: messages: list[str] = [] if parameter_name and fraction_kind: label = "solute" if fraction_kind == "solute" else "solvent" - messages.append( - f"{parameter_name} ({label} SAXS-effective interaction fraction)" + source_label = ( + "physical volume fraction" + if str(fraction_source).strip() == "physical" + else "SAXS-effective interaction fraction" ) + messages.append(f"{parameter_name} ({label} {source_label})") if solvent_weight_parameter: - if parameter_name and fraction_kind: + if ( + parameter_name + and fraction_kind + and str(fraction_source).strip() == "saxs_effective" + ): messages.append( f"{solvent_weight_parameter} (attenuation solvent scale)" ) diff --git a/src/saxshell/saxshell.py b/src/saxshell/saxshell.py index 750038c..0efae7e 100644 --- a/src/saxshell/saxshell.py +++ b/src/saxshell/saxshell.py @@ -141,6 +141,18 @@ def main(argv: list[str] | None = None) -> int: nargs=argparse.REMAINDER, help="Arguments passed through to the fullrmc command.", ) + representativefinder_parser = subparsers.add_parser( + "representativefinder", + help=( + "Build or run project-backed representative-structure analysis " + "run files." + ), + ) + representativefinder_parser.add_argument( + "args", + nargs=argparse.REMAINDER, + help="Arguments passed through to the representativefinder command.", + ) args = parser.parse_args(argv) @@ -226,6 +238,16 @@ def main(argv: list[str] | None = None) -> int: forwarded_args = forwarded_args[1:] return fullrmc_main(forwarded_args) + if args.command == "representativefinder": + from saxshell.representativefinder.cli import ( + main as representativefinder_main, + ) + + forwarded_args = list(args.args) + if forwarded_args[:1] == ["--"]: + forwarded_args = forwarded_args[1:] + return representativefinder_main(forwarded_args) + parser.print_help() return 0 diff --git a/tests/test_contrast_fft_backend.py b/tests/test_contrast_fft_backend.py new file mode 100644 index 0000000..3282b75 --- /dev/null +++ b/tests/test_contrast_fft_backend.py @@ -0,0 +1,44 @@ +import numpy as np + +from saxshell.saxs.born_refinement.backend import build_shared_q_grid +from saxshell.saxs.contrast_fft import ( + ContrastFFTSettings, + compute_contrast_fft_intensity, +) + + +def test_shared_q_grid_preserves_requested_upper_bound_for_partial_step(): + q_values = build_shared_q_grid(0.0101, 1.1976, q_step=0.01) + + np.testing.assert_allclose(q_values[0], 0.0101) + np.testing.assert_allclose(q_values[-2], 1.1901) + np.testing.assert_allclose(q_values[-1], 1.1976) + + +def test_single_atom_bare_density_uses_direct_born_trace_when_fft_bins_are_empty(): + coordinates = np.asarray([[0.0, 0.0, 0.0]], dtype=float) + weights = np.asarray([6.0], dtype=float) + q_values = np.asarray([0.01, 0.02, 0.03], dtype=float) + settings = ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=80.0, + padding_a=4.0, + ).normalized() + + result = compute_contrast_fft_intensity( + coordinates, + weights, + q_values, + settings, + elements=("C",), + ) + + np.testing.assert_allclose(result.raw_intensity, [36.0, 36.0, 36.0]) + np.testing.assert_allclose( + result.kernel_corrected_intensity, + result.raw_intensity, + ) + assert np.all(result.q_shell_counts == 1) + assert result.contrast_mode == "single_atom_bare_density_direct_born" + assert result.first_nonempty_q_a_inverse == q_values[0] diff --git a/tests/test_electron_density_mapping.py b/tests/test_electron_density_mapping.py index f7c12aa..905f91e 100644 --- a/tests/test_electron_density_mapping.py +++ b/tests/test_electron_density_mapping.py @@ -19,6 +19,7 @@ import saxshell.saxs.electron_density_mapping.workflow as density_workflow from saxshell.saxs.contrast.electron_density import ( CONTRAST_SOLVENT_METHOD_DIRECT, + CONTRAST_SOLVENT_METHOD_NEAT, ContrastSolventDensitySettings, ) from saxshell.saxs.debye import atomic_form_factor @@ -70,6 +71,22 @@ def _accept_default_questions(monkeypatch): ) +def test_fourier_transform_q_grid_preserves_partial_step_endpoint(): + settings = ElectronDensityFourierTransformSettings( + r_min=0.0, + r_max=1.0, + q_min=0.0101, + q_max=1.1976, + q_step=0.01, + ) + + q_values = density_workflow._q_values_from_transform_settings(settings) + + np.testing.assert_allclose(q_values[0], 0.0101) + np.testing.assert_allclose(q_values[-2], 1.1901) + np.testing.assert_allclose(q_values[-1], 1.1976) + + def _write_xyz(path: Path, lines: list[str]) -> Path: path.write_text("\n".join(lines) + "\n", encoding="utf-8") return path @@ -457,6 +474,19 @@ def _assert_batch_progress_dialog_visible( return dialog +def _assert_debye_progress_dialog_visible( + window: ElectronDensityMappingMainWindow, + *, + total: int, +): + dialog = window._debye_scattering_progress_dialog + assert dialog is not None + assert dialog.isVisible() + assert dialog.windowTitle() == "Computing Debye Scattering" + assert dialog.progress_bar.maximum() == total + return dialog + + def _table_column_index(table, header_text: str) -> int: for column_index in range(table.columnCount()): header_item = table.horizontalHeaderItem(column_index) @@ -2500,6 +2530,109 @@ def test_main_window_output_history_restores_preview_entries_in_non_preview_mode restored_window.close() +def test_main_window_saved_debye_output_can_be_compared_and_reloaded_later( + qapp, + tmp_path, +): + structure_path = _write_xyz( + tmp_path / "history_debye_restore.xyz", + [ + "4", + "history debye restore", + "N 0.0 0.0 0.1", + "H 0.94 0.0 -0.2", + "H -0.47 0.81 -0.2", + "H -0.47 -0.81 -0.2", + ], + ) + window = ElectronDensityMappingMainWindow( + initial_input_path=structure_path + ) + + window.run_button.click() + _wait_for( + lambda: window._profile_result is not None + and window._calculation_thread is None, + qapp, + ) + window.evaluate_fourier_button.click() + _wait_for( + lambda: window._fourier_result is not None + and window.calculate_debye_scattering_button.isEnabled(), + qapp, + ) + window.calculate_debye_scattering_button.click() + _wait_for( + lambda: window._debye_scattering_result is not None + and window._debye_scattering_thread is None, + qapp, + ) + + assert window._debye_scattering_result is not None + assert window.output_history_table.rowCount() == 3 + assert window.output_history_table.item(0, 1).text() == "Debye Scattering" + saved_debye_entry = window._saved_output_entries[-1] + assert saved_debye_entry.debye_scattering_result is not None + assert saved_debye_entry.transform_result is not None + + window.fourier_qmax_spin.setValue(window.fourier_qmax_spin.value() + 0.1) + window.evaluate_fourier_button.click() + qapp.processEvents() + + assert window._fourier_result is not None + assert window._debye_scattering_result is None + assert window.output_history_table.rowCount() == 4 + + debye_row = next( + row_index + for row_index in range(window.output_history_table.rowCount()) + if ( + window.output_history_table.item(row_index, 0).data( + Qt.ItemDataRole.UserRole + ) + == saved_debye_entry.entry_id + ) + ) + selection_model = window.output_history_table.selectionModel() + selection_model.clearSelection() + selection_model.select( + window.output_history_table.model().index(debye_row, 0), + QItemSelectionModel.SelectionFlag.Select + | QItemSelectionModel.SelectionFlag.Rows, + ) + qapp.processEvents() + + window.compare_output_history_button.click() + qapp.processEvents() + + saved_compare_dialog = window._output_history_compare_dialog + assert saved_compare_dialog is not None + assert len(saved_compare_dialog._scatter_plot.figure.axes) == 2 + saved_born_axis, saved_debye_axis = ( + saved_compare_dialog._scatter_plot.figure.axes + ) + assert "Born Approximation" in saved_born_axis.get_ylabel() + assert "Debye Scattering" in saved_debye_axis.get_ylabel() + assert saved_debye_axis.get_lines() + saved_compare_dialog.close() + + window.load_output_history_button.click() + qapp.processEvents() + + assert window._fourier_result is not None + assert window._debye_scattering_result is not None + assert window.open_debye_scattering_compare_button.isEnabled() + + window.open_debye_scattering_compare_button.click() + qapp.processEvents() + + live_compare_dialog = window._debye_scattering_compare_dialog + assert live_compare_dialog is not None + assert len(live_compare_dialog._plot_widget.figure.axes) == 2 + live_compare_dialog.close() + window.close() + + def test_main_window_push_controls_render_between_fourier_and_saved_outputs( qapp, ): @@ -2627,7 +2760,7 @@ def test_main_window_runs_profile_and_writes_outputs(qapp, tmp_path): ) assert ( window._profile_result.smearing_settings.debye_waller_factor - == pytest.approx(0.006) + == pytest.approx(0.0) ) assert window.run_button.isEnabled() assert "complete" in window.calculation_progress_message.text().lower() @@ -3080,7 +3213,7 @@ def test_main_window_evaluates_fourier_transform_and_toggles_log_axes( assert window._fourier_result is not None assert window.scattering_plot.current_result is window._fourier_result scattering_axis = window.scattering_plot.figure.axes[0] - assert scattering_axis.get_xlabel() == "q (Å⁻¹)" + assert scattering_axis.get_xlabel() == "q (Å$^{-1}$)" assert scattering_axis.get_xscale() == "log" assert scattering_axis.get_yscale() == "log" @@ -3097,10 +3230,10 @@ def test_main_window_uses_updated_fourier_defaults_and_inherited_q_range( qapp, ): window = ElectronDensityMappingMainWindow() - assert window.fourier_qmin_spin.value() == pytest.approx(0.02) + assert window.fourier_qmin_spin.value() == pytest.approx(0.01) assert window.fourier_qmax_spin.value() == pytest.approx(1.2) assert window.fourier_qstep_spin.value() == pytest.approx(0.01) - assert window.fourier_resampling_points_spin.value() == 2048 + assert window.fourier_resampling_points_spin.value() == 4096 inherited_window = ElectronDensityMappingMainWindow( initial_project_q_min=0.15, @@ -3109,7 +3242,7 @@ def test_main_window_uses_updated_fourier_defaults_and_inherited_q_range( assert inherited_window.fourier_qmin_spin.value() == pytest.approx(0.15) assert inherited_window.fourier_qmax_spin.value() == pytest.approx(0.9) assert inherited_window.fourier_qstep_spin.value() == pytest.approx(0.01) - assert inherited_window.fourier_resampling_points_spin.value() == 2048 + assert inherited_window.fourier_resampling_points_spin.value() == 4096 window.close() inherited_window.close() @@ -3142,26 +3275,27 @@ def test_main_window_keeps_fourier_controls_editable_without_cluster_groups( window.close() -def test_main_window_defaults_to_mirrored_fourier_domain_and_can_toggle_legacy( +def test_main_window_defaults_to_legacy_fourier_domain_and_can_toggle_mirrored( qapp, ): window = ElectronDensityMappingMainWindow() window.apply_fourier_to_all_button.setChecked(False) qapp.processEvents() - assert not window.fourier_legacy_mode_checkbox.isChecked() - assert window.fourier_rmin_label.text() == "-r max" - assert not window.fourier_rmin_spin.isEnabled() + assert window.fourier_legacy_mode_checkbox.isChecked() + assert window.fourier_rmin_label.text() == "r min" + assert window.fourier_rmin_spin.isEnabled() + assert window.fourier_rmin_spin.value() == pytest.approx(0.0) window.fourier_rmax_spin.setValue(3.25) qapp.processEvents() - assert window.fourier_rmin_spin.value() == pytest.approx(-3.25) + assert window.fourier_rmin_spin.value() == pytest.approx(0.0) - window.fourier_legacy_mode_checkbox.setChecked(True) + window.fourier_legacy_mode_checkbox.setChecked(False) qapp.processEvents() - assert window.fourier_rmin_label.text() == "r min" - assert window.fourier_rmin_spin.isEnabled() - assert window.fourier_rmin_spin.value() == pytest.approx(0.0) + assert window.fourier_rmin_label.text() == "-r max" + assert not window.fourier_rmin_spin.isEnabled() + assert window.fourier_rmin_spin.value() == pytest.approx(-3.25) window.close() @@ -3176,7 +3310,7 @@ def test_main_window_exposes_centered_exafs_window_options(qapp): assert "kaiser_bessel" in window_names assert "hanning" in window_names - assert window.fourier_window_combo.currentData() == "hanning" + assert window.fourier_window_combo.currentData() == "none" window.close() @@ -3747,7 +3881,15 @@ def test_main_window_debye_scattering_pane_builds_cluster_comparison_plot( assert window.calculate_debye_scattering_button.isEnabled() window.apply_debye_to_all_button.setChecked(True) window.calculate_debye_scattering_button.click() - qapp.processEvents() + _wait_for( + lambda: all( + state.debye_scattering_result is not None + for state in window._cluster_group_states + if not state.single_atom_only + ) + and window._debye_scattering_thread is None, + qapp, + ) assert all( state.debye_scattering_result is not None @@ -3781,6 +3923,51 @@ def test_main_window_debye_scattering_pane_builds_cluster_comparison_plot( window.close() +def test_cluster_folder_batch_debye_runs_create_saved_output_entries( + qapp, + tmp_path, +): + window = _open_ready_cluster_folder_window( + qapp, + tmp_path, + folder_name="debye_saved_outputs_batch", + ) + + window.apply_fourier_to_all_button.setChecked(True) + window.evaluate_fourier_button.click() + qapp.processEvents() + window.apply_debye_to_all_button.setChecked(True) + window.calculate_debye_scattering_button.click() + _wait_for( + lambda: all( + state.debye_scattering_result is not None + for state in window._cluster_group_states + if not state.single_atom_only + ) + and window._debye_scattering_thread is None, + qapp, + ) + + expected_debye_entries = sum( + 1 + for state in window._cluster_group_states + if not state.single_atom_only + ) + debye_entries = [ + entry + for entry in window._saved_output_entries + if entry.entry_kind == "debye_scattering" + ] + + assert len(debye_entries) == expected_debye_entries + assert all( + entry.debye_scattering_result is not None for entry in debye_entries + ) + assert all(entry.transform_result is not None for entry in debye_entries) + assert window.output_history_table.item(0, 1).text() == "Debye Scattering" + window.close() + + def test_main_window_debye_scattering_progress_bar_updates_for_single_run( qapp, tmp_path, @@ -3813,21 +4000,15 @@ def test_main_window_debye_scattering_progress_bar_updates_for_single_run( qapp, ) - progress_snapshots: list[tuple[bool, int, int, str]] = [] + progress_messages: list[tuple[int, int, str]] = [] def wrapped_compute(*args, progress_callback=None, **kwargs): assert progress_callback is not None def wrapped_progress(current, total, message): + progress_messages.append((int(current), int(total), str(message))) progress_callback(current, total, message) - progress_snapshots.append( - ( - not window.debye_scattering_progress_bar.isHidden(), - window.debye_scattering_progress_bar.value(), - window.debye_scattering_progress_bar.maximum(), - window.debye_scattering_status_label.text(), - ) - ) + time.sleep(0.01) return compute_average_debye_scattering_profile_for_input( *args, @@ -3842,16 +4023,51 @@ def wrapped_progress(current, total, message): window.calculate_debye_scattering_button.click() - assert progress_snapshots + _wait_for( + lambda: window._debye_scattering_progress_dialog is not None + and window._debye_scattering_progress_dialog.isVisible(), + qapp, + ) + dialog = _assert_debye_progress_dialog_visible(window, total=4) + + progress_snapshots: list[tuple[bool, int, int, str, str]] = [] + _wait_for(lambda: bool(progress_messages), qapp) + while window._debye_scattering_thread is not None: + qapp.processEvents() + progress_snapshots.append( + ( + not window.debye_scattering_progress_bar.isHidden(), + window.debye_scattering_progress_bar.value(), + window.debye_scattering_progress_bar.maximum(), + window.debye_scattering_status_label.text(), + dialog.message_label.text(), + ) + ) + time.sleep(0.01) + _wait_for( + lambda: window._debye_scattering_result is not None + and window._debye_scattering_thread is None, + qapp, + ) + + assert progress_messages assert any(visible for visible, *_rest in progress_snapshots) assert any( - maximum == 4 for _visible, _value, maximum, _text in progress_snapshots + maximum == 4 + for _visible, _value, maximum, _text, _dialog_text in progress_snapshots ) assert any( "Debye scattering average calculation" in text or "Debye trace" in text - for _visible, _value, _maximum, text in progress_snapshots + for _visible, _value, _maximum, text, _dialog_text in progress_snapshots + ) + assert any( + "Debye scattering average calculation" in dialog_text + or "Debye trace" in dialog_text + for _visible, _value, _maximum, _text, dialog_text in progress_snapshots ) assert window.debye_scattering_progress_bar.isHidden() + assert window._debye_scattering_progress_dialog is not None + assert not window._debye_scattering_progress_dialog.isVisible() assert window._debye_scattering_result is not None window.close() @@ -3870,20 +4086,15 @@ def test_main_window_debye_scattering_progress_bar_updates_for_batch_run( window.evaluate_fourier_button.click() qapp.processEvents() - progress_snapshots: list[tuple[int, int, str]] = [] + progress_messages: list[tuple[int, int, str]] = [] def wrapped_compute(*args, progress_callback=None, **kwargs): assert progress_callback is not None def wrapped_progress(current, total, message): + progress_messages.append((int(current), int(total), str(message))) progress_callback(current, total, message) - progress_snapshots.append( - ( - window.debye_scattering_progress_bar.value(), - window.debye_scattering_progress_bar.maximum(), - window.debye_scattering_status_label.text(), - ) - ) + time.sleep(0.01) return compute_average_debye_scattering_profile_for_input( *args, @@ -3899,15 +4110,56 @@ def wrapped_progress(current, total, message): window.apply_debye_to_all_button.setChecked(True) window.calculate_debye_scattering_button.click() - assert progress_snapshots - assert any(maximum == 8 for _value, maximum, _text in progress_snapshots) + _wait_for( + lambda: window._debye_scattering_progress_dialog is not None + and window._debye_scattering_progress_dialog.isVisible(), + qapp, + ) + dialog = _assert_debye_progress_dialog_visible(window, total=8) + + progress_snapshots: list[tuple[int, int, str, str]] = [] + _wait_for(lambda: bool(progress_messages), qapp) + while window._debye_scattering_thread is not None: + qapp.processEvents() + progress_snapshots.append( + ( + window.debye_scattering_progress_bar.value(), + window.debye_scattering_progress_bar.maximum(), + window.debye_scattering_status_label.text(), + dialog.message_label.text(), + ) + ) + time.sleep(0.01) + _wait_for( + lambda: all( + state.debye_scattering_result is not None + for state in window._cluster_group_states + if not state.single_atom_only + ) + and window._debye_scattering_thread is None, + qapp, + ) + + assert progress_messages + assert any( + maximum == 8 + for _value, maximum, _text, _dialog_text in progress_snapshots + ) assert any( - "Debye 1/2" in text for _value, _maximum, text in progress_snapshots + "Debye 1/2" in text + for _value, _maximum, text, _dialog_text in progress_snapshots ) assert any( - "Debye 2/2" in text for _value, _maximum, text in progress_snapshots + "Debye 2/2" in text + for _value, _maximum, text, _dialog_text in progress_snapshots + ) + assert any( + "Debye 1/2" in dialog_text or "Debye 2/2" in dialog_text + for _value, _maximum, _text, dialog_text in progress_snapshots ) assert window.debye_scattering_progress_bar.isHidden() + assert window._debye_scattering_progress_dialog is not None + assert not window._debye_scattering_progress_dialog.isVisible() assert all( state.debye_scattering_result is not None for state in window._cluster_group_states @@ -4276,6 +4528,77 @@ def test_main_window_computes_solvent_contrast_and_updates_fourier_source( window.close() +def test_main_window_saved_solvent_none_clears_active_solvent_subtraction( + qapp, + tmp_path, +): + structure_path = _write_xyz( + tmp_path / "solvent_none_clear.xyz", + [ + "4", + "solvent none clear", + "N 0.0 0.0 0.1", + "H 0.94 0.0 -0.2", + "H -0.47 0.81 -0.2", + "H -0.47 -0.81 -0.2", + ], + ) + window = ElectronDensityMappingMainWindow( + initial_input_path=structure_path + ) + + window.run_button.click() + _wait_for( + lambda: window._profile_result is not None + and window._calculation_thread is None, + qapp, + ) + + assert window._profile_result is not None + direct_density = float( + min( + np.max( + np.asarray( + window._profile_result.smeared_orientation_average_density, + dtype=float, + ) + ) + * 0.01, + 50.0, + ) + ) + window.solvent_method_combo.setCurrentIndex( + window.solvent_method_combo.findData(CONTRAST_SOLVENT_METHOD_DIRECT) + ) + window.direct_density_spin.setValue(direct_density) + window.compute_solvent_density_button.click() + qapp.processEvents() + + assert window._profile_result.solvent_contrast is not None + assert window._fourier_preview is not None + assert window._fourier_preview.source_profile_label.startswith( + "Solvent-subtracted" + ) + + window.solvent_method_combo.setCurrentIndex( + window.solvent_method_combo.findData(CONTRAST_SOLVENT_METHOD_NEAT) + ) + none_index = window.solvent_preset_combo.findText("None") + assert none_index >= 0 + window.solvent_preset_combo.setCurrentIndex(none_index) + qapp.processEvents() + window.compute_solvent_density_button.click() + qapp.processEvents() + + assert window._active_contrast_settings is None + assert window._profile_result is not None + assert window._profile_result.solvent_contrast is None + assert window.residual_profile_plot.current_contrast is None + assert window._fourier_preview is not None + assert window._fourier_preview.source_profile_label == "Smeared ρ(r)" + window.close() + + def test_fourier_preview_plot_renders_mirrored_profile_and_solvent_subtraction( qapp, tmp_path, @@ -4442,9 +4765,9 @@ def test_main_window_defaults_rstep_and_rmax_from_structure(qapp, tmp_path): initial_input_path=structure_path ) - assert window.rstep_spin.value() == pytest.approx(0.05) + assert window.rstep_spin.value() == pytest.approx(0.25) assert window.rmax_spin.value() == pytest.approx( - window._structure.rmax, + np.ceil(float(window._structure.rmax) + 2.0), abs=1.0e-4, ) assert "geometric mass center" in ( @@ -5303,7 +5626,15 @@ def test_main_window_restores_debye_scattering_workspace_state( qapp.processEvents() window.apply_debye_to_all_button.setChecked(True) window.calculate_debye_scattering_button.click() - qapp.processEvents() + _wait_for( + lambda: all( + state.debye_scattering_result is not None + for state in window._cluster_group_states + if not state.single_atom_only + ) + and window._debye_scattering_thread is None, + qapp, + ) assert workspace_state_path.is_file() workspace_payload = json.loads( @@ -6611,6 +6942,74 @@ def test_cluster_group_selection_syncs_fourier_rmax_to_solvent_cutoff( window.close() +def test_cluster_folder_saved_solvent_none_clears_batch_solvent_subtraction( + qapp, + tmp_path, +): + clusters_dir = _write_cluster_folder_input( + tmp_path / "clusters_none_clear" + ) + + window = ElectronDensityMappingMainWindow(initial_input_path=clusters_dir) + + window.run_button.click() + _wait_for( + lambda: all( + state.profile_result is not None + for state in window._cluster_group_states + ) + and window._calculation_thread is None, + qapp, + ) + + window.solvent_method_combo.setCurrentIndex( + window.solvent_method_combo.findData(CONTRAST_SOLVENT_METHOD_DIRECT) + ) + window.direct_density_spin.setValue(0.2) + window.compute_solvent_density_button.click() + qapp.processEvents() + + assert all( + state.profile_result is not None + and state.profile_result.solvent_contrast is not None + for state in window._cluster_group_states + ) + + window.evaluate_fourier_button.click() + qapp.processEvents() + assert all( + state.transform_result is not None + for state in window._cluster_group_states + ) + + window.solvent_method_combo.setCurrentIndex( + window.solvent_method_combo.findData(CONTRAST_SOLVENT_METHOD_NEAT) + ) + none_index = window.solvent_preset_combo.findText("None") + assert none_index >= 0 + window.solvent_preset_combo.setCurrentIndex(none_index) + qapp.processEvents() + window.compute_solvent_density_button.click() + qapp.processEvents() + + assert window._active_contrast_settings is None + assert all( + state.profile_result is not None + and state.profile_result.solvent_contrast is None + for state in window._cluster_group_states + ) + assert all( + state.solvent_density_e_per_a3 is None + and state.solvent_cutoff_radius_a is None + for state in window._cluster_group_states + ) + assert all( + state.transform_result is None + for state in window._cluster_group_states + ) + window.close() + + def test_cluster_folder_manual_mode_runs_selected_row_and_locks_mesh( qapp, tmp_path, diff --git a/tests/test_saxs_prefit.py b/tests/test_saxs_prefit.py index 13c571d..25dbac5 100644 --- a/tests/test_saxs_prefit.py +++ b/tests/test_saxs_prefit.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from pathlib import Path from types import SimpleNamespace import numpy as np @@ -13,8 +14,22 @@ ClusterDynamicsMLTrainingObservation, PredictedClusterCandidate, ) +from saxshell.fullrmc.project_model import build_rmcsetup_paths +from saxshell.fullrmc.representatives import ( + DistributionSelectionMetadata, + RepresentativeSelectionEntry, + RepresentativeSelectionMetadata, + RepresentativeSelectionSettings, + save_representative_selection_metadata, +) from saxshell.fullrmc.solution_properties import SolutionPropertiesSettings -from saxshell.saxs._model_templates import load_template_module +from saxshell.saxs._model_templates import ( + load_template_module, + load_template_spec, +) +from saxshell.saxs.contrast.settings import ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, +) from saxshell.saxs.debye.profiles import AveragedComponent, ClusterBin from saxshell.saxs.prefit import ( SAXSPrefitWorkflow, @@ -24,6 +39,7 @@ ) from saxshell.saxs.prefit.workflow import constrained_prefit_residuals from saxshell.saxs.project_manager import ( + DreamBestFitSelection, SAXSProjectManager, build_project_paths, project_artifact_paths, @@ -40,6 +56,9 @@ POLY_LMA_HS_TEMPLATE = "template_pydream_poly_lma_hs" POLY_LMA_HS_MIX_TEMPLATE = "template_pydream_poly_lma_hs_mix_approx" +SCALED_SOLVENT_MONOSQ_TEMPLATE = ( + "template_pydream_monosq_normalized_scaled_solvent" +) def _write_component_file(path, q_values, intensities): @@ -1015,6 +1034,219 @@ def fake_build_profiles( assert records[0].prior_artifacts_ready +def test_prefit_loads_pushed_3d_fft_components_after_source_toggle_change( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_representative_structures = True + manager.save_project(settings) + + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + artifact_paths.component_dir.mkdir(parents=True, exist_ok=True) + (artifact_paths.component_dir / "A_no_motif.txt").write_text( + (paths.scattering_components_dir / "A_no_motif.txt").read_text( + encoding="utf-8" + ), + encoding="utf-8", + ) + artifact_paths.component_map_file.write_text( + (paths.project_dir / "md_saxs_map.json").read_text(encoding="utf-8"), + encoding="utf-8", + ) + artifact_paths.prior_weights_file.write_text( + (paths.project_dir / "md_prior_weights.json").read_text( + encoding="utf-8" + ), + encoding="utf-8", + ) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + built_component_source_mode="average", + ) + + workflow = SAXSPrefitWorkflow(project_dir) + + assert manager.component_artifacts_match_settings( + settings, + artifact_paths=artifact_paths, + ) + assert [ + (component.structure, component.motif) + for component in workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + workflow.evaluate().q_values, + np.linspace(0.05, 0.3, 8), + ) + + +def test_prefit_snaps_legacy_3d_fft_endpoint_to_component_q_grid(tmp_path): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + q_values = 0.0101 + 0.01 * np.arange(119, dtype=float) + component = np.linspace(10.0, 17.0, q_values.size) + _write_component_file( + paths.scattering_components_dir / "A_no_motif.txt", + q_values, + component, + ) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.model_only_mode = True + settings.q_min = 0.0101 + settings.q_max = 1.1976 + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + evaluation = workflow.evaluate() + + assert np.allclose(evaluation.q_values, q_values) + + +def test_prefit_cluster_geometry_uses_representative_source_for_3d_fft_components( + tmp_path, +): + project_dir, paths, _effective_radius = _build_poly_lma_geometry_project( + tmp_path + ) + rmcsetup_paths = build_rmcsetup_paths(project_dir) + representative_dir = rmcsetup_paths.representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "selected_rep.xyz" + representative_path.write_text( + "\n".join( + [ + "4", + "selected representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", + ] + ) + + "\n", + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + save_representative_selection_metadata( + rmcsetup_paths.representative_selection_path, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=1.0, + cluster_count=2, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ), + ) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_representative_structures = True + settings.clusters_dir = None + manager.save_project(settings) + + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + artifact_paths.component_dir.mkdir(parents=True, exist_ok=True) + (artifact_paths.component_dir / "A_no_motif.txt").write_text( + (paths.scattering_components_dir / "A_no_motif.txt").read_text( + encoding="utf-8" + ), + encoding="utf-8", + ) + artifact_paths.component_map_file.write_text( + json.dumps({"saxs_map": {"A": {"no_motif": "A_no_motif.txt"}}}) + "\n", + encoding="utf-8", + ) + artifact_paths.prior_weights_file.write_text( + json.dumps( + { + "origin": "representative_structures", + "total_files": 2, + "structures": { + "A": { + "no_motif": { + "count": 2, + "weight": 1.0, + "representative": representative_path.name, + "profile_file": "A_no_motif.txt", + "source_kind": "representative_structure", + "source_dir": str(representative_dir), + "source_file": str(representative_path), + "source_file_name": representative_path.name, + } + } + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + built_component_source_mode="representative", + ) + + workflow = SAXSPrefitWorkflow(project_dir) + table = workflow.compute_cluster_geometry_table() + + assert [row.cluster_id for row in table.rows] == ["A"] + assert Path(table.rows[0].cluster_path) == representative_dir.resolve() + assert table.rows[0].mapped_parameter == "w0" + + def test_prefit_cluster_geometry_includes_predicted_structures_when_enabled( tmp_path, monkeypatch, @@ -1399,6 +1631,170 @@ def test_saxs_prefit_workflow_recommends_scale_with_weighted_solvent_trace( assert recommendation.points_used == 8 +def test_scaled_solvent_monosq_template_scales_solvent_with_global_scale(): + template_module = load_template_module(SCALED_SOLVENT_MONOSQ_TEMPLATE) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + solvent = np.linspace(1.5, 2.2, 8) + + result = template_module.lmfit_model_profile( + q_values, + solvent, + [component], + w0=0.6, + solv_w=0.5, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=2e-3, + ) + + assert np.allclose( + result, + 2e-3 * ((0.6 * component) + (0.5 * solvent)) + 0.05, + ) + + +def test_scaled_solvent_monosq_prefit_evaluates_scaled_solvent_contribution( + tmp_path, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + solvent_q = np.linspace(0.05, 0.3, 8) + solvent_intensity = np.linspace(1.5, 2.2, 8) + solvent_path = tmp_path / "scaled_weighted_solvent_trace.dat" + np.savetxt(solvent_path, np.column_stack([solvent_q, solvent_intensity])) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + settings.solvent_data_path = str(solvent_path) + settings.copied_solvent_data_file = None + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + for entry in entries: + if entry.name == "solv_w": + entry.value = 0.5 + if entry.name == "scale": + entry.value = 2e-3 + + evaluation = workflow.evaluate(entries) + + assert evaluation.solvent_intensities is not None + assert evaluation.solvent_contribution is not None + assert np.allclose(evaluation.solvent_intensities, solvent_intensity) + assert np.allclose( + evaluation.solvent_contribution, + solvent_intensity * 0.5 * 2e-3, + ) + + +def test_scaled_solvent_monosq_recommends_scale_with_scaled_solvent_branch( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + solvent_q = np.linspace(0.05, 0.3, 8) + solvent_intensity = np.linspace(1.5, 2.2, 8) + solvent_path = tmp_path / "scaled_autoscale_solvent_trace.dat" + np.savetxt(solvent_path, np.column_stack([solvent_q, solvent_intensity])) + + template_module = load_template_module(SCALED_SOLVENT_MONOSQ_TEMPLATE) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + experimental = template_module.lmfit_model_profile( + q_values, + solvent_intensity, + [component], + w0=0.6, + solv_w=0.5, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=5e-4, + ) + experimental_path = paths.experimental_data_dir / "scaled_exp_demo.txt" + np.savetxt(experimental_path, np.column_stack([q_values, experimental])) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + settings.experimental_data_path = str(experimental_path) + settings.copied_experimental_data_file = str(experimental_path) + settings.solvent_data_path = str(solvent_path) + settings.copied_solvent_data_file = None + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + entries = workflow.load_parameter_entries() + for entry in entries: + if entry.name == "solv_w": + entry.value = 0.5 + if entry.name == "scale": + entry.value = 1e-6 + entry.minimum = 1e-7 + entry.maximum = 1e-5 + + recommendation = workflow.recommend_scale_settings(entries) + + assert recommendation.current_scale == pytest.approx(1e-6) + assert recommendation.recommended_scale == pytest.approx(5e-4) + assert recommendation.recommended_minimum == pytest.approx(5e-5) + assert recommendation.recommended_maximum == pytest.approx(5e-3) + assert recommendation.current_offset == pytest.approx(0.0) + assert recommendation.recommended_offset == pytest.approx(0.05) + assert recommendation.points_used == 8 + + +def test_scaled_solvent_monosq_recommendation_uses_adaptive_bounds(tmp_path): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + experimental_scale = 2e-7 + experimental_offset = 0.01 + template_module = load_template_module(SCALED_SOLVENT_MONOSQ_TEMPLATE) + experimental = template_module.lmfit_model_profile( + q_values, + np.zeros_like(q_values), + [component], + w0=0.6, + solv_w=0.0, + offset=experimental_offset, + eff_r=3.0, + vol_frac=0.0, + scale=experimental_scale, + ) + experimental_path = paths.experimental_data_dir / "adaptive_exp_demo.txt" + np.savetxt(experimental_path, np.column_stack([q_values, experimental])) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + settings.experimental_data_path = str(experimental_path) + settings.copied_experimental_data_file = str(experimental_path) + manager.save_project(settings) + + workflow = SAXSPrefitWorkflow(project_dir) + recommendation = workflow.recommend_scale_settings( + workflow.load_parameter_entries() + ) + + assert recommendation.recommended_scale == pytest.approx( + experimental_scale + ) + assert recommendation.recommended_minimum == pytest.approx( + experimental_scale / 10.0 + ) + assert recommendation.recommended_maximum == pytest.approx( + experimental_scale * 10.0 + ) + assert recommendation.recommended_offset == pytest.approx( + experimental_offset + ) + assert recommendation.recommended_offset_minimum == pytest.approx(0.009) + assert recommendation.recommended_offset_maximum == pytest.approx(0.011) + + def test_solute_volume_fraction_estimate_uses_component_densities(): estimate = calculate_solute_volume_fraction_estimate( SoluteVolumeFractionSettings( @@ -1500,6 +1896,50 @@ def test_monosq_prefit_workflow_exposes_solvent_weight_target(tmp_path): assert workflow.solvent_weight_estimator_target() == "solv_w" +def test_scaled_solvent_monosq_prefit_exposes_physical_vol_frac_target( + tmp_path, +): + spec = load_template_spec(SCALED_SOLVENT_MONOSQ_TEMPLATE) + assert ( + spec.solution_scattering_support.volume_fraction_parameter + == "vol_frac" + ) + assert spec.solution_scattering_support.volume_fraction_kind == "solute" + assert ( + spec.solution_scattering_support.volume_fraction_source == "physical" + ) + assert ( + spec.solution_scattering_support.solvent_contribution_scale_mode + == "global_scale" + ) + assert spec.prefit_support.auto_apply_autoscale_on_load + assert spec.prefit_support.autoscale_bounds_mode == "adaptive" + eff_r_entry = next( + parameter for parameter in spec.parameters if parameter.name == "eff_r" + ) + assert eff_r_entry.initial_value == pytest.approx(3.0) + + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + manager.save_project(settings) + workflow = SAXSPrefitWorkflow(project_dir) + + assert workflow.supports_volume_fraction_estimator() + assert workflow.volume_fraction_estimator_target() == ( + "vol_frac", + "solute", + ) + assert workflow.solution_scattering_volume_fraction_target() == ( + "vol_frac", + "solute", + "physical", + ) + assert workflow.solvent_weight_estimator_target() == "solv_w" + assert workflow.solvent_contribution_is_scaled_by_global_scale() + + def test_poly_lma_prefit_workflow_exposes_solvent_weight_target(tmp_path): project_dir, _paths, _radius = _build_poly_lma_geometry_project(tmp_path) workflow = SAXSPrefitWorkflow(project_dir) diff --git a/tests/test_saxs_ui.py b/tests/test_saxs_ui.py index 4a97c05..c19c535 100644 --- a/tests/test_saxs_ui.py +++ b/tests/test_saxs_ui.py @@ -41,11 +41,25 @@ from scipy import stats import saxshell.saxs.project_manager.project as project_module +import saxshell.saxs.ui.main_window as saxs_ui_main_window_module import saxshell.saxs.ui.prefit_tab as prefit_tab_module from saxshell.clusterdynamicsml.workflow import ( ClusterDynamicsMLTrainingObservation, PredictedClusterCandidate, ) +from saxshell.fullrmc import ( + PackmolDockerLink, + load_packmol_docker_link_metadata, +) +from saxshell.fullrmc.project_model import build_rmcsetup_paths +from saxshell.fullrmc.representatives import ( + DistributionSelectionMetadata, + RepresentativeSelectionEntry, + RepresentativeSelectionIssue, + RepresentativeSelectionMetadata, + RepresentativeSelectionSettings, +) +from saxshell.plotting import Q_A_INVERSE_LABEL, load_pickled_plot_figure from saxshell.saxs._model_templates import ( list_template_specs, load_template_module, @@ -64,11 +78,24 @@ ) from saxshell.saxs.contrast.settings import ( COMPONENT_BUILD_MODE_BORN_APPROXIMATION, + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, COMPONENT_BUILD_MODE_CONTRAST, COMPONENT_BUILD_MODE_NO_CONTRAST, ContrastRepresentativeSamplerSettings, ) from saxshell.saxs.contrast.ui.main_window import ContrastModeMainWindow +from saxshell.saxs.contrast_fft import ( + ContrastFFTResult, + ContrastFFTSettings, + ContrastFFTTiming, +) +from saxshell.saxs.contrast_fft.ui.main_window import ( + FFTBornApproximationMainWindow, + _FFTComputationPayload, + _FFTComputationWorker, + _FFTProfileComputationResult, + _FFTProfileTarget, +) from saxshell.saxs.debye.profiles import AveragedComponent, ClusterBin from saxshell.saxs.dream import ( DreamParameterEntry, @@ -77,6 +104,9 @@ SAXSDreamWorkflow, load_dream_settings, ) +from saxshell.saxs.electron_density_mapping.workflow import ( + load_electron_density_structure, +) from saxshell.saxs.model_report import export_dream_model_report_pptx from saxshell.saxs.prefit import ( PrefitEvaluation, @@ -87,6 +117,7 @@ from saxshell.saxs.prefit.workflow import PrefitFitResult from saxshell.saxs.project_manager import ( ClusterImportResult, + DreamBestFitSelection, ExperimentalDataSummary, PowerPointExportSettings, ProjectSettings, @@ -112,6 +143,7 @@ ) from saxshell.saxs.ui.main_window import ( AUTO_SNAP_PANES_KEY, + PACKMOL_DOCKER_PRESETS_KEY, PROJECT_LOAD_TOTAL_STEPS, InstallModelDialog, RuntimeBundleOpener, @@ -131,6 +163,9 @@ POLY_LMA_HS_TEMPLATE = "template_pydream_poly_lma_hs" POLY_LMA_HS_MIX_TEMPLATE = "template_pydream_poly_lma_hs_mix_approx" +SCALED_SOLVENT_MONOSQ_TEMPLATE = ( + "template_pydream_monosq_normalized_scaled_solvent" +) def _table_column_index(table, label: str) -> int: @@ -341,6 +376,22 @@ def _build_minimal_saxs_project( return project_dir, paths +def _write_representative_selection_metadata( + project_dir: Path, + metadata: RepresentativeSelectionMetadata, +) -> Path: + rmcsetup_paths = build_rmcsetup_paths(project_dir) + rmcsetup_paths.representative_selection_path.parent.mkdir( + parents=True, + exist_ok=True, + ) + rmcsetup_paths.representative_selection_path.write_text( + json.dumps(metadata.to_dict(), indent=2) + "\n", + encoding="utf-8", + ) + return rmcsetup_paths.representative_selection_path + + def _write_predicted_structure_artifacts( paths, *, @@ -1471,6 +1522,79 @@ def test_prefit_single_solvent_weight_uses_combined_saxs_effective_multiplier( window.close() +def test_prefit_scaled_solvent_monosq_applies_physical_vol_frac_target( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + manager.save_project(settings) + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.prefit_tab.solute_volume_fraction_collapse_button.click() + widget = window.prefit_tab.solute_volume_fraction_widget + widget.solution_density_spin.setValue(1.0) + widget.solute_stoich_edit.setText("Cs1Pb1I3") + widget.solvent_stoich_edit.setText("H2O") + widget.molar_mass_solute_spin.setValue(620.0) + widget.molar_mass_solvent_spin.setValue(18.015) + widget.mass_solute_spin.setValue(1.0) + widget.mass_solvent_spin.setValue(9.0) + widget.solute_density_spin.setValue(2.0) + widget.solvent_density_spin.setValue(1.0) + + expected_settings = widget.current_estimator_settings() + expected_estimate = calculate_solution_scattering_estimate( + SolutionScatteringEstimatorSettings( + solution=expected_settings.solution, + solute_density_g_per_ml=expected_settings.solute_density_g_per_ml, + solvent_density_g_per_ml=expected_settings.solvent_density_g_per_ml, + calculate_number_density=expected_settings.calculate_number_density, + calculate_solute_volume_fraction=( + expected_settings.calculate_solute_volume_fraction + ), + calculate_solvent_scattering_contribution=( + expected_settings.calculate_solvent_scattering_contribution + ), + calculate_sample_fluorescence_yield=( + expected_settings.calculate_sample_fluorescence_yield + ), + beam=expected_settings.beam, + ) + ) + + widget.calculate_button.click() + + vol_frac_row = window.prefit_tab.find_parameter_row("vol_frac") + assert vol_frac_row >= 0 + assert float( + window.prefit_tab.parameter_table.item(vol_frac_row, 3).text() + ) == pytest.approx( + expected_estimate.volume_fraction_estimate.solute_volume_fraction, + rel=1e-3, + ) + assert float( + window.prefit_tab.parameter_table.item(vol_frac_row, 6).text() + ) == pytest.approx(0.5) + solvent_row = window.prefit_tab.find_parameter_row("solv_w") + assert solvent_row >= 0 + assert float( + window.prefit_tab.parameter_table.item(solvent_row, 3).text() + ) == pytest.approx( + expected_estimate.attenuation_estimate.solvent_scattering_scale_factor + * expected_estimate.interaction_contrast_estimate.saxs_effective_solvent_background_ratio, + rel=1e-3, + ) + output_text = widget.output_box.toPlainText() + assert "Applied vol_frac =" in output_text + assert "physical bulk volume fraction" in output_text + assert "Applied solv_w =" in output_text + window.close() + + def test_prefit_solute_volume_fraction_widget_hides_solute_density_in_molarity_mode( qapp, tmp_path, @@ -2740,6 +2864,10 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert window.save_project_action.text() == "Save Project" assert window.save_project_action.shortcuts() assert window.save_project_as_action.text() == "Save Project As..." + assert window.link_packmol_docker_action.text() == ( + "Link Packmol Docker Container..." + ) + assert window.link_packmol_docker_action.isEnabled() is True assert all( button.text() != "Save Project State" for button in window.project_setup_tab.project_group.findChildren( @@ -2751,6 +2879,8 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert window.md_extraction_menu.title() == "MD Extraction" assert window.structure_analysis_menu.title() == "Structure Analysis" assert window.visualization_menu.title() == "Visualization" + assert window.cli_setup_menu.title() == "CLI Setup" + assert window.beta_menu.title() == "(beta)" assert window.mdtrajectory_action.text() == "Open MD Trajectory Extraction" assert window.xyz2pdb_action.text() == "Open XYZ -> PDB Conversion" assert window.cluster_action.text() == "Open Cluster Extraction" @@ -2759,6 +2889,10 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): window.debye_waller_analysis_action.text() == "Open Debye-Waller Analysis" ) + assert ( + window.project_setup_tab.debye_waller_button.text() + == "Compute Debye-Waller Factors (beta)" + ) assert [action.text() for action in window.tools_menu.actions()] == [ "MD Extraction", "Structure Analysis", @@ -2767,6 +2901,8 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): "Visualization", "SAXS Calculation Preview", "X-ray Toolkit", + "CLI Setup", + "(beta)", ] assert [ action.text() for action in window.md_extraction_menu.actions() @@ -2779,7 +2915,25 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): action.text() for action in window.structure_analysis_menu.actions() ] == [ "Open Bond Analysis", - "Open Debye-Waller Analysis", + "Open Representative Structures", + ] + window._build_menu_bar() + assert [action.text() for action in window.tools_menu.actions()] == [ + "MD Extraction", + "Structure Analysis", + "Cluster Dynamics", + "PDF", + "Visualization", + "SAXS Calculation Preview", + "X-ray Toolkit", + "CLI Setup", + "(beta)", + ] + assert [ + action.text() for action in window.structure_analysis_menu.actions() + ] == [ + "Open Bond Analysis", + "Open Representative Structures", ] assert ( window.clusterdynamics_action.text() == "Open Cluster Dynamics (only)" @@ -2787,16 +2941,30 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert ( window.clusterdynamicsml_action.text() == "Open Cluster Dynamics (ML)" ) - assert window.fullrmc_action.text() == "Open fullrmc Setup" + assert [ + action.text() for action in window.cluster_dynamics_menu.actions() + ] == ["Open Cluster Dynamics (ML)"] + assert window.fullrmc_action.text() == "Open RMC Setup (fullrmc)" assert window.structure_viewer_action.text() == "Structure Viewer" assert window.blenderxyz_action.text() == "Open Blender XYZ Renderer" + assert ( + window.representative_finder_action.text() + == "Open Representative Structures" + ) assert window.component_calculation_preview_menu.title() == ( "SAXS Calculation Preview" ) - assert window.contrast_mode_action.text() == "Open SAXS Contrast Mode" + assert ( + window.contrast_mode_action.text() + == "Open Debye Scattering (Contrast Mode)" + ) assert ( window.electron_density_mapping_action.text() - == "Open Electron Density Mapping" + == "Open 1D Born Approximation" + ) + assert ( + window.fft_born_approximation_action.text() + == "Open 3D FFT Born Approximation" ) assert window.xray_toolkit_menu.title() == "X-ray Toolkit" assert ( @@ -2805,6 +2973,22 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert ( window.number_density_action.text() == "Open Number Density Estimate" ) + assert ( + window.solvent_shell_builder_action.text() + == "Open Solvent Shell Builder (Beta)" + ) + assert ( + window.representative_cli_setup_action.text() + == "Open Representative CLI Setup (Beta)" + ) + assert [action.text() for action in window.cli_setup_menu.actions()] == [ + "Open Representative CLI Setup (Beta)", + ] + assert [action.text() for action in window.beta_menu.actions()] == [ + "Open Cluster Dynamics (only)", + "Open Debye-Waller Analysis", + "Open Solvent Shell Builder (Beta)", + ] assert window.settings_menu.title() == "Settings" assert ( @@ -2825,6 +3009,75 @@ def test_main_window_menus_expose_project_tools_and_help(qapp, tmp_path): assert "External Display (1440p / QHD)" in preset_labels +def test_file_menu_can_link_packmol_docker_before_opening_fullrmc( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + + class _FakeSettings: + def __init__(self): + self.values: dict[str, object] = {} + + def value(self, key, default=None): + return self.values.get(key, default) + + def setValue(self, key, value): + self.values[key] = value + + settings_store = _FakeSettings() + linked = PackmolDockerLink( + display_name="Main UI Packmol", + container_name="packmol-main-ui", + container_project_root="/packmol_input_files/project_alpha", + packmol_command="packmol", + shell_command="sh", + packmol_version="Packmol version 20.14.4", + last_verified_at="2026-04-17T13:00:00", + container_id="sha256:mainui", + image_name="packmol:test-image", + packmol_command_path="/usr/local/bin/packmol", + ) + + class _FakeDialog: + def __init__(self, *args, **kwargs) -> None: + del args, kwargs + + def exec(self): + return 1 + + def selected_link(self): + return PackmolDockerLink.from_dict(linked.to_dict()) + + monkeypatch.setattr( + SAXSMainWindow, + "_packmol_docker_settings", + lambda self: settings_store, + ) + monkeypatch.setattr( + saxs_ui_main_window_module, + "PackmolDockerLinkDialog", + _FakeDialog, + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + window._open_packmol_docker_link_dialog() + + metadata_path = project_dir / "rmcsetup" / "packmol_docker_link.json" + saved = load_packmol_docker_link_metadata(metadata_path) + assert saved is not None + assert saved.container_name == "packmol-main-ui" + assert saved.packmol_version == "Packmol version 20.14.4" + assert window.statusBar().currentMessage() == ( + "Linked Packmol Docker container packmol-main-ui" + ) + raw_presets = settings_store.value(PACKMOL_DOCKER_PRESETS_KEY, "[]") + preset_payload = json.loads(raw_presets) + assert preset_payload[0]["container_name"] == "packmol-main-ui" + + def test_project_setup_shows_prep_help_tooltips(qapp): del qapp window = SAXSMainWindow() @@ -4712,7 +4965,7 @@ def fake_launch_electron_density_mapping_ui(**kwargs): window.close() -def test_open_contrast_mode_tool_does_not_auto_start_build( +def test_3d_fft_born_tool_uses_active_structure_folder( qapp, tmp_path, monkeypatch, @@ -4720,103 +4973,605 @@ def test_open_contrast_mode_tool_does_not_auto_start_build( del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) window = SAXSMainWindow(initial_project_dir=project_dir) - launched: dict[str, object] = {"start_calls": 0} + pdb_frames_dir = tmp_path / "pdb_frames" + pdb_frames_dir.mkdir() + xyz_frames_dir = tmp_path / "xyz_frames" + xyz_frames_dir.mkdir() + window.current_settings.pdb_frames_dir = str(pdb_frames_dir.resolve()) + window.current_settings.frames_dir = str(xyz_frames_dir.resolve()) + window.current_settings.use_representative_structures = True + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} - class FakeContrastModeWindow(QWidget): + class FakeFFTBornWindow(QWidget): def __init__(self): super().__init__() launched["instance"] = self - def start_contrast_component_build(self): - launched["start_calls"] += 1 - - def raise_(self): - launched["raised"] = True - - def activateWindow(self): - launched["activated"] = True - - def fake_launch_contrast_mode_ui(**kwargs): + def fake_launch_3d_fft_born_approximation_ui(**kwargs): launched.update(kwargs) - return FakeContrastModeWindow() + return FakeFFTBornWindow() monkeypatch.setattr( - window, - "_confirm_default_q_range_for_component_build", - lambda: True, - ) - monkeypatch.setattr( - "saxshell.saxs.contrast.ui.main_window.launch_contrast_mode_ui", - fake_launch_contrast_mode_ui, + "saxshell.saxs.contrast_fft.ui.main_window.launch_3d_fft_born_approximation_ui", + fake_launch_3d_fft_born_approximation_ui, ) - window.project_setup_tab.set_component_build_mode( - COMPONENT_BUILD_MODE_CONTRAST - ) - window.build_project_components() + window._open_3d_fft_born_approximation_tool() + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["initial_input_path"] == pdb_frames_dir.resolve() + assert launched["initial_use_representative_structures"] is True + assert launched["preview_mode"] is True assert launched["instance"] in window._child_tool_windows - assert launched["start_calls"] == 0 window.close() -def test_contrast_mode_window_can_close_while_workflow_thread_is_running( +def test_fft_project_representative_targets_preserve_fullsolv_source_mode( qapp, + tmp_path, monkeypatch, ): del qapp - window = ContrastModeMainWindow() + structure_path = tmp_path / "representative_fullsolv.xyz" + structure_path.write_text( + "\n".join( + [ + "7", + "full solvent representative", + "Pb 0.0 0.0 0.0", + "O 3.0 0.0 0.0", + "H 3.6 0.7 0.0", + "H 2.4 0.7 0.0", + "O 6.0 0.0 0.0", + "H 6.6 0.7 0.0", + "H 5.4 0.7 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + metadata = RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="Pb", + motif="no_motif", + param="Pb", + selected_weight=1.0, + cluster_count=1, + source_dir=str(structure_path.parent), + source_file=str(structure_path), + source_file_name=structure_path.name, + atom_count=7, + element_counts={"Pb": 1, "O": 2, "H": 4}, + source_solvent_mode="fullsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ) + project_source = SimpleNamespace( + representative_selection=metadata, + solvent_handling=None, + ) + window = FFTBornApproximationMainWindow( + preview_mode=True, + initial_project_dir=tmp_path, + ) + monkeypatch.setattr(window, "_project_source", lambda: project_source) - class FakeRunningThread: - def isRunning(self) -> bool: - return True + targets = window._representative_targets_from_project_source() - warnings: list[tuple[object, ...]] = [] - monkeypatch.setattr( - "saxshell.saxs.contrast.ui.main_window.QMessageBox.warning", - lambda *args, **kwargs: warnings.append(args), - ) - window._workflow_thread = FakeRunningThread() + assert set(targets) == {("representative", "full")} + assert len(targets[("representative", "full")]) == 1 + target = targets[("representative", "full")][0] + assert target.solvent_mode == "full" + assert target.reference_file == structure_path.resolve() + window.close() - assert window.close() is True - assert warnings == [] +def test_fft_average_cluster_targets_use_every_structure_in_cluster_folder( + qapp, + tmp_path, +): + del qapp + clusters_dir = tmp_path / "clusters" + structure_dir = clusters_dir / "PbI" + structure_dir.mkdir(parents=True) + frame_paths = [] + for index in range(2): + frame_path = structure_dir / f"frame_{index + 1:04d}.xyz" + frame_path.write_text( + "\n".join( + [ + "2", + frame_path.stem, + "Pb 0.0 0.0 0.0", + f"I {2.8 + index:.1f} 0.0 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + frame_paths.append(frame_path.resolve()) + + window = FFTBornApproximationMainWindow(preview_mode=True) + window._load_input_path(clusters_dir) -def test_main_window_loads_contrast_distribution_after_tool_build_signal( + target = window._active_profile_target() + assert target is not None + assert target.source_mode == "average" + assert target.file_count == 2 + assert target.source_files == tuple(frame_paths) + window.close() + + +def test_fft_representative_root_loads_singular_metadata_sources( qapp, tmp_path, monkeypatch, ): del qapp - project_dir, _settings, artifact_paths, _build_result = ( - _build_saved_contrast_distribution_project(tmp_path) + representative_root = tmp_path / "rmcsetup" / "representative_structures" + source_dir = representative_root / "partialsolv" / "PbI" + source_dir.mkdir(parents=True) + representative_path = source_dir / "selected_rep.xyz" + stale_path = source_dir / "stale_extra.xyz" + representative_text = ( + "\n".join( + [ + "2", + "selected representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + ] + ) + + "\n" + ) + representative_path.write_text(representative_text, encoding="utf-8") + stale_path.write_text(representative_text, encoding="utf-8") + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + metadata = RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="PbI", + motif="no_motif", + param="PbI", + selected_weight=1.0, + cluster_count=5, + source_dir=str(source_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=2, + element_counts={"Pb": 1, "I": 1}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], ) - window = SAXSMainWindow(initial_project_dir=project_dir) - loaded_distribution_ids: list[str] = [] - + project_source = SimpleNamespace( + representative_selection=metadata, + solvent_handling=None, + ) + window = FFTBornApproximationMainWindow( + preview_mode=True, + initial_project_dir=tmp_path, + initial_use_representative_structures=True, + ) + monkeypatch.setattr(window, "_project_source", lambda: project_source) + errors: list[tuple[str, str]] = [] monkeypatch.setattr( window, - "_load_saved_distribution", - lambda distribution_id: loaded_distribution_ids.append( - distribution_id - ), + "_show_error", + lambda title, message: errors.append((title, message)), ) - window._on_contrast_components_built( - { - "project_dir": str(project_dir.resolve()), - "distribution_id": artifact_paths.distribution_id, - "distribution_dir": str(artifact_paths.root_dir), - "component_dir": str(artifact_paths.component_dir), - "component_map_path": str(artifact_paths.component_map_file), - } - ) + window._load_input_path(representative_root) - assert loaded_distribution_ids == [artifact_paths.distribution_id] + assert errors == [] + assert window.structure_source_combo.currentData() == "representative" + target = window._active_profile_target() + assert target is not None + assert target.file_count == 1 + assert target.source_files == (representative_path.resolve(),) + assert stale_path.resolve() not in target.source_files + window.close() + + +def test_solvent_shell_builder_tool_uses_active_structure_folder( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + pdb_frames_dir = tmp_path / "pdb_frames" + pdb_frames_dir.mkdir() + xyz_frames_dir = tmp_path / "xyz_frames" + xyz_frames_dir.mkdir() + window.current_settings.pdb_frames_dir = str(pdb_frames_dir.resolve()) + window.current_settings.frames_dir = str(xyz_frames_dir.resolve()) + window.project_manager.save_project(window.current_settings) + launched: dict[str, object] = {} + + class FakeSolventShellBuilderWindow(QWidget): + def __init__(self): + super().__init__() + launched["instance"] = self + + def fake_launch_solvent_shell_builder_ui(**kwargs): + launched.update(kwargs) + return FakeSolventShellBuilderWindow() + + monkeypatch.setattr( + "saxshell.fullrmc.ui.solvent_shell_builder_window.launch_solvent_shell_builder_ui", + fake_launch_solvent_shell_builder_ui, + ) + + window._open_solvent_shell_builder_tool() + + assert launched["initial_project_dir"] == Path(project_dir).resolve() + assert launched["initial_input_path"] == pdb_frames_dir.resolve() + assert launched["instance"] in window._child_tool_windows + window.close() + + +def test_open_contrast_mode_tool_does_not_auto_start_build( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launched: dict[str, object] = {"start_calls": 0} + + class FakeContrastModeWindow(QWidget): + def __init__(self): + super().__init__() + launched["instance"] = self + + def start_contrast_component_build(self): + launched["start_calls"] += 1 + + def raise_(self): + launched["raised"] = True + + def activateWindow(self): + launched["activated"] = True + + def fake_launch_contrast_mode_ui(**kwargs): + launched.update(kwargs) + return FakeContrastModeWindow() + + monkeypatch.setattr( + window, + "_confirm_default_q_range_for_component_build", + lambda: True, + ) + monkeypatch.setattr( + "saxshell.saxs.contrast.ui.main_window.launch_contrast_mode_ui", + fake_launch_contrast_mode_ui, + ) + + window.project_setup_tab.set_component_build_mode( + COMPONENT_BUILD_MODE_CONTRAST + ) + window.build_project_components() + + assert launched["instance"] in window._child_tool_windows + assert launched["start_calls"] == 0 + window.close() + + +def test_contrast_mode_window_can_close_while_workflow_thread_is_running( + qapp, + monkeypatch, +): + del qapp + window = ContrastModeMainWindow() + + class FakeRunningThread: + def isRunning(self) -> bool: + return True + + warnings: list[tuple[object, ...]] = [] + monkeypatch.setattr( + "saxshell.saxs.contrast.ui.main_window.QMessageBox.warning", + lambda *args, **kwargs: warnings.append(args), + ) + window._workflow_thread = FakeRunningThread() + + assert window.close() is True + assert warnings == [] + + +def test_main_window_loads_contrast_distribution_after_tool_build_signal( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _settings, artifact_paths, _build_result = ( + _build_saved_contrast_distribution_project(tmp_path) + ) + window = SAXSMainWindow(initial_project_dir=project_dir) + loaded_distribution_ids: list[str] = [] + + monkeypatch.setattr( + window, + "_load_saved_distribution", + lambda distribution_id: loaded_distribution_ids.append( + distribution_id + ), + ) + + window._on_contrast_components_built( + { + "project_dir": str(project_dir.resolve()), + "distribution_id": artifact_paths.distribution_id, + "distribution_dir": str(artifact_paths.root_dir), + "component_dir": str(artifact_paths.component_dir), + "component_map_path": str(artifact_paths.component_map_file), + } + ) + + assert loaded_distribution_ids == [artifact_paths.distribution_id] assert "built and loaded" in window.statusBar().currentMessage().lower() window.close() +def test_born_push_bootstraps_distribution_so_prefit_and_dream_can_load( + qapp, + tmp_path, +): + del qapp + from saxshell.saxs.electron_density_mapping.ui.main_window import ( + ElectronDensityMappingMainWindow, + _ClusterDensityGroupState, + ) + from saxshell.saxs.electron_density_mapping.workflow import ( + ElectronDensityFourierTransformPreview, + ElectronDensityFourierTransformSettings, + ElectronDensityInputInspection, + ElectronDensityScatteringTransformResult, + ElectronDensityStructure, + ) + + manager = SAXSProjectManager() + project_dir = tmp_path / "born_push_project" + settings = manager.create_project(project_dir) + paths = build_project_paths(project_dir) + + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + template_name = "template_pd_likelihood_monosq_decoupled" + template_module = load_template_module(template_name) + experimental = template_module.lmfit_model_profile( + q_values, + np.zeros_like(q_values), + [component], + w0=0.6, + solv_w=0.0, + offset=0.05, + eff_r=9.0, + vol_frac=0.0, + scale=5e-4, + ) + + experimental_path = paths.experimental_data_dir / "exp_demo.txt" + np.savetxt(experimental_path, np.column_stack([q_values, experimental])) + + cluster_dir = paths.project_dir / "clusters" / "A" + cluster_dir.mkdir(parents=True, exist_ok=True) + frame_path = cluster_dir / "frame_0001.xyz" + frame_path.write_text( + "\n".join( + [ + "3", + "frame_0001", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + + settings.clusters_dir = str(paths.project_dir / "clusters") + settings.experimental_data_path = str(experimental_path) + settings.copied_experimental_data_file = str(experimental_path) + settings.selected_model_template = template_name + settings.component_build_mode = COMPONENT_BUILD_MODE_BORN_APPROXIMATION + settings.use_experimental_grid = False + settings.q_min = float(q_values.min()) + settings.q_max = float(q_values.max()) + settings.q_points = int(q_values.size) + manager.save_project(settings) + + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + assert not artifact_paths.distribution_metadata_file.is_file() + assert not artifact_paths.prior_weights_file.is_file() + + window = ElectronDensityMappingMainWindow( + initial_project_dir=project_dir.resolve(), + initial_distribution_id=artifact_paths.distribution_id, + initial_distribution_root_dir=artifact_paths.root_dir, + preview_mode=False, + ) + + inspection = ElectronDensityInputInspection( + selection_path=cluster_dir, + input_mode="folder", + structure_files=(frame_path,), + reference_file=frame_path, + format_counts={"xyz": 1}, + ) + reference_structure = ElectronDensityStructure( + file_path=frame_path, + display_label="frame_0001.xyz", + structure_comment="frame_0001", + coordinates=np.asarray( + [ + [0.0, 0.0, 0.0], + [2.8, 0.0, 0.0], + [0.0, 2.8, 0.0], + ], + dtype=float, + ), + centered_coordinates=np.asarray( + [ + [0.0, 0.0, 0.0], + [2.8, 0.0, 0.0], + [0.0, 2.8, 0.0], + ], + dtype=float, + ), + elements=("Pb", "I", "I"), + element_counts={"Pb": 1, "I": 2}, + atomic_numbers=np.asarray([82.0, 53.0, 53.0], dtype=float), + atomic_masses=np.asarray([207.2, 126.9, 126.9], dtype=float), + center_of_mass=np.asarray([0.0, 0.0, 0.0], dtype=float), + geometric_center=np.asarray([0.0, 0.0, 0.0], dtype=float), + reference_element="Pb", + reference_element_geometric_center=np.asarray( + [0.0, 0.0, 0.0], dtype=float + ), + reference_element_offset_from_geometric_center=0.0, + active_center=np.asarray([0.0, 0.0, 0.0], dtype=float), + center_mode="center_of_mass", + nearest_atom_index=0, + nearest_atom_distance=0.0, + bonds=((0, 1), (0, 2)), + rmax=2.8, + ) + fourier_settings = ElectronDensityFourierTransformSettings( + r_min=0.0, + r_max=5.0, + domain_mode="mirrored", + window_function="hanning", + resampling_points=2048, + q_min=float(q_values.min()), + q_max=float(q_values.max()), + q_step=float(q_values[1] - q_values[0]), + ) + preview = ElectronDensityFourierTransformPreview( + settings=fourier_settings, + source_profile_label="A", + source_radial_values=np.asarray([0.0, 1.0], dtype=float), + source_density_values=np.asarray([1.0, 0.5], dtype=float), + resampled_r_values=np.asarray([0.0, 1.0], dtype=float), + resampled_density_values=np.asarray([1.0, 0.5], dtype=float), + window_values=np.asarray([1.0, 1.0], dtype=float), + windowed_density_values=np.asarray([1.0, 0.5], dtype=float), + available_r_min=0.0, + available_r_max=1.0, + resampling_step_a=0.5, + nyquist_q_max_a_inverse=1.0, + independent_q_step_a_inverse=float(q_values[1] - q_values[0]), + q_grid_is_oversampled=False, + q_max_was_clamped=False, + notes=(), + source_mode="density_fourier", + ) + transform_result = ElectronDensityScatteringTransformResult( + preview=preview, + q_values=np.asarray(q_values, dtype=float), + scattering_amplitude=np.sqrt(component), + intensity=np.asarray(component, dtype=float), + ) + window._cluster_group_states = [ + _ClusterDensityGroupState( + key="A", + display_name="A", + structure_name="A", + motif_name="no_motif", + source_dir=cluster_dir, + inspection=inspection, + reference_structure=reference_structure, + average_atom_count=3.0, + single_atom_only=False, + trace_color="#2563eb", + transform_result=transform_result, + ) + ] + window._selected_cluster_group_key = "A" + + window._push_components_to_model() + + assert artifact_paths.distribution_metadata_file.is_file() + assert artifact_paths.prior_weights_file.is_file() + assert artifact_paths.component_map_file.is_file() + assert any(artifact_paths.component_dir.glob("*.txt")) + + saved_record = manager.load_saved_distribution( + project_dir, + artifact_paths.distribution_id, + ) + assert saved_record.component_artifacts_ready + assert saved_record.prior_artifacts_ready + + prefit_workflow = SAXSPrefitWorkflow(project_dir) + assert [ + (component.structure, component.motif) + for component in prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose(prefit_workflow.evaluate().q_values, q_values) + + dream_workflow = SAXSDreamWorkflow(project_dir) + assert [ + (component.structure, component.motif) + for component in dream_workflow.prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + dream_workflow.prefit_workflow.evaluate().q_values, + q_values, + ) + window.close() + + def test_contrast_mode_workspace_tracks_manual_representative_selection( qapp, tmp_path, monkeypatch ): @@ -5612,7 +6367,7 @@ def test_contrast_mode_workspace_reloads_saved_distribution_artifacts( experimental_axis.get_title() == "Experimental Data and Contrast Traces" ) - assert experimental_axis.get_xlabel() == "q (Å⁻¹)" + assert experimental_axis.get_xlabel() == Q_A_INVERSE_LABEL assert ( experimental_axis.get_ylabel() == "Experimental Intensity (arb. units)" ) @@ -7654,7 +8409,18 @@ def test_project_setup_component_build_mode_defaults_and_round_trips(qapp): ) assert ( tab.component_build_mode_combo.currentText() - == "Born Approximation (Average)" + == "1D Born Approximation (Average)" + ) + tab.set_component_build_mode( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + assert ( + tab.component_build_mode() + == COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + assert ( + tab.component_build_mode_combo.currentText() + == "3D FFT Born Approximation" ) tab.set_project_settings(settings, []) @@ -8804,7 +9570,7 @@ def test_dream_model_plot_includes_residual_subplot(qapp, tmp_path): residual_axis = window.dream_tab.model_figure.axes[1] assert top_axis.get_title().startswith("DREAM refinement:") assert residual_axis.get_ylabel() == "Residual" - assert residual_axis.get_xlabel() == "q (Å⁻¹)" + assert residual_axis.get_xlabel() == Q_A_INVERSE_LABEL assert residual_axis.get_xscale() == top_axis.get_xscale() residual_line = residual_axis.get_lines()[-1] @@ -10670,18 +11436,69 @@ def test_prefit_recommended_scale_button_updates_scale_bounds(qapp, tmp_path): ) -def test_prefit_tab_reorders_controls_and_parameter_actions(qapp): +def test_scaled_solvent_monosq_auto_autoscales_on_project_load( + qapp, + tmp_path, +): del qapp - tab = PrefitTab() - - assert tab._parameter_action_layout.itemAt(0).widget() is ( - tab.recommended_scale_button - ) - assert tab._parameter_action_layout.itemAt(1).widget() is ( - tab.update_button + project_dir, paths = _build_minimal_saxs_project(tmp_path) + q_values = np.linspace(0.05, 0.3, 8) + component = np.linspace(10.0, 17.0, 8) + experimental_scale = 2e-7 + experimental_offset = 0.01 + template_module = load_template_module(SCALED_SOLVENT_MONOSQ_TEMPLATE) + experimental = template_module.lmfit_model_profile( + q_values, + np.zeros_like(q_values), + [component], + w0=0.6, + solv_w=0.0, + offset=experimental_offset, + eff_r=3.0, + vol_frac=0.0, + scale=experimental_scale, ) - assert tab._parameter_action_layout.itemAt(2).widget() is ( - tab.auto_update_checkbox + experimental_path = paths.experimental_data_dir / "adaptive_exp_demo.txt" + np.savetxt(experimental_path, np.column_stack([q_values, experimental])) + + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.selected_model_template = SCALED_SOLVENT_MONOSQ_TEMPLATE + settings.experimental_data_path = str(experimental_path) + settings.copied_experimental_data_file = str(experimental_path) + manager.save_project(settings) + + window = SAXSMainWindow(initial_project_dir=project_dir) + entries = { + entry.name: entry for entry in window.prefit_tab.parameter_entries() + } + + assert entries["eff_r"].value == pytest.approx(3.0) + assert entries["scale"].value == pytest.approx(experimental_scale) + assert entries["scale"].minimum == pytest.approx(experimental_scale / 10.0) + assert entries["scale"].maximum == pytest.approx(experimental_scale * 10.0) + assert entries["scale"].vary + assert entries["offset"].value == pytest.approx(experimental_offset) + assert entries["offset"].minimum == pytest.approx(0.009) + assert entries["offset"].maximum == pytest.approx(0.011) + assert "Applied initial autoscale settings" in ( + window.prefit_tab.output_box.toPlainText() + ) + window.close() + + +def test_prefit_tab_reorders_controls_and_parameter_actions(qapp): + del qapp + tab = PrefitTab() + + assert tab._parameter_action_layout.itemAt(0).widget() is ( + tab.recommended_scale_button + ) + assert tab._parameter_action_layout.itemAt(1).widget() is ( + tab.update_button + ) + assert tab._parameter_action_layout.itemAt(2).widget() is ( + tab.auto_update_checkbox ) assert tab._parameter_action_layout.itemAt(3).widget() is ( tab.scrollable_parameter_checkbox @@ -11879,6 +12696,285 @@ def test_solvent_sort_prior_histogram_does_not_auto_match_component_colors( assert prior_window.structure_motif_colors is None +def test_project_setup_prior_histogram_plot_editor_applies_settings( + qapp, + tmp_path, +): + del qapp + json_path = tmp_path / "md_prior_weights.json" + json_path.write_text( + json.dumps( + { + "origin": "clusters", + "total_files": 5, + "structures": { + "PbI2": { + "motif_A": { + "count": 2, + "weight": 0.4, + "profile_file": "PbI2_motif_A.txt", + } + }, + "Pb2I4": { + "motif_B": { + "count": 3, + "weight": 0.6, + "profile_file": "Pb2I4_motif_B.txt", + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + tab = ProjectSetupTab() + tab.set_project_selected(True) + tab.show() + tab.draw_prior_plot(json_path) + QApplication.processEvents() + + assert tab.open_prior_plot_editor_button.isEnabled() + + tab.open_prior_plot_editor_button.click() + QApplication.processEvents() + + assert tab._prior_plot_editor_window is not None + assert tab._prior_plot_editor_controls is not None + + controls = tab._prior_plot_editor_controls + controls.title_edit.setText("Edited Prior Histogram") + controls.x_label_edit.setText("Cluster Bin") + controls.y_label_edit.setText("Population (%)") + controls.legend_title_edit.setText("Segment Class") + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("viridis") + ) + controls.show_total_annotations_checkbox.setChecked(False) + controls.legend_location_combo.setCurrentIndex( + controls.legend_location_combo.findData("upper_left") + ) + controls.x_tick_rotation_spin.setValue(12) + controls.label_table.setCurrentCell(1, 1) + controls._move_label_up() + controls.label_table.item(0, 1).setText(r"\f01Pb$_{2}$I$_{4}$") + QApplication.processEvents() + + axis = tab.prior_figure.axes[0] + tick_texts = [ + tick.get_text() + for tick in axis.get_xticklabels() + if tick.get_text().strip() + ] + legend = axis.get_legend() + assert axis.get_title() == "Edited Prior Histogram" + assert axis.get_xlabel() == "Cluster Bin" + assert axis.get_ylabel() == "Population (%)" + assert legend is not None + assert legend.get_title().get_text() == "Segment Class" + assert legend._loc == 2 + assert tab.prior_color_combo.currentText() == "viridis" + assert tab.prior_x_axis_order_combo.currentData() == "custom" + assert tab.prior_histogram_x_axis_order()[0][0] == "Pb2I4" + assert r"\mathbf{Pb}" in tick_texts[0] + assert not any(text.get_text().endswith("%") for text in axis.texts) + assert axis.patches + assert to_hex(axis.patches[0].get_facecolor()) == to_hex( + colormaps["viridis"](0.1) + ) + + tab._prior_plot_editor_window.close() + tab.close() + + +def test_project_setup_prior_histogram_restored_axes_keep_auto_labels( + qapp, + tmp_path, +): + json_path = tmp_path / "md_prior_weights.json" + json_path.write_text( + json.dumps( + { + "origin": "clusters", + "total_files": 5, + "structures": { + "PbI2": { + "motif_A": { + "count": 2, + "weight": 0.4, + "profile_file": "PbI2_motif_A.txt", + } + }, + "Pb2I4": { + "motif_B": { + "count": 3, + "weight": 0.6, + "profile_file": "Pb2I4_motif_B.txt", + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + tab = ProjectSetupTab() + tab.set_project_selected(True) + tab.draw_prior_plot(json_path) + qapp.processEvents() + + tab.set_prior_plot_state( + { + "axes": [ + { + "xscale": "linear", + "yscale": "linear", + "xlim": [-0.25, 1.25], + "ylim": [0.0, 120.0], + } + ] + } + ) + qapp.processEvents() + + axis = tab.prior_figure.axes[0] + tick_texts = [ + tick.get_text() + for tick in axis.get_xticklabels() + if tick.get_text().strip() + ] + + assert tick_texts == ["PbI$_{2}$", "Pb$_{2}$I$_{4}$"] + assert axis.get_xlim() == pytest.approx((-0.25, 1.25)) + assert axis.get_ylim() == pytest.approx((0.0, 120.0)) + tab.close() + + +def test_project_setup_prior_histogram_plot_editor_can_save_and_load_pickled_state( + qapp, + tmp_path, + monkeypatch, +): + del qapp + json_path = tmp_path / "md_prior_weights.json" + json_path.write_text( + json.dumps( + { + "origin": "clusters", + "total_files": 5, + "structures": { + "A": { + "motif_A": { + "count": 3, + "weight": 0.6, + "profile_file": "A_motif_A.txt", + } + }, + "A2": { + "motif_B": { + "count": 2, + "weight": 0.4, + "profile_file": "A2_motif_B.txt", + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + tab = ProjectSetupTab() + tab.set_project_selected(True) + tab.show() + tab.draw_prior_plot(json_path) + QApplication.processEvents() + tab.open_prior_plot_editor_button.click() + QApplication.processEvents() + + assert tab._prior_plot_editor_window is not None + assert tab._prior_plot_editor_controls is not None + + editor = tab._prior_plot_editor_window + controls = tab._prior_plot_editor_controls + pickle_path = tmp_path / "saved_prior_histogram.pkl" + + monkeypatch.setattr( + "saxshell.plotting.plot_editor.QFileDialog.getSaveFileName", + lambda *args, **kwargs: ( + str(pickle_path), + "Pickled Plot Files (*.pkl)", + ), + ) + monkeypatch.setattr( + "saxshell.plotting.plot_editor.QFileDialog.getOpenFileName", + lambda *args, **kwargs: ( + str(pickle_path), + "Pickled Plot Files (*.pkl)", + ), + ) + + controls.title_edit.setText("Pickled Prior Histogram") + controls.legend_title_edit.setText("Saved Legend") + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("magma") + ) + controls.label_table.item(0, 1).setText("Alpha") + QApplication.processEvents() + + assert editor.save_pickled_plot_as() == pickle_path + assert pickle_path.is_file() + + pickled_figure = load_pickled_plot_figure(pickle_path) + assert pickled_figure.axes[0].get_title() == "Pickled Prior Histogram" + assert pickled_figure.axes[0].get_legend() is not None + assert ( + pickled_figure.axes[0].get_legend().get_title().get_text() + == "Saved Legend" + ) + + controls.title_edit.setText("Live Prior Histogram") + controls.legend_title_edit.setText("Live Legend") + controls.colormap_combo.setCurrentIndex( + controls.colormap_combo.findData("summer") + ) + controls.label_table.item(0, 1).setText("Live Alpha") + QApplication.processEvents() + assert tab.prior_figure.axes[0].get_title() == "Live Prior Histogram" + assert tab.prior_color_combo.currentText() == "summer" + + assert editor.load_pickled_plot_as() == pickle_path + QApplication.processEvents() + + assert not editor.is_showing_pickled_plot() + assert controls.title_edit.text() == "Pickled Prior Histogram" + assert controls.legend_title_edit.text() == "Saved Legend" + assert controls.colormap_combo.currentData() == "magma" + assert tab.prior_color_combo.currentText() == "magma" + assert tab.prior_x_axis_order_combo.currentData() == "custom" + assert ( + tab.prior_figure.axes[0].get_xticklabels()[0].get_text().strip() + == "Alpha" + ) + assert ( + tab.prior_figure.axes[0].get_legend().get_title().get_text() + == "Saved Legend" + ) + + controls.title_edit.setText("Editable After Load") + QApplication.processEvents() + assert tab.prior_figure.axes[0].get_title() == "Editable After Load" + assert editor.figure.axes[0].get_title() == "Editable After Load" + + tab._prior_plot_editor_window.close() + tab.close() + + def test_create_project_warns_before_overwriting_existing_folder( qapp, tmp_path, monkeypatch ): @@ -12222,7 +13318,7 @@ def test_project_setup_preview_updates_with_experimental_q_range( assert tab.component_log_x_checkbox.isChecked() assert tab.component_log_y_checkbox.isChecked() assert preview_axis.get_title() == "Experimental Data Preview" - assert preview_axis.get_xlabel() == "q (Å⁻¹)" + assert preview_axis.get_xlabel() == Q_A_INVERSE_LABEL assert preview_axis.get_ylabel() == "Intensity (arb. units)" assert preview_axis.get_xscale() == "log" assert preview_axis.get_yscale() == "log" @@ -12505,6 +13601,9 @@ def run_task_sync( QApplication.processEvents() assert window.project_setup_tab.computed_distribution_combo.count() == 1 + assert "DREAM fits: 0" in ( + window.project_setup_tab.computed_distribution_combo.itemText(0) + ) assert ( window.project_setup_tab.selected_distribution_id() == expected_distribution_id @@ -12561,75 +13660,1873 @@ def test_build_components_in_contrast_mode_launches_scaffold_instead_of_builder_ == COMPONENT_BUILD_MODE_CONTRAST ) assert ( - "Contrast (Debye)" - in window.project_setup_tab.summary_box.toPlainText() + "Contrast (Debye)" + in window.project_setup_tab.summary_box.toPlainText() + ) + window.close() + + +def test_build_components_in_born_approximation_launches_electron_density_workflow( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + launched: list[dict[str, object]] = [] + start_calls: list[str] = [] + clusters_dir = tmp_path / "clusters" + (clusters_dir / "PbI2").mkdir(parents=True) + (clusters_dir / "PbI2" / "frame_0001.xyz").write_text( + "\n".join( + [ + "3", + "frame_0001", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + window.project_setup_tab.clusters_dir_edit.setText( + str(clusters_dir.resolve()) + ) + window.project_setup_tab.set_component_build_mode( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION + ) + monkeypatch.setattr( + window, + "_confirm_default_q_range_for_component_build", + lambda: True, + ) + monkeypatch.setattr( + window, + "_start_project_task", + lambda task_name, *args, **kwargs: start_calls.append(task_name), + ) + monkeypatch.setattr( + window, + "_open_electron_density_mapping_tool", + lambda **kwargs: launched.append(kwargs), + ) + + window.build_project_components() + + assert not start_calls + assert launched == [{"preview_mode": False}] + assert window.current_settings is not None + assert ( + window.current_settings.component_build_mode + == COMPONENT_BUILD_MODE_BORN_APPROXIMATION + ) + assert ( + "1D Born Approximation (Average)" + in window.project_setup_tab.summary_box.toPlainText() + ) + window.close() + + +def test_build_project_components_opens_3d_fft_born_window( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + window = SAXSMainWindow(initial_project_dir=project_dir) + start_calls: list[str] = [] + launched: list[dict[str, object]] = [] + clusters_dir = project_dir / "clusters" + clusters_dir.mkdir() + stoich_dir = clusters_dir / "PbI2" + stoich_dir.mkdir() + (stoich_dir / "frame_0001.xyz").write_text( + "\n".join( + [ + "3", + "frame_0001", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + window.project_setup_tab.project_dir_edit.setText( + str(project_dir.resolve()) + ) + window.project_setup_tab.clusters_dir_edit.setText( + str(clusters_dir.resolve()) + ) + window.project_setup_tab.set_component_build_mode( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + monkeypatch.setattr( + window, + "_confirm_default_q_range_for_component_build", + lambda: True, + ) + monkeypatch.setattr( + window, + "_start_project_task", + lambda task_name, *args, **kwargs: start_calls.append(task_name), + ) + monkeypatch.setattr( + window, + "_open_3d_fft_born_approximation_tool", + lambda **kwargs: launched.append(kwargs), + ) + + window.build_project_components() + + assert not start_calls + assert launched == [{"preview_mode": False}] + assert window.current_settings is not None + assert ( + window.current_settings.component_build_mode + == COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + assert ( + "3D FFT Born Approximation" + in window.project_setup_tab.summary_box.toPlainText() + ) + window.close() + + +def test_create_fft_distribution_resets_pushed_component_indicator( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + clusters_dir = project_dir / "clusters" + structure_dir = clusters_dir / "A" + structure_dir.mkdir(parents=True, exist_ok=True) + (structure_dir / "frame_0001.xyz").write_text( + "1\nframe 1\nA 0.0 0.0 0.0\n", + encoding="utf-8", + ) + settings.clusters_dir = str(clusters_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + manager.save_project(settings) + + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + shutil.copytree( + paths.scattering_components_dir, + artifact_paths.component_dir, + dirs_exist_ok=True, + ) + shutil.copy2( + paths.project_dir / "md_saxs_map.json", + artifact_paths.component_map_file, + ) + manager._write_distribution_metadata( + settings, + artifact_paths=artifact_paths, + built_component_source_mode="average", + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert "#16a34a" in ( + window.project_setup_tab.components_built_indicator.styleSheet() + ) + + def run_task_sync( + task_name, + task_fn, + *, + start_message, + settings=None, + ): + del start_message, settings + result = task_fn(lambda *_args, **_kwargs: None) + window._on_task_finished(task_name, result) + + monkeypatch.setattr(window, "_start_project_task", run_task_sync) + + window.build_prior_weights() + QApplication.processEvents() + + metadata_payload = json.loads( + artifact_paths.distribution_metadata_file.read_text(encoding="utf-8") + ) + component_files = ( + list(artifact_paths.component_dir.glob("*.txt")) + if artifact_paths.component_dir.is_dir() + else [] + ) + + assert "#6b7280" in ( + window.project_setup_tab.components_built_indicator.styleSheet() + ) + assert not artifact_paths.component_map_file.exists() + assert component_files == [] + assert metadata_payload["component_artifacts_ready"] is False + assert metadata_payload["prior_artifacts_ready"] is True + window.close() + + +def test_fft_born_main_window_uses_split_scrollable_layout(qapp): + del qapp + window = FFTBornApproximationMainWindow(preview_mode=True) + assert window.windowTitle() == "3D FFT Born Approximation (Preview)" + assert isinstance(window._pane_splitter, QSplitter) + assert isinstance(window._left_scroll_area, QScrollArea) + assert isinstance(window._right_scroll_area, QScrollArea) + assert "Preview Mode" in window.preview_mode_banner.text() + assert window.log_q_checkbox.isChecked() + assert window.log_intensity_checkbox.isChecked() + assert window.contrast_section.is_expanded + assert window.fft_settings_section.is_expanded + assert window.legacy_1d_settings_section.is_expanded is False + assert window.toggle_curve_legend_button.text() == "Hide Legend" + assert window.export_curve_csv_button.text() == "Export Plot CSV" + assert window.structure_viewer.show_mesh_checkbox.isChecked() + assert len(window.fft_box_visualizer.figure.axes) == 1 + assert window.fft_box_visualizer.figure.axes[0].name == "3d" + for widget in ( + window.q_min_spin, + window.q_max_spin, + window.q_step_spin, + window.spacing_spin, + window.sigma_spin, + window.min_box_length_spin, + window.padding_spin, + ): + assert widget.toolTip().strip() + assert "project q-range" in window.q_min_spin.toolTip() + assert "Nyquist limit" in window.spacing_spin.toolTip() + window.close() + + +def test_tools_menu_3d_fft_action_triggers_preview_launch(qapp, monkeypatch): + del qapp + window = SAXSMainWindow() + launched: list[dict[str, object]] = [] + monkeypatch.setattr( + window, + "_open_3d_fft_born_approximation_tool", + lambda **kwargs: launched.append(kwargs), + ) + + window.fft_born_approximation_action.trigger() + QApplication.processEvents() + + assert launched == [{"preview_mode": True}] + window.close() + + +def test_fft_born_main_window_loads_mesh_preview_and_exports_curve_csv( + qapp, + tmp_path, + monkeypatch, +): + del qapp + structure_path = tmp_path / "preview.xyz" + structure_path.write_text( + "\n".join( + [ + "3", + "preview", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + window = FFTBornApproximationMainWindow(preview_mode=True) + window._load_input_path(structure_path) + + assert window.structure_viewer.current_structure is not None + assert window.structure_viewer.current_mesh_geometry is not None + assert window.structure_viewer.show_mesh_checkbox.isChecked() + preview_axes = window.fft_box_visualizer.figure.axes + assert len(preview_axes) == 1 + assert preview_axes[0].name == "3d" + + q_values = np.asarray([0.01, 0.02, 0.03], dtype=float) + fft_settings = ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + ).normalized() + fft_result = ContrastFFTResult( + settings=fft_settings, + q_values=q_values, + raw_intensity=np.asarray([1.0, 0.8, 0.6], dtype=float), + kernel_corrected_intensity=np.asarray([1.0, 0.85, 0.7], dtype=float), + q_shell_counts=np.asarray([8, 12, 16], dtype=int), + density_integral=180.0, + expected_weight=180.0, + contrast_density_integral=180.0, + expected_contrast_weight=180.0, + solvent_exclusion_volume_a3=0.0, + grid_shape=(257, 257, 257), + box_lengths_a=(642.5, 642.5, 642.5), + voxel_spacing_a=(2.5, 2.5, 2.5), + q_nyquist_a_inverse=float(np.pi / 2.5), + q_frequency_step_a_inverse=( + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + ), + q_convention="q = 2πf", + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + first_nonempty_q_a_inverse=0.01, + solvent_density_e_per_a3=0.0, + contrast_mode="bare_atomic_density_only", + kernel_correction_supported=True, + kernel_correction_applied=True, + kernel_correction_model="Gaussian deposition intensity factor exp(-sigma^2 q^2)", + timing=ContrastFFTTiming( + atomic_density_seconds=0.1, + contrast_density_seconds=0.0, + fft_seconds=0.2, + shell_average_seconds=0.3, + total_seconds=0.6, + ), + ) + payload = _FFTComputationPayload( + q_values=q_values, + profile_results=( + _FFTProfileComputationResult( + target=window._active_profile_target(), + q_values=q_values, + fft_result=fft_result, + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ), + ), + ) + window._on_worker_finished(payload) + assert "3D FFT Nyquist limit:" in window.status_log_box.toPlainText() + result_axes = window.fft_box_visualizer.figure.axes + assert len(result_axes) == 2 + assert all(axis.name == "3d" for axis in result_axes) + + export_path = tmp_path / "fft_curves.csv" + monkeypatch.setattr( + QFileDialog, + "getSaveFileName", + lambda *args, **kwargs: (str(export_path), "CSV files (*.csv)"), + ) + window._export_q_space_curves_csv() + + exported_text = export_path.read_text(encoding="utf-8") + assert "3d_fft_born_approximation_q_a_inverse" in exported_text + assert "3d_fft_born_approximation_intensity" in exported_text + window.close() + + +def test_fft_worker_applies_active_contrast_to_legacy_overlay( + tmp_path, + monkeypatch, +): + structure_path = tmp_path / "contrast_overlay.xyz" + structure_path.write_text( + "\n".join( + [ + "3", + "contrast_overlay", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + load_electron_density_structure(structure_path) + q_values = np.asarray([0.01, 0.02, 0.03], dtype=float) + captured: dict[str, object] = {} + + monkeypatch.setattr( + "saxshell.saxs.contrast_fft.ui.main_window.build_shared_q_grid", + lambda q_min, q_max, q_step: q_values, + ) + monkeypatch.setattr( + "saxshell.saxs.contrast_fft.ui.main_window.compute_contrast_fft_intensity", + lambda *args, **kwargs: ContrastFFTResult( + settings=ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + ).normalized(), + q_values=q_values, + raw_intensity=np.asarray([1.0, 0.8, 0.6], dtype=float), + kernel_corrected_intensity=np.asarray( + [1.0, 0.8, 0.6], dtype=float + ), + q_shell_counts=np.asarray([8, 12, 16], dtype=int), + density_integral=180.0, + expected_weight=180.0, + contrast_density_integral=180.0, + expected_contrast_weight=180.0, + solvent_exclusion_volume_a3=0.0, + grid_shape=(257, 257, 257), + box_lengths_a=(642.5, 642.5, 642.5), + voxel_spacing_a=(2.5, 2.5, 2.5), + q_nyquist_a_inverse=float(np.pi / 2.5), + q_frequency_step_a_inverse=( + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + ), + q_convention="q = 2πf", + uses_two_pi_frequency_conversion=True, + density_subtraction_active=True, + first_nonempty_q_a_inverse=0.01, + solvent_density_e_per_a3=0.334, + contrast_mode="constant_solvent_density_inside_union_of_atomic_spheres", + kernel_correction_supported=False, + kernel_correction_applied=False, + kernel_correction_model=None, + timing=ContrastFFTTiming( + atomic_density_seconds=0.1, + contrast_density_seconds=0.1, + fft_seconds=0.1, + shell_average_seconds=0.1, + total_seconds=0.4, + ), + ), + ) + monkeypatch.setattr( + "saxshell.saxs.contrast_fft.ui.main_window.compute_electron_density_profile", + lambda *args, **kwargs: SimpleNamespace( + radial_centers=np.asarray([0.0, 1.0, 2.0], dtype=float), + shell_volumes=np.asarray([1.0, 1.0, 1.0], dtype=float), + solvent_contrast=None, + ), + ) + + def _fake_apply_contrast(profile, settings, *, solvent_name=None): + captured["contrast_settings"] = settings + captured["contrast_name"] = solvent_name + return SimpleNamespace( + radial_centers=np.asarray([0.0, 1.0, 2.0], dtype=float), + shell_volumes=np.asarray([1.0, 1.0, 1.0], dtype=float), + solvent_contrast=SimpleNamespace( + solvent_density_e_per_a3=0.334, + solvent_subtracted_smeared_density=np.asarray( + [2.0, 1.5, 1.0], dtype=float + ), + ), + ) + + monkeypatch.setattr( + "saxshell.saxs.contrast_fft.ui.main_window.apply_solvent_contrast_to_profile_result", + _fake_apply_contrast, + ) + + def _fake_scattering(profile, settings): + captured["use_solvent_subtracted_profile"] = ( + settings.use_solvent_subtracted_profile + ) + return SimpleNamespace( + q_values=q_values, + intensity=np.asarray([1.0, 0.85, 0.7], dtype=float), + ) + + monkeypatch.setattr( + "saxshell.saxs.contrast_fft.ui.main_window.compute_electron_density_scattering_profile", + _fake_scattering, + ) + + worker = _FFTComputationWorker( + targets=( + _FFTProfileTarget( + key="average|input|contrast_overlay|no_motif", + display_name="contrast_overlay", + structure_name="contrast_overlay", + motif_name="no_motif", + file_count=1, + reference_file=structure_path, + source_files=(structure_path,), + representative=structure_path.name, + source_mode="average", + solvent_mode="input", + ), + ), + fft_settings=ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + solvent_density_e_per_a3=0.334, + ).normalized(), + legacy_mesh_settings=None, + legacy_smearing_settings=None, + legacy_fourier_settings=None, + active_contrast_settings=ContrastSolventDensitySettings.from_values( + method=CONTRAST_SOLVENT_METHOD_DIRECT, + direct_electron_density_e_per_a3=0.334, + ), + active_contrast_name="Direct solvent", + q_min=0.01, + q_max=1.20, + q_step=0.01, + compare_legacy_1d=True, + compare_exact_debye=False, + ) + payloads: list[object] = [] + worker.finished.connect(payloads.append) + + worker.run() + + assert len(payloads) == 1 + assert isinstance(payloads[0], _FFTComputationPayload) + assert captured["contrast_name"] == "Direct solvent" + assert captured["use_solvent_subtracted_profile"] is True + + +def test_fft_born_workspace_state_restores_representative_profile_results( + qapp, + tmp_path, + monkeypatch, +): + del qapp + structure_path = tmp_path / "representative.xyz" + structure_path.write_text( + "\n".join( + [ + "3", + "representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + ] + ) + + "\n", + encoding="utf-8", + ) + q_values = np.asarray([0.01, 0.02, 0.03], dtype=float) + fft_result = ContrastFFTResult( + settings=ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + ).normalized(), + q_values=q_values, + raw_intensity=np.asarray([1.0, 0.9, 0.75], dtype=float), + kernel_corrected_intensity=np.asarray([1.0, 0.94, 0.81], dtype=float), + q_shell_counts=np.asarray([8, 12, 16], dtype=int), + density_integral=180.0, + expected_weight=180.0, + contrast_density_integral=180.0, + expected_contrast_weight=180.0, + solvent_exclusion_volume_a3=0.0, + grid_shape=(257, 257, 257), + box_lengths_a=(642.5, 642.5, 642.5), + voxel_spacing_a=(2.5, 2.5, 2.5), + q_nyquist_a_inverse=float(np.pi / 2.5), + q_frequency_step_a_inverse=( + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + ), + q_convention="q = 2πf", + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + first_nonempty_q_a_inverse=0.01, + solvent_density_e_per_a3=0.0, + contrast_mode="bare_atomic_density_only", + kernel_correction_supported=True, + kernel_correction_applied=True, + kernel_correction_model="Gaussian deposition intensity factor exp(-sigma^2 q^2)", + timing=ContrastFFTTiming( + atomic_density_seconds=0.1, + contrast_density_seconds=0.0, + fft_seconds=0.2, + shell_average_seconds=0.3, + total_seconds=0.6, + ), + ) + average_target = _FFTProfileTarget( + key="average|input|representative|no_motif", + display_name="representative", + structure_name="representative", + motif_name="no_motif", + file_count=1, + reference_file=structure_path, + source_files=(structure_path,), + representative=structure_path.name, + source_mode="average", + solvent_mode="input", + ) + representative_target = _FFTProfileTarget( + key="representative|partial|representative|no_motif", + display_name="representative", + structure_name="representative", + motif_name="no_motif", + file_count=1, + reference_file=structure_path, + source_files=(structure_path,), + representative=structure_path.name, + source_mode="representative", + solvent_mode="partial", + ) + + def _mock_targets(_self, _path): + return { + ("average", "input"): (average_target,), + ("representative", "partial"): (representative_target,), + } + + monkeypatch.setattr( + FFTBornApproximationMainWindow, + "_resolve_available_profile_targets", + _mock_targets, + ) + + output_dir = tmp_path / "fft_state" + project_dir = tmp_path / "project" + project_dir.mkdir() + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_output_dir=output_dir, + ) + window._load_input_path(structure_path) + rep_index = window.structure_source_combo.findData("representative") + window.structure_source_combo.setCurrentIndex(rep_index) + window.direct_density_spin.setValue(0.334) + direct_index = window.solvent_method_combo.findData( + CONTRAST_SOLVENT_METHOD_DIRECT + ) + window.solvent_method_combo.setCurrentIndex(direct_index) + window._apply_contrast_settings() + window._on_worker_finished( + _FFTComputationPayload( + q_values=q_values, + profile_results=( + _FFTProfileComputationResult( + target=representative_target, + q_values=q_values, + fft_result=fft_result, + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ), + ), + ) + ) + state_path = output_dir / "workspace_state.json" + assert state_path.is_file() + window.close() + + restored_window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_output_dir=output_dir, + ) + restored_window._load_input_path(structure_path) + + assert ( + restored_window.structure_source_combo.currentData() + == "representative" + ) + assert ( + restored_window.representative_solvent_mode_combo.currentData() + == "partial" + ) + assert restored_window._active_profile_key == representative_target.key + assert ( + representative_target.key in restored_window._computed_profile_results + ) + assert restored_window._active_contrast_name == "Direct solvent" + assert restored_window._current_payload is not None + assert state_path.is_file() + restored_window.close() + + +def test_fft_born_representative_input_can_switch_to_project_average_clusters( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + clusters_dir = project_dir / "clusters" + structure_dir = clusters_dir / "A" + structure_dir.mkdir(parents=True) + frame_paths = [] + for index in range(2): + frame_path = structure_dir / f"frame_{index + 1:04d}.xyz" + frame_path.write_text( + "3\nframe\nPb 0.0 0.0 0.0\nI 2.8 0.0 0.0\nI 0.0 2.8 0.0\n", + encoding="utf-8", + ) + frame_paths.append(frame_path.resolve()) + settings.clusters_dir = str(clusters_dir) + settings.use_representative_structures = True + manager.save_project(settings) + + rmcsetup_paths = build_rmcsetup_paths(project_dir) + representative_dir = ( + rmcsetup_paths.representative_partial_solvent_dir / "A" + ) + representative_dir.mkdir(parents=True) + representative_path = representative_dir / "selected_rep.xyz" + representative_path.write_text( + frame_paths[0].read_text(encoding="utf-8"), + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + metadata = RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=1.0, + cluster_count=2, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=3, + element_counts={"Pb": 1, "I": 2}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ) + _write_representative_selection_metadata(project_dir, metadata) + project_source = SimpleNamespace( + representative_selection=metadata, + solvent_handling=None, + ) + + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_use_representative_structures=True, + ) + monkeypatch.setattr(window, "_project_source", lambda: project_source) + window._load_input_path(rmcsetup_paths.representative_clusters_dir) + + assert window.structure_source_combo.currentData() == "representative" + assert ("average", "input") in window._available_profile_targets + average_index = window.structure_source_combo.findData("average") + window.structure_source_combo.setCurrentIndex(average_index) + target = window._active_profile_target() + + assert target is not None + assert target.source_mode == "average" + assert target.file_count == 2 + assert target.source_files == tuple(frame_paths) + window.close() + + +def _make_fft_result_for_test( + q_values: np.ndarray, + raw_intensity: np.ndarray, +) -> ContrastFFTResult: + fft_settings = ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + ).normalized() + return ContrastFFTResult( + settings=fft_settings, + q_values=np.asarray(q_values, dtype=float), + raw_intensity=np.asarray(raw_intensity, dtype=float), + kernel_corrected_intensity=np.asarray(raw_intensity, dtype=float), + q_shell_counts=np.ones_like(q_values, dtype=int), + density_integral=10.0, + expected_weight=10.0, + contrast_density_integral=10.0, + expected_contrast_weight=10.0, + solvent_exclusion_volume_a3=0.0, + grid_shape=(257, 257, 257), + box_lengths_a=(642.5, 642.5, 642.5), + voxel_spacing_a=(2.5, 2.5, 2.5), + q_nyquist_a_inverse=float(np.pi / 2.5), + q_frequency_step_a_inverse=( + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + ), + q_convention="q = 2πf", + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + first_nonempty_q_a_inverse=float(q_values[0]), + solvent_density_e_per_a3=0.0, + contrast_mode="bare_atomic_density_only", + kernel_correction_supported=True, + kernel_correction_applied=True, + kernel_correction_model=( + "Gaussian deposition intensity factor exp(-sigma^2 q^2)" + ), + timing=ContrastFFTTiming( + atomic_density_seconds=0.1, + contrast_density_seconds=0.0, + fft_seconds=0.2, + shell_average_seconds=0.3, + total_seconds=0.6, + ), + ) + + +def test_fft_restored_traces_require_recompute_after_source_or_contrast_change( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_representative_structures = True + manager.save_project(settings) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + + q_values = 0.05 + 0.05 * np.arange(6, dtype=float) + average_path = tmp_path / "average.xyz" + representative_path = tmp_path / "representative.xyz" + structure_text = ( + "3\ncomponent\nPb 0.0 0.0 0.0\nI 2.8 0.0 0.0\nI 0.0 2.8 0.0\n" + ) + average_path.write_text(structure_text, encoding="utf-8") + representative_path.write_text(structure_text, encoding="utf-8") + average_target = _FFTProfileTarget( + key="average|input|A|no_motif", + display_name="A", + structure_name="A", + motif_name="no_motif", + file_count=1, + reference_file=average_path, + source_files=(average_path,), + representative=average_path.name, + source_mode="average", + solvent_mode="input", + ) + representative_target = _FFTProfileTarget( + key="representative|partial|A|no_motif", + display_name="A", + structure_name="A", + motif_name="no_motif", + file_count=1, + reference_file=representative_path, + source_files=(representative_path,), + representative=representative_path.name, + source_mode="representative", + solvent_mode="partial", + ) + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_distribution_id=artifact_paths.distribution_id, + initial_distribution_root_dir=artifact_paths.root_dir, + initial_use_representative_structures=True, + ) + window._loaded_input_path = project_dir + window._available_profile_targets = { + ("average", "input"): (average_target,), + ("representative", "partial"): (representative_target,), + } + representative_index = window.structure_source_combo.findData( + "representative" + ) + window.structure_source_combo.setCurrentIndex(representative_index) + window.q_min_spin.setValue(float(q_values[0])) + window.q_max_spin.setValue(float(q_values[-1])) + window.q_step_spin.setValue(0.05) + window._computed_profile_results = { + representative_target.key: _FFTProfileComputationResult( + target=representative_target, + q_values=q_values, + fft_result=_make_fft_result_for_test( + q_values, + np.linspace(11.0, 18.0, q_values.size), + ), + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ) + } + window._computed_profile_run_signature = ( + window._current_trace_configuration_signature() + ) + window._update_push_to_model_state() + assert window.push_to_model_button.isEnabled() + + average_index = window.structure_source_combo.findData("average") + window.structure_source_combo.setCurrentIndex(average_index) + assert not window.push_to_model_button.isEnabled() + + window._on_worker_finished( + _FFTComputationPayload( + q_values=q_values, + profile_results=( + _FFTProfileComputationResult( + target=average_target, + q_values=q_values, + fft_result=_make_fft_result_for_test( + q_values, + np.linspace(21.0, 28.0, q_values.size), + ), + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ), + ), + ) + ) + assert window.push_to_model_button.isEnabled() + + direct_index = window.solvent_method_combo.findData( + CONTRAST_SOLVENT_METHOD_DIRECT + ) + window.solvent_method_combo.setCurrentIndex(direct_index) + window.direct_density_spin.setValue(0.334) + + assert not window.push_to_model_button.isEnabled() + window.close() + + +def test_fft_push_to_model_records_representative_component_source( + qapp, + tmp_path, + monkeypatch, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_representative_structures = True + manager.save_project(settings) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + q_values = np.asarray([0.01, 0.02, 0.03], dtype=float) + fft_result = ContrastFFTResult( + settings=ContrastFFTSettings( + spacing_a=2.5, + gaussian_sigma_a=0.75, + minimum_box_length_a=640.0, + padding_a=24.0, + ).normalized(), + q_values=q_values, + raw_intensity=np.asarray([1.0, 0.9, 0.8], dtype=float), + kernel_corrected_intensity=np.asarray([1.0, 0.9, 0.8], dtype=float), + q_shell_counts=np.asarray([8, 12, 16], dtype=int), + density_integral=10.0, + expected_weight=10.0, + contrast_density_integral=10.0, + expected_contrast_weight=10.0, + solvent_exclusion_volume_a3=0.0, + grid_shape=(257, 257, 257), + box_lengths_a=(642.5, 642.5, 642.5), + voxel_spacing_a=(2.5, 2.5, 2.5), + q_nyquist_a_inverse=float(np.pi / 2.5), + q_frequency_step_a_inverse=( + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + float(2.0 * np.pi / 642.5), + ), + q_convention="q = 2πf", + uses_two_pi_frequency_conversion=True, + density_subtraction_active=False, + first_nonempty_q_a_inverse=0.01, + solvent_density_e_per_a3=0.0, + contrast_mode="bare_atomic_density_only", + kernel_correction_supported=True, + kernel_correction_applied=True, + kernel_correction_model="Gaussian deposition intensity factor exp(-sigma^2 q^2)", + timing=ContrastFFTTiming( + atomic_density_seconds=0.1, + contrast_density_seconds=0.0, + fft_seconds=0.2, + shell_average_seconds=0.3, + total_seconds=0.6, + ), + ) + target = _FFTProfileTarget( + key="representative|partial|A|no_motif", + display_name="A", + structure_name="A", + motif_name="no_motif", + file_count=1, + reference_file=tmp_path / "representative.xyz", + source_files=(tmp_path / "representative.xyz",), + representative="representative.xyz", + source_mode="representative", + solvent_mode="partial", + ) + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_distribution_id=artifact_paths.distribution_id, + initial_distribution_root_dir=artifact_paths.root_dir, + initial_use_representative_structures=True, + ) + window._loaded_input_path = project_dir + window._available_profile_targets = { + ("representative", "partial"): (target,) + } + window._current_profile_targets = (target,) + window._active_profile_key = target.key + window._computed_profile_results = { + target.key: _FFTProfileComputationResult( + target=target, + q_values=q_values, + fft_result=fft_result, + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ) + } + representative_index = window.structure_source_combo.findData( + "representative" + ) + window.structure_source_combo.setCurrentIndex(representative_index) + monkeypatch.setattr( + FFTBornApproximationMainWindow, + "_ensure_linked_distribution_ready_for_push", + lambda self: None, + ) + + window._push_components_to_model() + + payload = json.loads( + artifact_paths.distribution_metadata_file.read_text(encoding="utf-8") + ) + assert payload["built_component_source_mode"] == "representative" + assert payload["use_representative_structures"] is True + assert artifact_paths.component_map_file.is_file() + window.close() + + +def test_fft_restored_trace_requires_recompute_when_endpoint_is_missing( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings, artifact_paths = _seed_saved_distribution_from_root( + project_dir, + component_build_mode=COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + ) + manager.save_project(settings) + q_values = 0.0101 + 0.01 * np.arange(119, dtype=float) + target = _FFTProfileTarget( + key="average|input|A|no_motif", + display_name="A", + structure_name="A", + motif_name="no_motif", + file_count=1, + reference_file=tmp_path / "frame_0001.xyz", + source_files=(tmp_path / "frame_0001.xyz",), + representative="frame_0001.xyz", + source_mode="average", + solvent_mode="input", + ) + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_distribution_id=artifact_paths.distribution_id, + initial_distribution_root_dir=artifact_paths.root_dir, + initial_use_representative_structures=False, + ) + window._loaded_input_path = project_dir + window.q_min_spin.setValue(0.0101) + window.q_max_spin.setValue(1.1976) + window.q_step_spin.setValue(0.01) + window._computed_profile_results = { + target.key: _FFTProfileComputationResult( + target=target, + q_values=q_values, + fft_result=_make_fft_result_for_test( + q_values, + np.linspace(21.0, 28.0, q_values.size), + ), + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ) + } + window._computed_profile_run_signature = ( + window._current_trace_configuration_signature() + ) + window._update_push_to_model_state() + + assert not window._results_match_current_configuration() + assert not window.push_to_model_button.isEnabled() + window.close() + + +def test_fft_push_to_model_registers_components_with_prefit_and_dream( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings, artifact_paths = _seed_saved_distribution_from_root( + project_dir, + component_build_mode=COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT, + ) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.use_representative_structures = False + manager.save_project(settings) + + q_values = np.linspace(0.05, 0.3, 8) + pushed_intensity = np.linspace(21.0, 28.0, 8) + structure_path = tmp_path / "frame_0001.xyz" + structure_path.write_text( + "3\nframe_0001\nPb 0.0 0.0 0.0\nI 2.8 0.0 0.0\nI 0.0 2.8 0.0\n", + encoding="utf-8", + ) + target = _FFTProfileTarget( + key="average|input|A|no_motif", + display_name="A", + structure_name="A", + motif_name="no_motif", + file_count=1, + reference_file=structure_path, + source_files=(structure_path,), + representative=structure_path.name, + source_mode="average", + solvent_mode="input", + ) + window = FFTBornApproximationMainWindow( + preview_mode=False, + initial_project_dir=project_dir, + initial_distribution_id=artifact_paths.distribution_id, + initial_distribution_root_dir=artifact_paths.root_dir, + initial_use_representative_structures=False, + ) + main_window = SAXSMainWindow(initial_project_dir=project_dir) + window.born_components_built.connect(main_window._on_born_components_built) + window._loaded_input_path = project_dir + window._computed_profile_results = { + target.key: _FFTProfileComputationResult( + target=target, + q_values=q_values, + fft_result=_make_fft_result_for_test(q_values, pushed_intensity), + legacy_q_values=None, + legacy_intensity=None, + exact_debye_intensity=None, + legacy_elapsed_seconds=None, + debye_elapsed_seconds=None, + ) + } + + window._push_components_to_model() + QApplication.processEvents() + + saved_settings = manager.load_project(project_dir) + assert ( + saved_settings.component_build_mode + == COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + assert saved_settings.use_representative_structures is False + assert main_window.prefit_workflow is not None + assert [ + (component.structure, component.motif) + for component in main_window.prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + main_window.prefit_workflow.components[0].intensities, + pushed_intensity, + ) + assert main_window.dream_workflow is not None + assert [ + (component.structure, component.motif) + for component in main_window.dream_workflow.prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + main_window.dream_workflow.prefit_workflow.components[0].intensities, + pushed_intensity, + ) + + prefit_workflow = SAXSPrefitWorkflow(project_dir) + assert [ + (component.structure, component.motif) + for component in prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + prefit_workflow.components[0].intensities, + pushed_intensity, + ) + assert np.allclose(prefit_workflow.evaluate().q_values, q_values) + + dream_workflow = SAXSDreamWorkflow(project_dir) + assert [ + (component.structure, component.motif) + for component in dream_workflow.prefit_workflow.components + ] == [("A", "no_motif")] + assert np.allclose( + dream_workflow.prefit_workflow.components[0].intensities, + pushed_intensity, + ) + assert np.allclose( + dream_workflow.prefit_workflow.evaluate().q_values, + q_values, + ) + main_window.close() + window.close() + + +def test_saved_distribution_roundtrips_representative_structure_preference( + tmp_path, +): + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.use_representative_structures = True + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + manager.save_project(settings) + artifact_paths = project_artifact_paths( + settings, + storage_mode="distribution", + allow_legacy_fallback=False, + ) + manager.ensure_artifact_dirs(artifact_paths) + artifact_paths.component_dir.mkdir(parents=True, exist_ok=True) + (artifact_paths.component_dir / "A_no_motif.txt").write_text( + "0.01 1.0 0.0 0.0\n", + encoding="utf-8", + ) + artifact_paths.component_map_file.write_text( + json.dumps({"saxs_map": {"A": {"no_motif": "A_no_motif.txt"}}}) + "\n", + encoding="utf-8", + ) + artifact_paths.prior_weights_file.write_text( + json.dumps({"A_no_motif.txt": 1.0}) + "\n", + encoding="utf-8", + ) + manager._write_distribution_metadata( + settings, artifact_paths=artifact_paths + ) + + loaded_settings = manager.settings_for_saved_distribution( + project_dir, + artifact_paths.distribution_id, + ) + + assert loaded_settings.use_representative_structures is True + + +def test_load_saved_distribution_checks_representative_toggle_for_representative_components( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.use_representative_structures = False + manager.save_project(settings) + _settings, artifact_paths = _seed_saved_distribution_from_root(project_dir) + + representative_dir = build_rmcsetup_paths( + project_dir + ).representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "selected_rep.xyz" + representative_path.write_text( + "\n".join( + [ + "4", + "selected representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", + ] + ) + + "\n", + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + _write_representative_selection_metadata( + project_dir, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=0.6, + cluster_count=1, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ), + ) + + metadata_path = artifact_paths.distribution_metadata_file + assert metadata_path is not None + metadata_payload = json.loads(metadata_path.read_text(encoding="utf-8")) + metadata_payload["use_representative_structures"] = False + metadata_payload["built_component_source_mode"] = "representative" + metadata_path.write_text( + json.dumps(metadata_payload, indent=2) + "\n", + encoding="utf-8", + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + assert not window.project_setup_tab.use_representative_structures() + + window.project_setup_tab.load_distribution_button.click() + QApplication.processEvents() + + assert window.current_settings is not None + assert window.current_settings.use_representative_structures is True + assert window.project_setup_tab.use_representative_structures() + assert "Representative Structures mode is on." in ( + window.project_setup_tab.representative_structure_status_label.text() + ) + window.close() + + +def test_generate_prior_weights_keeps_full_distribution_when_representatives_selected( + tmp_path, +): + project_dir, paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + + clusters_dir = project_dir / "clusters" + structure_dir = clusters_dir / "A" + structure_dir.mkdir(parents=True, exist_ok=True) + frame_a = structure_dir / "frame_0001.xyz" + frame_b = structure_dir / "frame_0002.xyz" + for frame_path in (frame_a, frame_b): + frame_path.write_text( + "\n".join( + [ + "4", + frame_path.stem, + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", + ] + ) + + "\n", + encoding="utf-8", + ) + + representative_dir = build_rmcsetup_paths( + project_dir + ).representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "selected_rep.xyz" + representative_path.write_text( + frame_a.read_text(encoding="utf-8"), + encoding="utf-8", + ) + + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + _write_representative_selection_metadata( + project_dir, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=1.0, + cluster_count=2, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ), + ) + + settings.clusters_dir = str(clusters_dir) + settings.use_representative_structures = True + result = manager.generate_prior_weights(settings) + prior_payload = json.loads( + result.md_prior_weights_path.read_text(encoding="utf-8") + ) + prior_entry = prior_payload["structures"]["A"]["no_motif"] + + assert prior_payload["origin"] == clusters_dir.name + assert prior_entry["source_kind"] == "cluster_dir" + assert Path(prior_entry["source_dir"]).resolve() == structure_dir.resolve() + assert Path(prior_entry["source_file"]).resolve() == frame_a.resolve() + + +def test_representative_checkbox_stays_disabled_until_all_bins_are_computed( + qapp, + tmp_path, +): + del qapp + project_dir, paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.use_representative_structures = True + manager.save_project(settings) + (paths.project_dir / "md_prior_weights.json").write_text( + json.dumps( + { + "origin": "clusters", + "total_files": 2, + "structures": { + "A": { + "no_motif": { + "count": 1, + "weight": 0.6, + "representative": "frame_0001.xyz", + "profile_file": "A_no_motif.txt", + } + }, + "B": { + "no_motif": { + "count": 1, + "weight": 0.4, + "representative": "frame_0001.xyz", + "profile_file": "B_no_motif.txt", + } + }, + }, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + representative_dir = build_rmcsetup_paths( + project_dir + ).representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "partial_rep.xyz" + representative_path.write_text( + "\n".join( + [ + "4", + "partial representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", + ] + ) + + "\n", + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + _write_representative_selection_metadata( + project_dir, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=0.6, + cluster_count=1, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[ + RepresentativeSelectionIssue( + structure="B", + motif="no_motif", + param="w1", + message="Representative structure has not been computed yet.", + ) + ], + invalid_bins=[], + ), + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert not window.project_setup_tab.use_representative_structures() + assert ( + not window.project_setup_tab.use_representative_structures_checkbox.isEnabled() + ) + assert "Saved 1 of 2 required representative structures." in ( + window.project_setup_tab.representative_structure_status_label.text() + ) + window.close() + + +def test_representative_checkbox_uses_prior_histogram_bins_for_readiness( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.use_representative_structures = True + manager.save_project(settings) + + representative_dir = build_rmcsetup_paths( + project_dir + ).representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "partial_rep.xyz" + representative_path.write_text( + "\n".join( + [ + "4", + "partial representative", + "Pb 0.0 0.0 0.0", + "I 2.8 0.0 0.0", + "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", + ] + ) + + "\n", + encoding="utf-8", + ) + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", + ) + _write_representative_selection_metadata( + project_dir, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=1.0, + cluster_count=1, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[ + RepresentativeSelectionIssue( + structure="B", + motif="no_motif", + param="w1", + message="Stale missing bin from a previous distribution.", + ) + ], + invalid_bins=[], + ), + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert window.project_setup_tab.use_representative_structures() + assert ( + window.project_setup_tab.use_representative_structures_checkbox.isEnabled() + ) + assert "Representative Structures mode is on." in ( + window.project_setup_tab.representative_structure_status_label.text() ) window.close() -def test_build_components_in_born_approximation_launches_electron_density_workflow( +def test_representative_toggle_keeps_prior_histogram_and_hides_stale_average_components( qapp, tmp_path, - monkeypatch, ): del qapp project_dir, _paths = _build_minimal_saxs_project(tmp_path) - window = SAXSMainWindow(initial_project_dir=project_dir) - launched: list[dict[str, object]] = [] - start_calls: list[str] = [] - clusters_dir = tmp_path / "clusters" - (clusters_dir / "PbI2").mkdir(parents=True) - (clusters_dir / "PbI2" / "frame_0001.xyz").write_text( + _settings, artifact_paths = _seed_saved_distribution_from_root(project_dir) + + representative_dir = build_rmcsetup_paths( + project_dir + ).representative_partial_solvent_dir + representative_dir.mkdir(parents=True, exist_ok=True) + representative_path = representative_dir / "selected_rep.xyz" + representative_path.write_text( "\n".join( [ - "3", - "frame_0001", + "4", + "selected representative", "Pb 0.0 0.0 0.0", "I 2.8 0.0 0.0", "I 0.0 2.8 0.0", + "I 0.0 0.0 2.8", ] ) + "\n", encoding="utf-8", ) - window.project_setup_tab.clusters_dir_edit.setText( - str(clusters_dir.resolve()) - ) - window.project_setup_tab.set_component_build_mode( - COMPONENT_BUILD_MODE_BORN_APPROXIMATION - ) - monkeypatch.setattr( - window, - "_confirm_default_q_range_for_component_build", - lambda: True, - ) - monkeypatch.setattr( - window, - "_start_project_task", - lambda task_name, *args, **kwargs: start_calls.append(task_name), + selection = DreamBestFitSelection( + run_name="Representative Structure Finder", + run_relative_path="rmcsetup/representative_structures", + label="Representative Structure Finder", + selection_source="representativefinder", + selected_at="2026-05-06T12:00:00", ) - monkeypatch.setattr( - window, - "_open_electron_density_mapping_tool", - lambda **kwargs: launched.append(kwargs), + _write_representative_selection_metadata( + project_dir, + RepresentativeSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + distribution_selection=DistributionSelectionMetadata( + selection_mode="representative_finder", + selection=selection, + run_dir="rmcsetup/representative_structures", + updated_at="2026-05-06T12:00:00", + entries=[], + ), + settings=RepresentativeSelectionSettings(), + updated_at="2026-05-06T12:00:00", + representative_entries=[ + RepresentativeSelectionEntry( + structure="A", + motif="no_motif", + param="w0", + selected_weight=1.0, + cluster_count=1, + source_dir=str(representative_dir), + source_file=str(representative_path), + source_file_name=representative_path.name, + atom_count=4, + element_counts={"Pb": 1, "I": 3}, + source_solvent_mode="partialsolv", + ) + ], + missing_bins=[], + invalid_bins=[], + ), ) - window.build_project_components() + window = SAXSMainWindow(initial_project_dir=project_dir) - assert not start_calls - assert launched == [{"preview_mode": False}] - assert window.current_settings is not None + assert window.project_setup_tab._component_paths is not None + assert len(window.project_setup_tab._component_paths) == 1 assert ( - window.current_settings.component_build_mode - == COMPONENT_BUILD_MODE_BORN_APPROXIMATION + window.project_setup_tab.current_prior_json_path() + == artifact_paths.prior_weights_file.resolve() + ) + + window.project_setup_tab.use_representative_structures_checkbox.setChecked( + True ) + + assert window.project_setup_tab.use_representative_structures() + assert window.project_setup_tab._component_paths is None assert ( - "Born Approximation (Average)" - in window.project_setup_tab.summary_box.toPlainText() + window.project_setup_tab.current_prior_json_path() + == artifact_paths.prior_weights_file.resolve() + ) + assert all( + str(row.get("source_kind", "")) == "cluster_dir" + for row in window.project_setup_tab.recognized_cluster_rows() ) window.close() +def test_saved_distribution_details_show_component_source_preference( + qapp, + tmp_path, +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.use_representative_structures = True + manager.save_project(settings) + _settings, artifact_paths = _seed_saved_distribution_from_root(project_dir) + + window = SAXSMainWindow(initial_project_dir=project_dir) + details = window.project_setup_tab._distribution_details[ + artifact_paths.distribution_id + ] + + assert "Component source preference: Representative Structures" in details + assert "built from Representative Structures" in details + window.close() + + +def test_fft_born_close_cancels_running_calculation_without_blocking(qapp): + del qapp + window = FFTBornApproximationMainWindow(preview_mode=True) + calls: list[object] = [] + + class DummyWorker: + def cancel(self): + calls.append("cancel") + + class DummyThread: + def __init__(self): + self.running = True + + def isRunning(self): + return self.running + + def quit(self): + calls.append("quit") + self.running = False + + def wait(self, timeout_ms): + calls.append(("wait", timeout_ms)) + return True + + window._compute_worker = DummyWorker() + window._compute_thread = DummyThread() + + assert window.close() is True + assert calls == ["cancel", "quit", ("wait", 1000)] + + def test_saved_distributions_can_coexist_and_load_by_component_build_mode( qapp, tmp_path, @@ -12756,6 +15653,52 @@ def test_save_project_state_ignores_tiny_q_range_edge_mismatch( window.close() +def test_save_project_state_allows_legacy_3d_fft_one_step_endpoint_gap( + qapp, tmp_path, monkeypatch +): + del qapp + project_dir, paths = _build_minimal_saxs_project(tmp_path) + q_values = 0.0101 + 0.01 * np.arange(118, dtype=float) + _write_component_file( + paths.scattering_components_dir / "A_no_motif.txt", + q_values, + np.linspace(10.0, 17.0, q_values.size), + ) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + settings.component_build_mode = ( + COMPONENT_BUILD_MODE_BORN_APPROXIMATION_3D_FFT + ) + settings.model_only_mode = True + settings.q_min = 0.0101 + settings.q_max = 1.19 + manager.save_project(settings) + window = SAXSMainWindow(initial_project_dir=project_dir) + captured: dict[str, str] = {} + + monkeypatch.setattr( + window, + "_show_error", + lambda title, message: captured.update( + { + "title": title, + "message": message, + } + ), + ) + + window.project_setup_tab.qmin_edit.setText("0.0101") + window.project_setup_tab.qmax_edit.setText("1.19") + window.save_project_state() + + assert captured == {} + assert window.prefit_workflow is not None + evaluation_q_values = window.prefit_workflow.evaluate().q_values + assert evaluation_q_values[0] == pytest.approx(q_values[0]) + assert evaluation_q_values[-1] == pytest.approx(q_values[-1]) + window.close() + + def test_project_setup_preview_plots_solvent_data_in_green(qapp, tmp_path): del qapp experimental_path = tmp_path / "exp_preview.txt" @@ -12884,7 +15827,7 @@ def test_project_setup_component_overlay_uses_secondary_y_axis(qapp, tmp_path): experimental_axis.get_title() == "Experimental Data and SAXS Components" ) - assert experimental_axis.get_xlabel() == "q (Å⁻¹)" + assert experimental_axis.get_xlabel() == Q_A_INVERSE_LABEL assert ( experimental_axis.get_ylabel() == "Experimental Intensity (arb. units)" ) @@ -12895,6 +15838,89 @@ def test_project_setup_component_overlay_uses_secondary_y_axis(qapp, tmp_path): ) +def test_project_setup_component_plot_editor_updates_dual_axis_plot( + qapp, tmp_path +): + q_values = np.asarray([0.05, 0.08, 0.12, 0.18], dtype=float) + data_path = tmp_path / "exp_overlay_editor.txt" + np.savetxt( + data_path, + np.column_stack( + [ + q_values, + np.asarray([100.0, 80.0, 55.0, 30.0], dtype=float), + ] + ), + ) + component_path = tmp_path / "A_no_motif.txt" + _write_component_file( + component_path, + q_values, + np.asarray([4.8, 4.6, 4.4, 4.2], dtype=float), + ) + + summary = load_experimental_data_file(data_path) + tab = ProjectSetupTab() + tab.set_project_selected(True) + tab._apply_experimental_file(data_path, summary) + tab.draw_component_plot([component_path]) + + tab.open_component_plot_editor() + qapp.processEvents() + + controls = tab._component_plot_editor_controls + assert controls is not None + assert not controls.secondary_y_label_edit.isHidden() + assert not controls.primary_axis_label_font_spin.isHidden() + assert not controls.primary_tick_label_font_spin.isHidden() + assert not controls.secondary_axis_label_font_spin.isHidden() + assert not controls.secondary_tick_label_font_spin.isHidden() + assert controls.residual_y_label_edit.isHidden() + + controls.title_edit.setText("Custom SAXS Overlay") + controls.primary_axis_label_font_spin.setValue(18.5) + controls.primary_tick_label_font_spin.setValue(14.5) + qapp.processEvents() + + experimental_axis, component_axis = tab.component_figure.axes + assert controls.secondary_axis_label_font_spin.value() == pytest.approx( + 11.0 + ) + assert controls.secondary_tick_label_font_spin.value() == pytest.approx( + 9.0 + ) + assert experimental_axis.yaxis.label.get_fontsize() == pytest.approx(18.5) + experimental_ticklabels = experimental_axis.get_yticklabels() + assert experimental_ticklabels + assert experimental_ticklabels[0].get_fontsize() == pytest.approx(14.5) + component_ticklabels = component_axis.get_yticklabels() + assert component_ticklabels + assert component_ticklabels[0].get_fontsize() == pytest.approx(9.0) + + controls.secondary_axis_label_font_spin.setValue(17.5) + controls.secondary_tick_label_font_spin.setValue(13.5) + qapp.processEvents() + + label_item = controls.label_table.item(1, 2) + assert label_item is not None + label_item.setText("Component Average") + qapp.processEvents() + + experimental_axis, component_axis = tab.component_figure.axes + legend = experimental_axis.get_legend() + assert legend is not None + legend_labels = [text.get_text() for text in legend.get_texts()] + + assert experimental_axis.get_title() == "Custom SAXS Overlay" + assert "Component Average" in legend_labels + assert experimental_axis.yaxis.label.get_fontsize() == pytest.approx(18.5) + assert component_axis.get_ylabel() == "Model Intensity (arb. units)" + assert component_axis.yaxis.label.get_fontsize() == pytest.approx(17.5) + component_ticklabels = component_axis.get_yticklabels() + assert component_ticklabels + assert component_ticklabels[0].get_fontsize() == pytest.approx(13.5) + + def test_component_legend_toggle_hides_and_shows_legend(qapp, tmp_path): del qapp q_values = np.asarray([0.05, 0.08, 0.12, 0.18], dtype=float) @@ -13162,6 +16188,71 @@ def test_component_trace_color_scheme_persists_with_project_state( ) +def test_project_setup_plot_view_state_persists_with_project_state( + qapp, tmp_path +): + del qapp + project_dir, _paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + window = SAXSMainWindow(initial_project_dir=project_dir) + + window.project_setup_tab.component_log_x_checkbox.setChecked(False) + window.project_setup_tab.component_log_y_checkbox.setChecked(False) + window.project_setup_tab.component_legend_toggle_button.setChecked(False) + component_axes = window.project_setup_tab.component_figure.axes + assert len(component_axes) == 2 + component_axes[0].set_xlim(0.08, 0.18) + component_axes[0].set_ylim(0.045, 0.06) + component_axes[1].set_ylim(9.0, 15.0) + + prior_axes = window.project_setup_tab.prior_figure.axes + assert len(prior_axes) == 1 + prior_axes[0].set_xlim(-0.05, 0.75) + prior_axes[0].set_ylim(0.0, 120.0) + + window.save_project_state() + + saved_settings = manager.load_project(project_dir) + assert saved_settings.component_plot_state["axes"][0]["xlim"] == ( + pytest.approx([0.08, 0.18]) + ) + assert saved_settings.component_plot_state["axes"][1]["ylim"] == ( + pytest.approx([9.0, 15.0]) + ) + assert saved_settings.prior_plot_state["axes"][0]["xlim"] == ( + pytest.approx([-0.05, 0.75]) + ) + + reloaded_window = SAXSMainWindow(initial_project_dir=project_dir) + reloaded_component_axes = ( + reloaded_window.project_setup_tab.component_figure.axes + ) + reloaded_prior_axes = reloaded_window.project_setup_tab.prior_figure.axes + + assert ( + not reloaded_window.project_setup_tab.component_log_x_checkbox.isChecked() + ) + assert ( + not reloaded_window.project_setup_tab.component_log_y_checkbox.isChecked() + ) + assert ( + reloaded_window.project_setup_tab.component_legend_toggle_button.isChecked() + is False + ) + assert reloaded_component_axes[0].get_xscale() == "linear" + assert reloaded_component_axes[0].get_yscale() == "linear" + assert reloaded_component_axes[0].get_xlim() == pytest.approx((0.08, 0.18)) + assert reloaded_component_axes[0].get_ylim() == pytest.approx( + (0.045, 0.06) + ) + assert reloaded_component_axes[1].get_ylim() == pytest.approx((9.0, 15.0)) + assert reloaded_prior_axes[0].get_xlim() == pytest.approx((-0.05, 0.75)) + assert reloaded_prior_axes[0].get_ylim() == pytest.approx((0.0, 120.0)) + + window.close() + reloaded_window.close() + + def test_prefit_plot_shows_solvent_contribution_and_legend_pick_toggles_model( qapp, tmp_path ): @@ -13384,6 +16475,76 @@ def test_prefit_plot_trace_toggles_control_visible_series(qapp, tmp_path): assert "Experimental" not in line_labels +def test_prefit_plot_editor_updates_residual_plot_fields(qapp): + q_values = np.asarray([0.05, 0.08, 0.12, 0.18], dtype=float) + experimental = np.asarray([10.0, 8.0, 6.0, 4.0], dtype=float) + model = np.asarray([9.5, 7.9, 6.2, 4.3], dtype=float) + residuals = model - experimental + structure_factor = np.asarray([1.1, 1.0, 0.95, 0.9], dtype=float) + + tab = PrefitTab() + tab.show_structure_factor_trace_checkbox.setChecked(True) + evaluation = PrefitEvaluation( + q_values=q_values, + experimental_intensities=experimental, + model_intensities=model, + residuals=residuals, + structure_factor_trace=structure_factor, + ) + tab.plot_evaluation(evaluation) + tab.open_plot_editor() + qapp.processEvents() + + controls = tab._plot_editor_controls + assert controls is not None + assert not controls.secondary_y_label_edit.isHidden() + assert not controls.primary_axis_label_font_spin.isHidden() + assert not controls.primary_tick_label_font_spin.isHidden() + assert not controls.secondary_axis_label_font_spin.isHidden() + assert not controls.secondary_tick_label_font_spin.isHidden() + assert not controls.residual_y_label_edit.isHidden() + assert not controls.show_annotation_checkbox.isHidden() + + controls.residual_y_label_edit.setText("Model - Data") + controls.primary_axis_label_font_spin.setValue(15.0) + controls.primary_tick_label_font_spin.setValue(11.0) + qapp.processEvents() + + top_axis, bottom_axis, structure_axis = tab.figure.axes + assert controls.secondary_axis_label_font_spin.value() == pytest.approx( + 11.0 + ) + assert controls.secondary_tick_label_font_spin.value() == pytest.approx( + 9.0 + ) + assert top_axis.yaxis.label.get_fontsize() == pytest.approx(15.0) + top_ticklabels = top_axis.get_yticklabels() + assert top_ticklabels + assert top_ticklabels[0].get_fontsize() == pytest.approx(11.0) + bottom_ticklabels = bottom_axis.get_yticklabels() + assert bottom_ticklabels + assert bottom_ticklabels[0].get_fontsize() == pytest.approx(11.0) + structure_ticklabels = structure_axis.get_yticklabels() + assert structure_ticklabels + assert structure_ticklabels[0].get_fontsize() == pytest.approx(9.0) + + controls.secondary_axis_label_font_spin.setValue(16.0) + controls.secondary_tick_label_font_spin.setValue(12.0) + controls.show_annotation_checkbox.setChecked(False) + qapp.processEvents() + + top_axis, bottom_axis, structure_axis = tab.figure.axes + + assert bottom_axis.get_ylabel() == "Model - Data" + assert top_axis.yaxis.label.get_fontsize() == pytest.approx(15.0) + assert structure_axis.get_ylabel() == "S(q)" + assert structure_axis.yaxis.label.get_fontsize() == pytest.approx(16.0) + structure_ticklabels = structure_axis.get_yticklabels() + assert structure_ticklabels + assert structure_ticklabels[0].get_fontsize() == pytest.approx(12.0) + assert not top_axis.texts + + def test_prefit_field_interaction_warns_before_components_are_built( qapp, tmp_path, monkeypatch ): @@ -14656,8 +17817,20 @@ def fake_build_profiles( manager.save_project(observed_settings) window = SAXSMainWindow(initial_project_dir=project_dir) + observed_label = project_module.distribution_label_for_settings( + observed_settings + ) + excluded_label = project_module.distribution_label_for_settings( + excluded_settings + ) assert window.project_setup_tab.computed_distribution_combo.count() == 2 + assert "DREAM fits: 1" in ( + window.project_setup_tab.computed_distribution_combo.currentText() + ) + assert window.project_setup_tab.active_distribution_text() == ( + observed_label + ) assert window.project_setup_tab.current_prior_json_path() == ( observed_artifacts.prior_weights_file ) @@ -14675,12 +17848,23 @@ def fake_build_profiles( ) ) assert target_index >= 0 + assert "DREAM fits: 1" in ( + window.project_setup_tab.computed_distribution_combo.itemText( + target_index + ) + ) window.project_setup_tab.computed_distribution_combo.setCurrentIndex( target_index ) + assert window.project_setup_tab.active_distribution_text() == ( + observed_label + ) window.project_setup_tab.load_distribution_button.click() QApplication.processEvents() + assert window.project_setup_tab.active_distribution_text() == ( + excluded_label + ) assert window.project_setup_tab.exclude_elements() == ["H"] assert window.project_setup_tab.current_prior_json_path() == ( excluded_artifacts.prior_weights_file @@ -14822,6 +18006,78 @@ def write_legacy_distribution( ) +def test_project_setup_saved_distribution_dropdown_shows_build_mode_for_legacy_metadata( + qapp, + tmp_path, +): + del qapp + project_dir, paths = _build_minimal_saxs_project(tmp_path) + manager = SAXSProjectManager() + settings = manager.load_project(project_dir) + legacy_distribution_id = project_module._distribution_id_for_settings( + settings, + include_template=True, + include_build_mode=False, + ) + distribution_dir = ( + build_project_paths(project_dir).saved_distributions_dir + / legacy_distribution_id + ) + component_dir = distribution_dir / "scattering_components" + component_dir.mkdir(parents=True, exist_ok=True) + shutil.copytree( + paths.scattering_components_dir, + component_dir, + dirs_exist_ok=True, + ) + shutil.copy2( + paths.project_dir / "md_saxs_map.json", + distribution_dir / "md_saxs_map.json", + ) + shutil.copy2( + paths.project_dir / "md_prior_weights.json", + distribution_dir / "md_prior_weights.json", + ) + (distribution_dir / "distribution.json").write_text( + json.dumps( + { + "schema_version": 1, + "distribution_id": legacy_distribution_id, + "label": ( + "Observed Only | Template: " + f"{settings.selected_model_template} | Excluded: None | " + "q-range: default | Grid: experimental grid" + ), + "template_name": settings.selected_model_template, + "use_predicted_structure_weights": False, + "exclude_elements": [], + "clusters_dir": None, + "q_min": None, + "q_max": None, + "use_experimental_grid": True, + "q_points": None, + "component_artifacts_ready": True, + "prior_artifacts_ready": True, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + window = SAXSMainWindow(initial_project_dir=project_dir) + + assert window.project_setup_tab.computed_distribution_combo.count() == 1 + assert "Build: No Contrast (Debye)" in ( + window.project_setup_tab.computed_distribution_combo.itemText(0) + ) + assert "Build: No Contrast (Debye)" in ( + window.project_setup_tab.active_distribution_text() + ) + + window.close() + + def test_project_setup_loads_saved_distribution_with_cropped_q_range( qapp, tmp_path, From c6f4152b79460c28e0bc4609d3210523676368c0 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:50:10 -0600 Subject: [PATCH 6/7] docs: document expanded SAXS and fullrmc workflows Refresh the docs landing pages, install/setup guides, and user-guide navigation for representative structures, 1D/3D Born workflows, Packmol Docker linking, and updated SAXS/fullrmc terminology. --- README.md | 117 ++++---- README.rst | 163 ++++------- docs/api/overview.md | 20 +- docs/development/contributing.md | 24 +- docs/development/saxs-contrast-mode.md | 2 +- docs/getting-started/installation.md | 183 ++++++------- docs/getting-started/project-setup.md | 94 +++++-- docs/getting-started/quickstart.md | 183 ++++++++----- docs/hooks.py | 47 ++++ docs/index.md | 60 ++-- docs/tutorials/example-workflow.md | 32 ++- docs/tutorials/md-to-saxs-pipeline.md | 33 ++- docs/user-guide/blender-structure-renderer.md | 22 +- docs/user-guide/cluster-dynamics-ml.md | 3 +- docs/user-guide/cluster-extraction.md | 9 +- docs/user-guide/debye-waller-analysis.md | 12 +- docs/user-guide/electron-density-mapping.md | 31 ++- docs/user-guide/fft-born-approximation.md | 246 +++++++++++++++++ docs/user-guide/fullrmc-packmol-docker.md | 258 ++++++++++++++++++ docs/user-guide/gui-overview.md | 115 ++++++-- docs/user-guide/preloaded-saxs-models.md | 129 ++++++--- docs/user-guide/project-configuration.md | 3 +- docs/user-guide/pydream-workflow.md | 7 + .../representative-structure-cli.md | 103 +++++++ docs/user-guide/saxs-prefit.md | 75 +++-- docs/user-guide/template-system.md | 13 +- docs/user-guide/xray-toolkit.md | 5 +- docs/user-guide/xyz2pdb-conversion.md | 12 +- mkdocs.yml | 16 +- 29 files changed, 1454 insertions(+), 563 deletions(-) create mode 100644 docs/hooks.py create mode 100644 docs/user-guide/fft-born-approximation.md create mode 100644 docs/user-guide/fullrmc-packmol-docker.md create mode 100644 docs/user-guide/representative-structure-cli.md diff --git a/README.md b/README.md index 3ea833b..5789376 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ # SAXSShell SAXSShell is a Python toolkit for simulation-driven scattering workflows. It -combines Qt applications, command-line tools, and reusable Python workflows for: +combines Qt applications, supporting tools, and reusable Python workflows for: - trajectory inspection and frame export - XYZ to PDB conversion using reference molecules @@ -19,10 +19,13 @@ Project documentation is published at: - https://kewh5868.github.io/SAXSShell/ -The documentation is organized by workflow rather than by source file, so the -best starting points are: +The documentation is organized by workflow rather than by source file. First +learn how to process your molecular dynamics trajectory into frames and +clusters, then create a project folder for the SAXSShell session: - [Getting Started](https://kewh5868.github.io/SAXSShell/getting-started/installation/) +- [MD Extraction and Cluster Preparation](https://kewh5868.github.io/SAXSShell/user-guide/cluster-extraction/) +- [Quickstart](https://kewh5868.github.io/SAXSShell/getting-started/quickstart/) - [XYZ to PDB Conversion](https://kewh5868.github.io/SAXSShell/user-guide/xyz2pdb-conversion/) - [SAXS Prefit](https://kewh5868.github.io/SAXSShell/user-guide/saxs-prefit/) - [pyDREAM Workflow](https://kewh5868.github.io/SAXSShell/user-guide/pydream-workflow/) @@ -30,31 +33,38 @@ best starting points are: ## Installation -Use Python 3.12 for the smoothest experience with the current Qt stack. - -If you want an isolated environment first: +SAXSShell is not pip-installable yet. Run it from a source checkout with the +repository conda environment file. ```bash -conda create -n saxshell-py312 python=3.12 +git clone https://github.com/kewh5868/SAXSShell.git +cd SAXSShell +conda env create -f requirements/saxshell-py312.yml ``` +If the `saxshell-py312` environment already exists, update it from the same +file: + ```bash -conda run --no-capture-output -n saxshell-py312 python -m pip install saxshell +conda env update -n saxshell-py312 -f requirements/saxshell-py312.yml --prune ``` -For editable local development: +Launch the main SAXSShell application from the repository root: ```bash -conda run --no-capture-output -n saxshell-py312 python -m pip install -e . +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` -You can also launch the code directly from a source checkout without installing -entry points: +## First Project -```bash -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs --help -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui -``` +Start by preparing the simulation data that the SAXS project will consume: + +1. Inspect the MD trajectory and export usable frames with `mdtrajectory`. +2. Convert frames with `xyz2pdb` only if residue-aware PDB files are needed. +3. Extract stoichiometry-sorted clusters with `clusters`. +4. Create a dedicated project folder for the SAXSShell session. +5. Open the SAXSShell application and choose that project folder in + **Project Setup** before building SAXS components. ## Docs Local Preview @@ -62,44 +72,45 @@ Install the pinned docs dependencies and start the local preview server from the repository root: ```bash -python -m pip install -r requirements/docs.txt -mkdocs serve +conda run --no-capture-output -n saxshell-py312 python -m pip install -r requirements/docs.txt +conda run --no-capture-output -n saxshell-py312 mkdocs serve ``` Then open `http://127.0.0.1:8000/`. -## CLI Entry Points - -The installed umbrella entry point is: - -```bash -conda run --no-capture-output -n saxshell-py312 saxshell --help -conda run --no-capture-output -n saxshell-py312 saxshell saxs --help -``` - -Standalone tools that install directly include: - -- `bondanalysis` -- `blenderxyz` -- `clusterdynamics` -- `clusterdynamicsml` -- `clusters` -- `mdtrajectory` -- `pdfsetup` -- `saxshell` -- `xyz2pdb` - -The SAXS and fullrmc interfaces are currently reached through the umbrella -command: - -```bash -saxshell saxs ui -saxshell fullrmc ui /path/to/project -``` - -From a source checkout, the equivalent module launches are: - -```bash -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.fullrmc ui /path/to/project -``` +## Standalone Tools + +These supporting tools can be used independently from the main SAXSShell +window, or opened from the main UI when a project-backed workflow is needed: + +- `mdtrajectory`: inspect MD trajectories, review optional CP2K energy data, + choose an equilibration cutoff, and export selected frames. +- `xyz2pdb`: convert extracted XYZ frames into residue-aware PDB files with + reference molecule definitions. +- `clusters`: extract stoichiometry-sorted cluster folders from exported XYZ + or PDB frames. +- `bondanalysis`: measure bond-pair and angle distributions from cluster + folders. +- `clusterdynamics`: build time-binned cluster population heatmaps and lifetime + summaries. +- `clusterdynamicsml`: extend observed cluster dynamics with predicted larger + structures and model-comparison outputs. +- `pdfsetup`: run Debyer-backed trajectory-averaged PDF and partial-PDF + calculations. +- `blenderxyz`: render publication-style structure images with Blender. +- `representativefinder`: select representative structures from project-backed + stoichiometry folders. +- `structureviewer`: inspect individual structure files in the SAXSShell + structure viewer. + +## External Applications + +The conda environment file installs the Python stack, but several optional +SAXSShell applications call external software that must be installed separately: + +| External software | Required by | Install / docs | +| ----------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| Debyer | `pdfsetup` PDF and partial-PDF calculations | [Debyer docs](https://debyer.readthedocs.io/en/latest/) and [Debyer GitHub](https://github.com/wojdyr/debyer) | +| Blender | `blenderxyz` structure rendering | [Blender download](https://www.blender.org/download/) and [Blender installation manual](https://docs.blender.org/manual/en/latest/getting_started/installing/index.html) | +| Packmol | `fullrmc` Packmol setup and solvent packing workflows | [Packmol GitHub](https://github.com/m3g/packmol) and [Packmol user guide](https://m3g.github.io/packmol/) | +| Docker | `fullrmc` Packmol Docker link workflow | [Get Docker](https://docs.docker.com/get-started/get-docker/) | diff --git a/README.rst b/README.rst index 8e2a3c1..356596b 100644 --- a/README.rst +++ b/README.rst @@ -43,10 +43,9 @@ molecular dynamics-derived liquid structures. scriptable workflows that connect atomistic simulation output to scattering observables and structural interpretation. -The project is being developed as a Python library with command-line -entry points and documentation for simulation-driven scattering -analysis, especially for liquid-state structure and solvation-focused -studies. +The project is being developed as a Python library with Qt applications, +supporting tools, and documentation for simulation-driven scattering analysis, +especially for liquid-state structure and solvation-focused studies. For more information about the saxshell library, please consult our `online documentation `_. @@ -60,51 +59,35 @@ If you use saxshell in a scientific publication, we would like you to cite this Installation ------------ -The preferred method is to use `Miniconda Python -`_ -and install from the "conda-forge" channel of Conda packages. +SAXSShell is not pip-installable yet. The current user-facing path is to clone +the repository and create the conda environment from the checked-in +``requirements/saxshell-py312.yml`` file. -To add "conda-forge" to the conda channels, run the following in a terminal. :: +From a terminal, run :: - conda config --add channels conda-forge + git clone https://github.com/kewh5868/SAXSShell.git + cd SAXSShell + conda env create -f requirements/saxshell-py312.yml -We want to install our packages in a suitable conda environment. -The following creates and activates a new environment named ``saxshell_env`` :: +If the environment already exists, update it with :: - conda create -n saxshell_env saxshell - conda activate saxshell_env + conda env update -n saxshell-py312 -f requirements/saxshell-py312.yml --prune -The output should print the latest version displayed on the badges above. +Launch the main SAXSShell application from the repository root with :: -If the above does not work, you can use ``pip`` to download and install the latest release from -`Python Package Index `_. -To install using ``pip`` into your ``saxshell_env`` environment, type :: + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs - pip install saxshell +You can also verify that the source checkout imports inside the conda +environment with :: -If you prefer to install from sources, after installing the dependencies, obtain the source archive from -`GitHub `_. Once installed, ``cd`` into your ``saxshell`` directory -and run the following :: - - pip install . - -This package also provides command-line utilities. To check the software has been installed correctly, type :: - - saxshell --version - -You can also type the following command to verify the installation. :: - - python -c "import saxshell; print(saxshell.__version__)" - - -To view the basic usage and available commands, type :: - - saxshell -h + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -c "import saxshell; print(saxshell.__version__)" Getting Started --------------- You may consult our `online documentation `_ for tutorials and API references. +Start by learning how to process your MD trajectory into frames and clusters, +then create a dedicated project folder for the SAXSShell session. Workflow Process Tree --------------------- @@ -186,15 +169,12 @@ based steady-state cutoff suggestions. The same ``mdtrajectory`` workflow can be used in three ways: 1. As a Qt desktop application for interactive use. -2. As a terminal command for scripted or batch workflows. +2. Through the source-checkout module launch for scripted or batch workflows. 3. As a Python class in notebooks and other Python scripts. -To launch the Qt application, use one of the following commands :: +To launch the Qt application from the repository root :: - mdtrajectory - mdtrajectory ui - saxshell mdtrajectory - python -m saxshell.mdtrajectory + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory Terminal Use Cases ------------------ @@ -205,11 +185,11 @@ preprocessing inside a larger shell workflow. Inspect a trajectory and optional energy file :: - mdtrajectory inspect traj.xyz --energy-file traj.ener + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener Suggest a steady-state cutoff from a CP2K energy profile :: - mdtrajectory suggest-cutoff traj.xyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz \ --energy-file traj.ener \ --temp-target-k 300 \ --temp-tol-k 1.0 \ @@ -217,7 +197,7 @@ Suggest a steady-state cutoff from a CP2K energy profile :: Preview the selected export range without writing files :: - mdtrajectory preview traj.xyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory preview traj.xyz \ --use-cutoff \ --cutoff-fs 50 \ --start 0 \ @@ -225,16 +205,12 @@ Preview the selected export range without writing files :: Export frames directly from the terminal :: - mdtrajectory export traj.xyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz \ --energy-file traj.ener \ --use-suggested-cutoff \ --temp-target-k 300 \ --window 3 -The ``saxshell`` command also forwards to the same workflow :: - - saxshell mdtrajectory inspect traj.xyz - Python and Notebook Use ----------------------- @@ -274,15 +250,12 @@ files into residue-labeled PDB files. It uses: The same ``xyz2pdb`` workflow can be used in three ways: 1. As a Qt desktop application for interactive use. -2. As a terminal command for scripted or batch workflows. +2. Through the source-checkout module launch for scripted or batch workflows. 3. As a Python class in notebooks and other Python scripts. -To launch the Qt application, use one of the following commands :: +To launch the Qt application from the repository root :: - xyz2pdb - xyz2pdb ui - saxshell xyz2pdb - python -m saxshell.xyz2pdb + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb Reference Library ----------------- @@ -296,11 +269,11 @@ add it to the current library folder. List the available reference molecules :: - xyz2pdb references list + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb references list Create a new reference molecule from the terminal :: - xyz2pdb references add ref.xyz --name dmf --residue-name DMF + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb references add ref.xyz --name dmf --residue-name DMF Residue Assignment JSON ----------------------- @@ -342,11 +315,11 @@ Suppose you have: Launch the UI :: - xyz2pdb + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb Or prefill the main inputs from the terminal :: - xyz2pdb ui splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb ui splitxyz \ --config dmf_assignments.json \ --library-dir src/saxshell/xyz2pdb/reference_library @@ -372,24 +345,24 @@ terminal. List the available references first if you want to confirm that the expected PDB is in the library :: - xyz2pdb references list \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb references list \ --library-dir src/saxshell/xyz2pdb/reference_library Inspect the selected XYZ input and JSON config :: - xyz2pdb inspect splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb inspect splitxyz \ --config dmf_assignments.json \ --library-dir src/saxshell/xyz2pdb/reference_library Preview the first frame before writing files :: - xyz2pdb preview splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb preview splitxyz \ --config dmf_assignments.json \ --library-dir src/saxshell/xyz2pdb/reference_library Export the converted PDB files :: - xyz2pdb export splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb export splitxyz \ --config dmf_assignments.json \ --library-dir src/saxshell/xyz2pdb/reference_library \ --output-dir xyz2pdb_splitxyz @@ -399,26 +372,22 @@ Terminal Use Cases Inspect the selected XYZ input, reference library, and config :: - xyz2pdb inspect splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb inspect splitxyz \ --config assignments.json \ --library-dir references Preview the first-frame residue assignments and suggested output folder :: - xyz2pdb preview splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb preview splitxyz \ --config assignments.json \ --library-dir references Export PDB files from one XYZ or a folder of XYZ files :: - xyz2pdb export splitxyz \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb export splitxyz \ --config assignments.json \ --library-dir references -The ``saxshell`` command also forwards to the same workflow :: - - saxshell xyz2pdb inspect splitxyz --config assignments.json - Python and Class-Based Use Case ------------------------------- @@ -495,15 +464,12 @@ same folder. The same ``clusters`` workflow can be used in three ways: 1. As a Qt desktop application for interactive use. -2. As a terminal command for scripted or batch workflows. +2. Through the source-checkout module launch for scripted or batch workflows. 3. As a Python class in notebooks and other Python scripts. -To launch the Qt application, use one of the following commands :: +To launch the Qt application from the repository root :: - clusters - clusters ui - saxshell cluster - python -m saxshell.clusters + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster Cluster UI Use -------------- @@ -514,11 +480,11 @@ enable periodic boundary conditions before exporting clusters. Launch the UI :: - clusters + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster Or open the UI with a frames folder preloaded :: - clusters ui splitxyz0001 + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster ui splitxyz0001 Inside the window: @@ -540,11 +506,11 @@ operations directly from the terminal. Inspect an extracted frames folder :: - clusters inspect splitxyz0001 + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster inspect splitxyz0001 Preview a run using explicit cluster rules and periodic wrapping :: - clusters preview splitxyz0001 \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster preview splitxyz0001 \ --use-pbc \ --search-mode kdtree \ --node Pb \ @@ -555,7 +521,7 @@ Preview a run using explicit cluster rules and periodic wrapping :: Export clusters without opening the UI :: - clusters export splitxyz0001 \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster export splitxyz0001 \ --use-pbc \ --search-mode kdtree \ --node Pb \ @@ -621,15 +587,12 @@ folders produced by ``clusters``. A typical workflow is: The same ``bondanalysis`` workflow can be used in three ways: 1. As a Qt desktop application for interactive use. -2. As a terminal command for scripted or batch workflows. +2. Through the source-checkout module launch for scripted or batch workflows. 3. As a Python workflow in notebooks and other Python scripts. -To launch the Qt application, use one of the following commands :: +To launch the Qt application from the repository root :: - bondanalysis - bondanalysis ui - saxshell bondanalysis - python -m saxshell.bondanalysis + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis Bondanalysis UI Use ------------------- @@ -641,11 +604,11 @@ updated. Launch the UI :: - bondanalysis + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis Or open the UI with a clusters folder preloaded :: - bondanalysis ui clusters_splitxyz0001 + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis ui clusters_splitxyz0001 Inside the window: @@ -681,11 +644,11 @@ Bondanalysis Terminal Use Inspect a clusters directory before running analysis :: - bondanalysis inspect clusters_splitxyz0001 + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis inspect clusters_splitxyz0001 Run bond-pair and angle analysis headlessly on every cluster type :: - bondanalysis run clusters_splitxyz0001 \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis run clusters_splitxyz0001 \ --bond-pair Pb:I:3.50 \ --bond-pair Pb:O:3.20 \ --angle-triplet Pb:I:I:3.50:3.50 @@ -693,17 +656,13 @@ Run bond-pair and angle analysis headlessly on every cluster type :: Restrict the run to selected stoichiometry folders and choose an explicit output directory :: - bondanalysis run clusters_splitxyz0001 \ + PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis run clusters_splitxyz0001 \ --output-dir bondanalysis_clusters_splitxyz0001 \ --cluster-type PbI2 \ --cluster-type Pb2I3 \ --bond-pair Pb:I:3.50 \ --angle-triplet Pb:I:I:3.50:3.50 -The top-level ``saxshell`` command forwards to the same workflow :: - - saxshell bondanalysis inspect clusters_splitxyz0001 - Bondanalysis Python and Notebook Use ------------------------------------ @@ -740,12 +699,10 @@ Support and Contribute If you see a bug or want to request a feature, please `report it as an issue `_ and/or `submit a fix as a PR `_. -Feel free to fork the project and contribute. To install saxshell -in a development mode, with its sources being directly used by Python -rather than copied to a package directory, use the following in the root -directory :: - - pip install -e . +Feel free to fork the project and contribute. For the current source-checkout +workflow, create or update the ``saxshell-py312`` conda environment from +``requirements/saxshell-py312.yml`` and run tools with ``PYTHONPATH=src`` from +the repository root. To ensure code quality and to prevent accidental commits into the default branch, please set up the use of our pre-commit hooks. diff --git a/docs/api/overview.md b/docs/api/overview.md index 844dc6e..784099a 100644 --- a/docs/api/overview.md +++ b/docs/api/overview.md @@ -4,7 +4,7 @@ This section is intentionally lightweight. The repository does not yet include a fully automated API reference pipeline, so this page focuses on the workflow classes that are most likely to be imported directly. -## Recommended entry points +## Recommended Python workflow classes ### Trajectory processing @@ -47,22 +47,14 @@ The SAXS stack also includes reusable modules for: These modules are usable from Python, but their interfaces are evolving faster than the main workflow classes above. -## CLI-first note +## Source-checkout launch note -Several parts of the repository are easier to discover from their CLI help than -from their Python surface: +Several tools can also be launched through their Python modules while the +public Python API stabilizes. From the repository root, start the main SAXS +application with: ```bash -saxshell --help -saxshell saxs --help -clusters --help -mdtrajectory --help -``` - -From a source checkout, you can inspect the SAXS CLI directly with: - -```bash -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs --help +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` ## TODO diff --git a/docs/development/contributing.md b/docs/development/contributing.md index ea15c72..b151c02 100644 --- a/docs/development/contributing.md +++ b/docs/development/contributing.md @@ -2,17 +2,23 @@ ## Local setup -Create a Python 3.12 environment and install the package in editable mode: +Create or update the Python 3.12 conda environment from the repository root: ```bash -python -m pip install -e . +conda env create -f requirements/saxshell-py312.yml +``` + +If the environment already exists: + +```bash +conda env update -n saxshell-py312 -f requirements/saxshell-py312.yml --prune ``` Install pre-commit and enable the hooks: ```bash -python -m pip install pre-commit -pre-commit install +conda run --no-capture-output -n saxshell-py312 python -m pip install pre-commit +conda run --no-capture-output -n saxshell-py312 pre-commit install ``` ## Running tests @@ -20,7 +26,7 @@ pre-commit install Run the full suite: ```bash -pytest -q +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 pytest -q ``` For focused work, prefer the smallest test slice that exercises your change. @@ -30,14 +36,14 @@ For focused work, prefer the smallest test slice that exercises your change. Install the pinned docs dependencies and start the local server: ```bash -python -m pip install -r requirements/docs.txt -mkdocs serve +conda run --no-capture-output -n saxshell-py312 python -m pip install -r requirements/docs.txt +conda run --no-capture-output -n saxshell-py312 mkdocs serve ``` Build the site locally the same way CI does: ```bash -mkdocs build --strict +conda run --no-capture-output -n saxshell-py312 mkdocs build --strict ``` ## Formatting and linting @@ -54,7 +60,7 @@ The repository uses pre-commit hooks for: Run them manually if needed: ```bash -pre-commit run --all-files +conda run --no-capture-output -n saxshell-py312 pre-commit run --all-files ``` ## Branch and PR expectations diff --git a/docs/development/saxs-contrast-mode.md b/docs/development/saxs-contrast-mode.md index e319acf..f51f46c 100644 --- a/docs/development/saxs-contrast-mode.md +++ b/docs/development/saxs-contrast-mode.md @@ -6,7 +6,7 @@ workflow. The design goal is simple: keep the contrast workflow fully separated from the legacy no-contrast SAXS builder except at explicit routing and reload seams. -## Main entry points +## Main code paths The feature spans three layers: diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index a2c04a0..790bca1 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -1,142 +1,137 @@ # Installation -## Python version +## Current install path -Use Python 3.12 for the current supported path. +SAXSShell is not pip-installable yet. Run it from a source checkout and create +the pinned conda environment from the repository environment file. -The repository CI and local conda guidance are pinned around Python 3.12 -because the current PySide6 stack in this repo is not ready for Python 3.14. +Use Python 3.12 through the provided `saxshell-py312` environment. The +repository CI and local guidance are pinned around Python 3.12 because the +current Qt stack is not ready for Python 3.14. -## Create a conda environment - -The current docs examples assume a Python 3.12 environment named -`saxshell-py312`: +## Clone the repository ```bash -conda create -n saxshell-py312 python=3.12 +git clone https://github.com/kewh5868/SAXSShell.git +cd SAXSShell ``` -You can either activate that environment first or keep commands explicit with -`conda run --no-capture-output -n saxshell-py312 ...`. The -`--no-capture-output` flag is useful for Qt applications because terminal logs -and tracebacks stay visible. +Run the rest of the commands from the repository root unless a page says +otherwise. -## Install from PyPI - -```bash -conda run --no-capture-output -n saxshell-py312 python -m pip install saxshell -``` +## Create the conda environment -After installation, confirm the umbrella CLI is available: +Create the environment from the checked-in `.yml` file: ```bash -conda run --no-capture-output -n saxshell-py312 saxshell --help +conda env create -f requirements/saxshell-py312.yml ``` -## Install from source - -Clone the repository and install it in editable mode: +If the environment already exists, update it from the same file: ```bash -git clone https://github.com/kewh5868/SAXSShell.git -cd SAXSShell -conda run --no-capture-output -n saxshell-py312 python -m pip install -e . +conda env update -n saxshell-py312 -f requirements/saxshell-py312.yml --prune ``` -This installs the package and the command-line entry points defined in the -project metadata. +The examples use `conda run --no-capture-output -n saxshell-py312 ...` so Qt +logs and tracebacks remain visible in the terminal. -## Run directly from a source checkout +## Launch SAXSShell -If you want to launch the software from the repository root without installing -editable entry points yet, export `PYTHONPATH=src` and run the relevant module -inside the conda environment: +Start the main SAXSShell application from the repository root: ```bash -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxshell --help -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs --help -PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` -Common translations from installed CLI form to source-checkout form are: +The application opens to the main SAXS workflow. Create or select a dedicated +project folder in **Project Setup** after your trajectory-derived frames and +clusters are ready. -- `mdtrajectory ...` -> `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory ...` -- `clusters ...` -> `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster ...` -- `blenderxyz ...` -> `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.toolbox.blender.cli ...` -- `saxshell saxs ...` -> `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ...` -- `saxshell fullrmc ...` -> `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.fullrmc ...` +## Recommended starting point -## Optional docs dependencies +Before spending time in Prefit or DREAM, prepare the simulation data that the +SAXS project will consume: -If you want to preview this documentation site locally: +1. Inspect the MD trajectory and export a frame folder with `mdtrajectory`. +2. Convert exported XYZ frames with `xyz2pdb` only if downstream analysis needs + residue-aware PDB files. +3. Extract stoichiometry-sorted clusters with `clusters`. +4. Create a dedicated project folder for the SAXSShell session. +5. Launch SAXSShell and choose that project folder in **Project Setup**. -```bash -python -m pip install -r requirements/docs.txt -mkdocs serve -``` +See [MD Extraction and Cluster Preparation](../user-guide/cluster-extraction.md) +for the trajectory-to-clusters path. -## Installed commands +## Standalone tools -The current package exposes these top-level tools: +These supporting tools can be used independently from the main SAXSShell +window, or opened from the main UI when a project-backed workflow is needed: -- `saxshell` -- `mdtrajectory` -- `clusters` -- `blenderxyz` -- `clusterdynamics` -- `clusterdynamicsml` -- `bondanalysis` -- `pdfsetup` -- `xyz2pdb` +| Tool | Short description | +| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------- | +| `mdtrajectory` | Inspects MD trajectories, reads optional CP2K energy files, helps choose equilibration cutoffs, and exports selected frames. | +| `xyz2pdb` | Converts extracted XYZ frames into residue-aware PDB files using reference molecule definitions. | +| `clusters` | Extracts stoichiometry-sorted cluster folders from exported XYZ or PDB frames. | +| `bondanalysis` | Measures bond-pair and angle distributions from cluster folders. | +| `clusterdynamics` | Builds time-binned cluster population heatmaps, energy overlays, and lifetime summaries. | +| `clusterdynamicsml` | Extends observed cluster dynamics with predicted larger structures and model-comparison outputs. | +| `pdfsetup` | Runs Debyer-backed trajectory-averaged PDF and partial-PDF calculations. | +| `blenderxyz` | Creates publication-style structure renders with Blender. | +| `representativefinder` | Selects representative structures from project-backed stoichiometry folders. | +| `structureviewer` | Opens individual structure files in the SAXSShell structure viewer. | -The SAXS and fullrmc CLIs are currently reached through the umbrella command -rather than separate installed scripts: +From the source checkout, these tools are reached through their Python modules +inside the `saxshell-py312` environment. For example: ```bash -saxshell saxs --help -saxshell fullrmc --help +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener ``` -## Installation notes +## External application dependencies -- The SAXS UI and several other tools use PySide6 for the GUI. -- The SAXS workflow also depends on scientific Python packages such as NumPy, - SciPy, and lmfit. -- The SAXS Debye component builder uses `xraydb`. -- The `blenderxyz` application also requires a separate Blender installation: - -- The Blender renderer works best when `blender` is on `PATH`, but you can also - browse to the Blender executable or `.app` bundle from inside the UI. +The conda environment file installs the Python stack. Some optional +SAXSShell applications call external software that must be installed separately. -## Debyer installation for PDF calculations +| External software | Required by | Install / docs | +| ----------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| Debyer | `pdfsetup` PDF and partial-PDF calculations | [Debyer docs](https://debyer.readthedocs.io/en/latest/) and [Debyer GitHub](https://github.com/wojdyr/debyer) | +| Blender | `blenderxyz` structure rendering | [Blender download](https://www.blender.org/download/) and [Blender installation manual](https://docs.blender.org/manual/en/latest/getting_started/installing/index.html) | +| Packmol | `fullrmc` Packmol setup and solvent packing workflows | [Packmol GitHub](https://github.com/m3g/packmol) and [Packmol user guide](https://m3g.github.io/packmol/) | +| Docker | `fullrmc` Packmol Docker link workflow | [Get Docker](https://docs.docker.com/get-started/get-docker/) | -The `pdfsetup` application uses -[Debyer](https://debyer.readthedocs.io/en/latest/) as an external backend. That -means SAXSShell does not bundle the Debyer binary itself. You need to install -Debyer separately and make sure the `debyer` executable is available on your -`PATH`. +### Debyer -Useful upstream links: +The `pdfsetup` application launches the `debyer` executable as an external +backend. SAXSShell does not bundle that binary. Install Debyer separately and +make sure `debyer` is available on `PATH` before running trajectory-averaged +PDF calculations. -- Debyer documentation: -- Debyer GitHub repository: +Debyer's upstream repository documents source builds with `autoconf`, +`automake`, `gengetopt`, `./configure`, and `make`. Follow the current upstream +instructions for your platform. -Debyer's official project documentation describes a native build based on its -own C/C++ source tree and autotools-style setup. SAXSShell's Debyer integration -does **not** require a Fortran runtime from Debyer itself. If you are installing -Debyer from source, follow the current upstream instructions rather than -assuming a Fortran toolchain is needed. +### Blender -When the PDF application starts, it runs a quick Debyer availability check by: +The `blenderxyz` renderer needs Blender installed separately. It works best +when `blender` is on `PATH`, but the renderer UI can also browse to a Blender +executable or `.app` bundle. -1. locating `debyer` on `PATH` -2. attempting a lightweight `debyer --help` subprocess call +### Packmol and Docker -If that startup check fails, the PDF UI reports that immediately so you can -resolve the Debyer installation or local execution permissions before launching -a long trajectory-average job. +The `fullrmc` Packmol setup workflow needs Packmol when you are preparing packed +coordinate files. The Packmol Docker link workflow additionally needs Docker so +SAXSShell can validate and sync files into a Packmol-ready container. -## TODO +Install Docker only if you plan to use the Packmol container workflow. If you +use Packmol outside Docker, make sure the selected workflow can reach the +`packmol` executable in that environment. + +## Optional docs dependencies -TODO: add a short platform-specific troubleshooting section once the current -conda packaging and GUI runtime guidance are finalized. +If you want to preview this documentation site locally: + +```bash +conda run --no-capture-output -n saxshell-py312 python -m pip install -r requirements/docs.txt +conda run --no-capture-output -n saxshell-py312 mkdocs serve +``` diff --git a/docs/getting-started/project-setup.md b/docs/getting-started/project-setup.md index fefcc82..dd0318e 100644 --- a/docs/getting-started/project-setup.md +++ b/docs/getting-started/project-setup.md @@ -5,6 +5,15 @@ distribution. This is the point where you define the Project Setup snapshot, optionally compute Debye-Waller factors, and decide how SAXS components should be built for the active modeling branch. +In plain language, this is where you tell SAXSShell what experimental SAXS data +you want to match, which simulation-derived cluster set should be compared +against that data, and which modeling branch should be saved for later fitting. + +!!! info "Image placeholder" +Add a screenshot of the full **Project Setup** tab with a loaded project, +showing the project path, input selectors, computed-distribution controls, +and preview panels. + ## What lives here The current UI code shows Project Setup as the first tab in the SAXS @@ -16,26 +25,42 @@ application. This is where you typically: - choose a model template - set the q-range, grid behavior, recognized elements, and excluded elements - create or load computed distributions -- optionally compute project-backed Debye-Waller factors +- optionally compute project-backed representative structures or + Debye-Waller factors - build SAXS components with the selected component-build mode - preview component traces and prior histograms before moving to Prefit ## Typical setup order -1. Open the SAXS application with `saxshell saxs ui` or - `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui`. -2. Create a new project or load an existing project directory. -3. Select the experimental dataset and the cluster folder you want to model. -4. Choose the template, q-range, grid mode, and excluded elements. -5. Pick the **Component build mode** for the current modeling branch. -6. Click **Create Computed Distribution**. -7. Optionally click **Compute Debye-Waller Factors** if the active clusters are - PDB files and you want saved disorder terms for later workflows. -8. Click **Build SAXS Components**. -9. Review the component preview, cluster table, and prior histogram preview. -10. Move to **SAXS Prefit** once the active distribution has the component +1. Process the MD trajectory first: inspect it, export frames, optionally + convert XYZ frames to PDB, and extract the cluster folder you want to model. +2. Create a dedicated project folder for this SAXSShell session. +3. Open the SAXS application from the repository root: + `PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs`. +4. Create a new project or load the project directory you prepared. +5. Select the experimental dataset and the cluster folder you want to model. +6. Choose the template, q-range, grid mode, and excluded elements. +7. Pick the **Component build mode** for the current modeling branch. +8. Click **Create Computed Distribution**. +9. Optionally compute representative structures from the full UI or the beta + CLI setup if later workflows should use representative files instead of + average cluster folders. +10. Optionally click **Compute Debye-Waller Factors (beta)** if the active + clusters are PDB files and you want saved disorder terms for later + workflows. +11. Click **Build SAXS Components**. +12. Review the component preview, cluster table, and prior histogram preview. +13. Move to **SAXS Prefit** once the active distribution has the component traces you want to fit. +!!! info "Image placeholder" +Add a screenshot focused on the project and input-selection controls used +in steps 2 through 4. + +!!! info "Image placeholder" +Add a screenshot focused on the computed-distribution and component-build +controls used in steps 5 through 8. + ## Computed distributions The Project Setup tab now makes computed distributions explicit. @@ -62,23 +87,44 @@ into the saved distribution identity. Because the component build mode is part of that identity, the same project can hold multiple otherwise-similar distributions at once, for example one built -with `No Contrast (Debye)` and one built with -`Born Approximation (Average)`. +with `No Contrast (Debye)`, one built with +`1D Born Approximation (Average)`, and one built with +`3D FFT Born Approximation`. ## Component build modes The **Component build mode** dropdown controls what happens when you click **Build SAXS Components**. -| Mode | What happens | -| -------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **No Contrast (Debye)** | The main SAXS UI runs the direct Debye component builder and saves the component traces into the active computed distribution. | -| **Contrast (Debye)** | SAXSShell opens the linked **SAXS Contrast Mode** workflow so you can analyze representative structures, compute electron-density terms, and build contrast-aware Debye traces for that computed distribution. | -| **Born Approximation (Average)** | SAXSShell opens the linked **Electron Density Mapping** workflow in computed-distribution mode so you can compute per-stoichiometry electron-density profiles, apply optional solvent subtraction, evaluate Fourier transforms, and then push the resulting Born-approximation components back into the model. | +| Mode | What happens | +| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **No Contrast (Debye)** | The main SAXS UI runs the direct Debye component builder and saves the component traces into the active computed distribution. | +| **Contrast (Debye)** | SAXSShell opens the linked **SAXS Contrast Mode** workflow so you can analyze representative structures, compute electron-density terms, and build contrast-aware Debye traces for that computed distribution. | +| **1D Born Approximation (Average)** | SAXSShell opens the linked legacy radial-density workflow in computed-distribution mode so you can compute per-stoichiometry electron-density profiles, apply optional solvent subtraction, evaluate spherical Fourier transforms, and then push the resulting Born-approximation components back into the model. | +| **3D FFT Born Approximation** | SAXSShell opens the separate Cartesian FFT workflow so you can build a 3D electron-density map, optionally apply a constant solvent-density contrast subtraction in real space, compare the q-shell-averaged FFT result against 1D Born and Debye references, and push computed traces back into the linked distribution. | + +## Representative structures + +Representative structures are optional project-backed files that compatible +Debye, Born, FFT, and RMCSetup workflows can use instead of average cluster +folders. Use **Tools > Structure Analysis > Open Representative Structures** for +the full interactive analysis UI, or use **Tools > (beta) > Open Representative +CLI Setup (Beta)** to save `representative_structure_cli_run.json` and run the +same backend from the source checkout: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project +``` ## Debye-Waller factors -**Compute Debye-Waller Factors** is an optional linked step in Project Setup. +**Compute Debye-Waller Factors (beta)** is an optional linked step in Project +Setup. + +!!! warning "Current testing status" +The linked **Compute Debye-Waller Factors (beta)** workflow is currently in +testing and has a known bug. Use it cautiously and verify any saved +Debye-Waller outputs before treating them as reliable downstream inputs. Important current behavior: @@ -95,6 +141,10 @@ You do not need Debye-Waller factors to create a computed distribution, but the tool is intended to be run before component building when you plan to reuse those saved disorder terms in later SAXSShell workflows. +!!! info "Image placeholder" +Add a screenshot of the Debye-Waller readiness indicator and button state +inside **Project Setup**, including an example tooltip if available. + ## Model and Build section Project Setup also includes **Install Custom Template**. This is for templates @@ -126,6 +176,8 @@ Examples from the current codebase include: - Treat **Create Computed Distribution** as the point where you intentionally branch the project into a specific build configuration. +- Finish the basic project definition in **Project Setup** before you spend + time interpreting Prefit or DREAM behavior. - Finish the Project Setup steps before judging Prefit behavior. Prefit and DREAM both depend on the active computed distribution and its saved component artifacts. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index d40adbc..7b1eaf6 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -1,106 +1,147 @@ # Quickstart -This quickstart is intentionally practical. It is not the full workflow, but it -gets you from a trajectory or cluster folder to the relevant applications. +This quickstart starts where a new SAXSShell project usually starts: with a +molecular dynamics trajectory that needs to become a reusable SAXS project +folder. -The command examples below assume the package is installed in your active -environment. If you are launching directly from a source checkout, use the -`PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m ...` -pattern from [Installation](installation.md). For example, `saxshell saxs ui` -maps to `python -m saxshell.saxs ui`. +In plain language, the goal is to turn a trajectory into exported frames, +clusters, and a dedicated SAXSShell project directory, then compare those +simulation-derived structures against experimental SAXS data. -## 1. Inspect and export frames +Run commands from the repository root after creating the conda environment in +[Installation](installation.md). -If you are starting from a trajectory: +## Prepare the MD trajectory first + +Begin by confirming that the trajectory is readable and exporting the frame set +that downstream tools should use: ```bash -mdtrajectory inspect traj.xyz --energy-file traj.ener -mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 ``` -This gives you a folder of exported frames that can feed the next stage. - -## 2. Optional XYZ to PDB conversion - -If you need residue-aware PDB frames before clustering: +If residue identity matters for your downstream analysis, convert the exported +XYZ frames before clustering: ```bash -xyz2pdb preview splitxyz --config residue_map.json -xyz2pdb export splitxyz --config residue_map.json +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb export splitxyz --config residue_map.json ``` -Skip this stage if plain XYZ cluster extraction is enough for your system. - -## 3. Extract clusters - -Launch the cluster UI: +Extract the cluster folder that the SAXS project will consume: ```bash -clusters +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster inspect splitxyz +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster export splitxyz ``` -Or inspect a frame folder from the terminal: +Create a dedicated project folder for the SAXSShell session. Keep the project +folder separate from raw trajectory output so saved SAXS state, computed +distributions, fit results, and optional project-backed calculations stay +together. ```bash -clusters inspect splitxyz +mkdir -p my_saxshell_project ``` -## 4. Analyze bond and angle distributions +The fastest way to understand the SAXS UI is to treat the first three tabs as a +sequence: -```bash -bondanalysis inspect clusters_splitxyz0001 -bondanalysis run clusters_splitxyz0001 -``` +- **Project Setup** defines the project inputs and creates a saved computed + distribution for one modeling branch. +- **SAXS Prefit** lets you inspect whether the chosen template and built + components produce a sensible model preview. +- **SAXS DREAM Fit** takes the Prefit state and runs Bayesian sampling when you + want a posterior distribution instead of just a single editable preview. -## 5. Launch the SAXS workflow +## Start in Project Setup -Open the SAXS UI: +Launch the SAXS UI: ```bash -saxshell saxs ui +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` -Inside the UI, the normal path is: - -1. Create or open a SAXS project. -2. Point the project at your experimental data and cluster folder. -3. In **Project Setup**, choose the template, q-range, grid behavior, excluded - elements, and SAXS component build mode. -4. Click **Create Computed Distribution** to save that Project Setup snapshot - and generate the matching prior-weight inputs. -5. Optionally click **Compute Debye-Waller Factors** when the active clusters - folder contains PDB files and you want saved disorder terms for later - workflows. -6. Click **Build SAXS Components**. -7. If the build mode is `No Contrast (Debye)`, the main UI runs the direct - component builder. -8. If the build mode is `Contrast (Debye)`, the linked SAXS Contrast Mode - window opens. -9. If the build mode is `Born Approximation (Average)`, the linked Electron - Density Mapping window opens in computed-distribution mode. -10. Review the model in **SAXS Prefit**. -11. Run **SAXS DREAM Fit** if you need Bayesian refinement. - -A computed distribution is the saved Project Setup configuration for one SAXS -modeling branch. In practice it tracks the active template, component-build -mode, cluster source, q-range and grid choices, excluded elements, and whether -the run uses observed structures only or observed plus predicted structures. - -## 6. If you are starting from an existing project - -You can open a project directly: +### What to do first -```bash -saxshell saxs ui /path/to/project -``` +1. Create a new project directory or open an existing SAXS project. +2. Select the experimental SAXS dataset. +3. Select the cluster folder you want to model. +4. Optionally select the solvent SAXS dataset if the workflow needs it. +5. Choose the template, q-range, grid behavior, and excluded elements. +6. Choose the **Component build mode** for the modeling branch you want to + save. +7. Click **Create Computed Distribution**. +8. Click **Build SAXS Components**. + +!!! info "Image placeholder" +Add a screenshot of the **Project Setup** tab after a project is loaded, +with the project path, data selectors, computed-distribution controls, and +component-build controls visible. + +### About computed distributions + +A computed distribution is SAXSShell's saved record of one Project Setup branch. +In practice it captures the active template, cluster source, q-range choices, +component-build mode, and related settings that define how SAXS components +should be generated for that branch. + +### Debye-Waller note + +!!! warning "Debye-Waller status" +**Compute Debye-Waller Factors (beta)** is currently in testing and has a +known bug. Treat that path as provisional, and verify any saved outputs +before you rely on them in later workflows. + +## Move to SAXS Prefit + +After components exist for the active computed distribution, move to +**SAXS Prefit**. + +This is the tab where you answer practical questions such as: + +- does the current template behave sensibly against the experimental trace +- do the built components look reasonable +- do any geometry-aware templates need cluster geometry metadata before the + model can update +- which parameters should stay fixed, vary, or be expressed through simple + relationships + +!!! info "Image placeholder" +Add a screenshot of the **SAXS Prefit** tab showing the main plot, the +parameter table, and any geometry or solution-estimator controls that +should be called out to a first-time user. + +## Use SAXS DREAM Fit when Prefit is stable + +Only move to **SAXS DREAM Fit** after Prefit looks reasonable. + +The DREAM tab uses the current Prefit state to prepare a pyDREAM runtime bundle +and then sample plausible parameter combinations. Use it when you want +uncertainty estimates, posterior summaries, or a more formal Bayesian fit. + +!!! info "Image placeholder" +Add a screenshot of the **SAXS DREAM Fit** tab showing the parameter map, +runtime settings, and result-preview area that a new user should inspect +first. + +## Optional upstream analysis -The same pattern also exists for the `fullrmc` UI: +The SAXS tabs assume you already have a usable project input set. Depending on +your question, you may also want to analyze the prepared clusters before +building SAXS components: ```bash -saxshell fullrmc ui /path/to/project +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis run clusters_splitxyz0001 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.clusterdynamics splitxyz --project-dir my_saxshell_project ``` -## Next step +## Next steps -Go to [Project Setup](project-setup.md) for the first SAXS-specific workflow in -the GUI. +- Go to [Project Setup](project-setup.md) for a more detailed setup sequence. +- Use [GUI Overview](../user-guide/gui-overview.md) if you want the main window + mapped out before exploring the deeper user-guide pages. +- Use [SAXS Prefit](../user-guide/saxs-prefit.md) and + [pyDREAM Workflow](../user-guide/pydream-workflow.md) once your computed + distribution is in place. diff --git a/docs/hooks.py b/docs/hooks.py new file mode 100644 index 0000000..ab3e8c7 --- /dev/null +++ b/docs/hooks.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +_GLOBAL_NOTICE = """ +!!! warning "Documentation status" + This page was generated and edited with the assistance of an LLM and is still in development. + It has not been fully vetted by the developer. Verify commands, UI labels, file paths, + workflow descriptions, and scientific claims against the current code and your local + workflow before relying on it. + + If you notice an error, omission, or outdated guidance, please open an issue on + [GitHub](https://github.com/kewh5868/SAXSShell/issues). +""".strip() + +_SECTION_NOTICES = { + "tutorials/": """ +!!! note "Tutorials section status" + The Tutorials section is still being built out. Treat this page as a draft scaffold + rather than a complete end-to-end tutorial. +""".strip(), + "api/": """ +!!! note "API section status" + The API section has not been fully created yet. Use this page as a provisional pointer + to likely workflow classes, not as a complete or stable API reference. +""".strip(), + "development/": """ +!!! note "Development section status" + The Development section is still incomplete. Current pages are working notes for + contributors rather than a fully vetted maintenance guide. +""".strip(), +} + + +def _section_notice(src_path: str) -> str | None: + normalized = str(src_path).strip().replace("\\", "/") + for prefix, notice in _SECTION_NOTICES.items(): + if normalized.startswith(prefix): + return notice + return None + + +def on_page_markdown(markdown, page, config, files): + notices = [_GLOBAL_NOTICE] + section_notice = _section_notice(page.file.src_path) + if section_notice is not None: + notices.append(section_notice) + notices.append(str(markdown).lstrip()) + return "\n\n".join(notices) diff --git a/docs/index.md b/docs/index.md index e124f34..45307f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,30 +2,40 @@ ![SAXSShell icon](source/img/saxshell_icon.svg){ width="220" } -SAXSShell is a workflow-oriented toolkit for turning molecular simulation output -into structural analysis and SAXS fitting workflows. The repository combines: +SAXSShell is a workflow-oriented toolkit for comparing molecular dynamics +structures against small-angle X-ray scattering (`SAXS`) data. + +In plain language, the main goal is to help you answer a solvation-structure +question: "which cluster populations, representative structures, and model +choices from my simulation best match the experimental SAXS signal?" + +The repository combines: - Qt desktop applications for interactive use -- command-line interfaces for reproducible batch runs +- source-checkout module launches for reproducible batch runs - Python workflow classes that mirror the same operations for notebooks and scripts The project is aimed at researchers who need to move from trajectory files to -cluster populations, distribution analysis, SAXS component building, prefit +cluster populations, distribution analysis, SAXS component building, Prefit screening, and Bayesian fitting without constantly switching tools. ## Who this documentation is for This site is organized around tasks: +- users preparing raw MD trajectories for a new SAXSShell project - users preparing a new SAXS project from cluster folders - users running SAXS Prefit and pyDREAM refinement - users maintaining or extending the template system - contributors working on the codebase itself -If you want a quick entry point, start with: +If you want a quick starting path, begin with installation, then learn how to +process your MD trajectories and create a project folder for the SAXSShell +session: - [Installation](getting-started/installation.md) +- [MD Extraction and Cluster Preparation](user-guide/cluster-extraction.md) - [Quickstart](getting-started/quickstart.md) - [Project Setup](getting-started/project-setup.md) @@ -40,18 +50,24 @@ The current repo supports an end-to-end path that usually looks like this: 5. Optionally predict larger clusters and representative predicted structures with `clusterdynamicsml`. 6. Measure bond and angle distributions with `bondanalysis`. -7. Optionally compute project-backed Debye-Waller factors from sorted PDB - cluster folders with **Debye-Waller Analysis**. -8. Optionally compute trajectory-averaged PDFs and partial PDFs with `pdfsetup`. -9. Build a SAXS project with `saxshell saxs`, create or load a computed - distribution in **Project Setup**, and choose how SAXS components will be - prepared. -10. Build components with one of the supported modes: - `No Contrast (Debye)`, `Contrast (Debye)`, or - `Born Approximation (Average)`. -11. Refine the project in **SAXS Prefit** and, if needed, run **pyDREAM**. -12. Use the resulting distributions and selected structures in downstream tools - such as `saxshell fullrmc`. +7. Optionally compute project-backed representative structures with the full + Representative Structures UI or the beta run-file based + `representativefinder` source-module path. +8. Optionally compute project-backed Debye-Waller factors from sorted PDB + cluster folders with **Debye-Waller Analysis**. This path is currently in + testing and should be treated cautiously because the linked + **Compute Debye-Waller Factors (beta)** workflow still has a known bug. +9. Optionally compute trajectory-averaged PDFs and partial PDFs with `pdfsetup`. +10. Create a dedicated SAXSShell project folder, launch the main application + from the source checkout, create or load a computed distribution in + **Project Setup**, and choose how SAXS components will be prepared. +11. Build components with one of the supported modes: + `No Contrast (Debye)`, `Contrast (Debye)`, + `1D Born Approximation (Average)`, or + `3D FFT Born Approximation`. +12. Refine the project in **SAXS Prefit** and, if needed, run **pyDREAM**. +13. Use the resulting distributions and selected structures in downstream tools + such as the `fullrmc` workflow. ## Documentation map @@ -69,27 +85,27 @@ into the main SAXS UI. The user guide is split into: -- **Main UI workflow elements** for the `saxshell saxs` application and its +- **Main UI workflow elements** for the SAXSShell application and its project, computed distributions, Prefit, DREAM, template, and export behavior - **Supporting applications** grouped the same way as the main `Tools` menu: `MD Extraction`, `Structure Analysis`, `Cluster Dynamics`, `PDF`, - `Visualization`, `SAXS Calculation Preview`, and `X-ray Toolkit` + `Visualization`, `SAXS Calculation Preview`, `X-ray Toolkit`, and `(beta)` ### Tutorials Use this section for longer, task-based walkthroughs that connect several tools -in sequence. +in sequence. These pages are still early drafts rather than complete tutorials. ### API Use this section if you want the shortest route to the reusable workflow -classes. +classes. It is currently a provisional overview, not a complete API reference. ### Development Use this section if you are contributing code, working on CI, or changing the -documentation site itself. +documentation site itself. These contributor notes are still being built out. ## Scope notes diff --git a/docs/tutorials/example-workflow.md b/docs/tutorials/example-workflow.md index 5541156..bcc1d09 100644 --- a/docs/tutorials/example-workflow.md +++ b/docs/tutorials/example-workflow.md @@ -3,13 +3,16 @@ This walkthrough shows a realistic high-level sequence without assuming a specific chemistry beyond "simulation frames eventually become a SAXS project." +Run the commands from the repository root after creating the `saxshell-py312` +conda environment from `requirements/saxshell-py312.yml`. + ## Step 1: inspect the trajectory Start with the trajectory tool to confirm that the input is readable and, if available, that the accompanying energy file is usable for cutoff analysis. ```bash -mdtrajectory inspect traj.xyz --energy-file traj.ener +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener ``` ## Step 2: export usable frames @@ -17,7 +20,7 @@ mdtrajectory inspect traj.xyz --energy-file traj.ener Use either a manual cutoff or the suggested one: ```bash -mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 ``` ## Step 3: convert to PDB only if needed @@ -25,7 +28,7 @@ mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --te If downstream logic needs molecule identity, convert the exported XYZ frames: ```bash -xyz2pdb export splitxyz --config residue_map.json +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb export splitxyz --config residue_map.json ``` ## Step 4: extract clusters @@ -33,8 +36,8 @@ xyz2pdb export splitxyz --config residue_map.json Use the cluster workflow on the exported frame folder: ```bash -clusters preview splitxyz -clusters export splitxyz +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster preview splitxyz +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster export splitxyz ``` ## Step 5: inspect distributions @@ -43,26 +46,25 @@ Run bond analysis on the resulting cluster folder if you need bond-pair or angle summaries: ```bash -bondanalysis run clusters_splitxyz0001 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis run clusters_splitxyz0001 ``` ## Step 6: build a SAXS project -Open the SAXS UI and configure the project: +Create a dedicated project folder, open the SAXS UI, and configure the project: ```bash -saxshell saxs ui +mkdir -p my_saxshell_project +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` In the UI: -1. select the experimental data -2. select the cluster folder -3. choose the template -4. build the project inputs - -From a source checkout, use -`PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui`. +1. choose or create the project folder +2. select the experimental data +3. select the cluster folder +4. choose the template +5. create the computed distribution and build the project inputs ## Step 7: refine the Prefit model diff --git a/docs/tutorials/md-to-saxs-pipeline.md b/docs/tutorials/md-to-saxs-pipeline.md index 673da6e..961d454 100644 --- a/docs/tutorials/md-to-saxs-pipeline.md +++ b/docs/tutorials/md-to-saxs-pipeline.md @@ -11,26 +11,30 @@ Assume you have: - optionally a CP2K energy file such as `traj.ener` - optionally a residue-mapping JSON file for `xyz2pdb` +Also assume you have cloned the repository, created the `saxshell-py312` +conda environment from `requirements/saxshell-py312.yml`, and are running these +commands from the repository root. + ## Export frames ```bash -mdtrajectory inspect traj.xyz --energy-file traj.ener -mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 -mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 ``` ## Convert to residue-aware PDB, if needed ```bash -xyz2pdb preview splitxyz --config residue_map.json -xyz2pdb export splitxyz --config residue_map.json +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb preview splitxyz --config residue_map.json +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.xyz2pdb export splitxyz --config residue_map.json ``` ## Extract clusters ```bash -clusters inspect splitxyz -clusters export splitxyz --use-pbc +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster inspect splitxyz +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.cluster export splitxyz --use-pbc ``` The exact node, linker, shell, and cutoff settings depend on the chemistry of @@ -40,21 +44,20 @@ hard-coding them into ad hoc notebooks. ## Analyze distributions ```bash -bondanalysis inspect clusters_splitxyz0001 -bondanalysis run clusters_splitxyz0001 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis inspect clusters_splitxyz0001 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.bondanalysis run clusters_splitxyz0001 ``` ## Build the SAXS project ```bash -saxshell saxs ui +mkdir -p my_saxshell_project +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ``` -There is currently no one-shot CLI that replaces the full Project Setup tab, so -the usual next step is interactive project configuration in the SAXS UI. - -From a source checkout, use -`PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs ui`. +The usual next step is interactive project configuration in the SAXS UI. Choose +the project folder you created, then select the experimental SAXS dataset and +the cluster folder produced above. ## Prefit and DREAM diff --git a/docs/user-guide/blender-structure-renderer.md b/docs/user-guide/blender-structure-renderer.md index ec2cca8..d924916 100644 --- a/docs/user-guide/blender-structure-renderer.md +++ b/docs/user-guide/blender-structure-renderer.md @@ -1,7 +1,7 @@ # Blender Structure Renderer The Blender tool is SAXSShell's publication-rendering application for atomistic -structures. It is exposed as the standalone `blenderxyz` program and from the +structures. It is available as the standalone `blenderxyz` tool and from the main SAXSShell window through `Tools > Visualization > Open Blender XYZ Renderer`. @@ -52,23 +52,6 @@ Blender application manually. You do not need to open the main SAXSShell UI first. The Blender renderer can be launched directly as its own Qt application. -### Installed package - -If SAXSShell is installed into your environment, start the Blender tool with: - -```bash -blenderxyz -``` - -You can also prefill a structure file or Blender location: - -```bash -blenderxyz path/to/structure.xyz -blenderxyz path/to/structure.pdb --blender-executable /Applications/Blender.app -``` - -### Source checkout - From the repository root, launch the same standalone application with: ```bash @@ -86,7 +69,8 @@ PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 \ ## Typical workflow -1. Launch `blenderxyz`. +1. Launch the Blender renderer from the source checkout or from the main + SAXSShell UI. 2. Choose an `XYZ` or `PDB` file. 3. Confirm or browse to the Blender executable if needed. 4. Review the generated orientation rows and duplicate or add custom rows. diff --git a/docs/user-guide/cluster-dynamics-ml.md b/docs/user-guide/cluster-dynamics-ml.md index 5961f0f..4b1dc64 100644 --- a/docs/user-guide/cluster-dynamics-ml.md +++ b/docs/user-guide/cluster-dynamics-ml.md @@ -1,8 +1,7 @@ # Cluster Dynamics ML `clusterdynamicsml` is the predictive companion to `clusterdynamics`. It can be -launched directly from the `clusterdynamicsml` application entry point or from -the main SAXSShell UI. It +launched directly from the source checkout or from the main SAXSShell UI. It combines: - time-binned cluster dynamics from extracted XYZ or PDB frame folders diff --git a/docs/user-guide/cluster-extraction.md b/docs/user-guide/cluster-extraction.md index d5146ce..e7ea150 100644 --- a/docs/user-guide/cluster-extraction.md +++ b/docs/user-guide/cluster-extraction.md @@ -16,6 +16,9 @@ In this repository, that bridge spans more than one tool. pairwise disorder coefficients from sorted PDB cluster folders. 7. Feed the resulting cluster folder into the SAXS project. +Run the examples from the repository root after creating the +`saxshell-py312` conda environment. + ## `mdtrajectory` This tool is responsible for: @@ -30,9 +33,9 @@ This tool is responsible for: Example: ```bash -mdtrajectory inspect traj.xyz --energy-file traj.ener -mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 -mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory inspect traj.xyz --energy-file traj.ener +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory suggest-cutoff traj.xyz --energy-file traj.ener --temp-target-k 300 --window 3 +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.mdtrajectory export traj.xyz --energy-file traj.ener --use-suggested-cutoff --temp-target-k 300 --window 3 ``` When a cutoff is applied, the default folder name now uses the form diff --git a/docs/user-guide/debye-waller-analysis.md b/docs/user-guide/debye-waller-analysis.md index ec53d3d..401e8e6 100644 --- a/docs/user-guide/debye-waller-analysis.md +++ b/docs/user-guide/debye-waller-analysis.md @@ -5,10 +5,20 @@ estimating pair-resolved thermal-displacement coefficients from sorted PDB cluster folders. It is designed to stay project-aware so the results can be saved, reopened, and reused by later SAXSShell workflows. +!!! warning "Current testing status" +The linked **Compute Debye-Waller Factors (beta)** workflow is currently +in testing and has a known bug. Treat the tool and its saved outputs as +provisional until they have been manually checked for your project. + +!!! info "Image placeholder" +Add a screenshot of the Debye-Waller Analysis window showing the cluster +selector, results tables, and run log that users should verify during a +test run. + ## Launching the application Open the tool from the main SAXS UI through -`Tools > Structure Analysis > Open Debye-Waller Analysis`. +`Tools > (beta) > Open Debye-Waller Analysis`. If the tool is launched from an active project, it uses that project to: diff --git a/docs/user-guide/electron-density-mapping.md b/docs/user-guide/electron-density-mapping.md index 9775433..aa5dde7 100644 --- a/docs/user-guide/electron-density-mapping.md +++ b/docs/user-guide/electron-density-mapping.md @@ -1,6 +1,6 @@ -# Electron Density Mapping +# 1D Born Approximation (Average) -The **Electron Density Mapping** tool is SAXSShell's supporting application for +The **1D Born Approximation (Average)** tool is SAXSShell's supporting application for building radial electron-density profiles from XYZ or PDB inputs and, when needed, turning those profiles into q-space scattering estimates. The current UI supports three working styles: @@ -11,7 +11,11 @@ UI supports three working styles: stoichiometry gets its own density, solvent, Fourier, and saved-output state. In Project Setup, this is the linked component-build workspace for -**Born Approximation (Average)**. +**1D Born Approximation (Average)**. + +This page documents the legacy radial-profile workflow. For the separate +Cartesian FFT workflow, see +[3D FFT Born Approximation](fft-born-approximation.md). In the main SAXS workflow the tool can run either in **Preview Mode** or in **Computed Distribution Mode**. Preview mode is for exploratory work. Computed @@ -29,12 +33,12 @@ PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshel ### From the main SAXS UI -- **Tools > SAXS Calculation Preview > Open Electron Density Mapping** opens +- **Tools > SAXS Calculation Preview > Open 1D Born Approximation** opens the tool in **Preview Mode**. The window title shows - `Electron Density Mapping (Preview)`, and the banner explains that pushed + `1D Born Approximation (Preview)`, and the banner explains that pushed model components are disabled in this mode. - **Build SAXS Components** with the - **Born Approximation (Average)** build mode opens the tool in + **1D Born Approximation (Average)** build mode opens the tool in **Computed Distribution Mode**. The tool inherits the active project q-range, the active computed distribution, the preferred input folder, and the distribution output directory. @@ -187,10 +191,11 @@ The **Fourier Transform** section converts the active density profile into a Born-approximation scattering estimate. The preview panel always shows the resampled and windowed real-space data that will be used. -The default transform domain is **mirrored mode**, which reflects the profile -about `r = 0` and evaluates the windowed transform over `-rmax` to `rmax`. -The UI also keeps a **legacy r min to r max transform** toggle for historical -behavior. +The validated default transform domain in the current UI is the historical +one-sided transform from `r = 0` to `r = rmax` with no window. The +**Legacy r min to r max transform** checkbox remains checked by default to +match the backend Born-versus-Debye comparison tests. Clearing that checkbox +switches to the mirrored-domain transform over `-rmax` to `rmax`. The evaluated transform is: @@ -208,7 +213,7 @@ with `I(q) = |F(q)|^2`. | Field | Description | | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | **r min / r max** | Real-space bounds used for the transform. | -| **Legacy r min to r max transform** | Switches from the default mirrored-domain transform back to the historical one-sided `r min` to `r max` behavior. In mirrored mode, the left bound is shown as `-r max`. | +| **Legacy r min to r max transform** | Keeps the current default one-sided `r min` to `r max` behavior. Clear it to switch to the mirrored-domain transform, where the left bound is shown as `-r max`. | | **q min / q max / q step** | Requested q grid for the output scattering profile. | | **Window** | Real-space apodization window. Options are `None`, `Lorch`, `Cosine`, `Hanning`, `Parzen`, `Welch`, `Gaussian`, `Sine`, and `Kaiser-Bessel`. | | **Resample pts** | Number of resampled real-space points used by the transform. | @@ -350,7 +355,7 @@ The structure viewer is an interactive **3D** Matplotlib viewer, not a static ### Standalone preview -1. Launch **Open Electron Density Mapping** from the main UI or start the tool +1. Launch **Open 1D Born Approximation** from the main UI or start the tool from the terminal. 2. Load a single structure file or a folder of structures. 3. Review the structure summary and set the mesh. @@ -363,7 +368,7 @@ The structure viewer is an interactive **3D** Matplotlib viewer, not a static ### Computed distribution / Born approximation 1. Open the tool from **Build SAXS Components** with the - **Born Approximation (Average)** workflow. + **1D Born Approximation (Average)** workflow. 2. Confirm that the inherited cluster folder, output directory, and q-range are correct. 3. Review the stoichiometry table and choose whether to run the full batch or diff --git a/docs/user-guide/fft-born-approximation.md b/docs/user-guide/fft-born-approximation.md new file mode 100644 index 0000000..2b4d0ef --- /dev/null +++ b/docs/user-guide/fft-born-approximation.md @@ -0,0 +1,246 @@ +# 3D FFT Born Approximation + +The **3D FFT Born Approximation** tool is SAXSShell's separate Cartesian +Fourier workflow for building scattering curves directly from a voxelized +three-dimensional electron-density map. + +Use this tool when you want to keep the full 3D structure instead of reducing +it to a spherically averaged radial density first. That makes it the right +place to experiment with constant solvent-density contrast subtraction in a way +that stays tied to the actual 3D molecular shape. + +## Launching the application + +### From the main SAXS UI + +- **Tools > SAXS Calculation Preview > Open 3D FFT Born Approximation** opens + the tool in **Preview Mode**. +- **Build SAXS Components** with the + **3D FFT Born Approximation** build mode opens the tool in + **Computed Distribution Mode**. + +Both launch paths inherit the active project q-range when that information is +available from the main UI. + +When the tool is opened from **Build SAXS Components**, it also inherits the +active project structure-source preference: + +- **Average cluster folders / input structures** builds each profile from every + structure file in the matching cluster folder. +- **Representative structures** builds each profile from the single saved + representative file recorded in + `rmcsetup/representative_structures/representative_selection.json`. + +If no-solvent, partial/source, or full-solvent representative variants are +available, the **Representative solvent** selector lets you choose which saved +variant is used for the 3D FFT run. + +## Layout + +The window follows the same broad layout style as the 1D Born tool: + +- a scrollable **left pane** for input, FFT settings, electron-density + contrast setup, overlay choices, plot options, actions, and the status log +- a scrollable **right pane** for the structure viewer, q-space curves, the + FFT real-space visualizer, shell diagnostics, and the run summary + +!!! info "Image placeholder" +Add a screenshot of the 3D FFT Born Approximation window showing the split +left and right panes, the q-space plot, and the FFT real-space visualizer. + +## 1D versus 3D Born + +The two Born tools do related but different jobs. + +| Workflow | What is transformed | What is preserved | Typical use | +| ----------------------------------- | ----------------------------------------------------------- | ------------------------------------------ | ----------------------------------------------------------------------------------------------------------- | +| **1D Born Approximation (Average)** | A spherically averaged radial density profile `rho(r)` | Radial structure only | Fast legacy workflow, radial diagnostics, comparison to the historical SAXSShell behavior | +| **3D FFT Born Approximation** | A full Cartesian contrast-density grid `Delta rho(x, y, z)` | Full 3D structure before q-shell averaging | 3D density studies, constant solvent-density subtraction, comparison against exact Debye and legacy 1D Born | + +In plain language: + +- the **1D Born** workflow first averages the structure into a radial profile + and then Fourier-transforms that profile +- the **3D FFT Born** workflow keeps the full 3D map, Fourier-transforms that + map, and only then averages intensity over q-shells + +## Mathematical model + +### Continuous 3D Born amplitude + +For a contrast density `Delta rho(r)`, the Born amplitude is + +$$ +A(\mathbf{q}) = +\int \Delta \rho(\mathbf{r}) +\exp\!\left(i \mathbf{q} \cdot \mathbf{r}\right) +\mathrm{d}\mathbf{r} +$$ + +and the orientationally averaged scattering intensity is + +$$ +I(q) = +\left\langle +\left|A(\mathbf{q})\right|^2 +\right\rangle_{\lVert \mathbf{q} \rVert = q}. +$$ + +### Constant solvent-density contrast + +The current 3D FFT workflow uses the contrast-density form + +$$ +\Delta \rho(\mathbf{r}) = +\rho_{\mathrm{atom}}(\mathbf{r}) +- \rho_0 \chi(\mathbf{r}), +$$ + +where: + +- `rho_atom(r)` is the voxelized atomic electron-density map +- `rho_0` is the constant solvent electron density in `e / Å^3` +- `chi(r)` is the exclusion mask built from the union of atomic exclusion + spheres + +This is why the 3D FFT workflow is the correct place for constant solvent +subtraction: the subtraction is applied in real space on the 3D density field, +not retrofitted into a purely radial post-processing step. + +### Discrete FFT form + +On a Cartesian grid with voxel volume `Delta V`, the tool evaluates + +$$ +A(\mathbf{q}_{ijk}) \approx +\Delta V +\sum_n +\Delta \rho(\mathbf{r}_n) +\exp\!\left(i \mathbf{q}_{ijk} \cdot \mathbf{r}_n\right). +$$ + +The FFT frequencies are converted with + +$$ +\mathbf{q} = 2 \pi \mathbf{f}, +$$ + +so the q values remain in `Å^-1` when the real-space coordinates are in `Å`. + +After the FFT, the tool computes shell-averaged intensity: + +$$ +I(q_t) = +\left\langle +\left|A(q_x, q_y, q_z)\right|^2 +\right\rangle_{q_t}. +$$ + +That last step is what makes the result comparable to orientationally averaged +Debye scattering. + +### Relation to the 1D Born workflow + +The 1D Born tool uses a radial transform of the form + +$$ +F(q) = +4 \pi +\int +\rho(r)\,W(r)\,r^2\, +\mathrm{sinc}\!\left(\frac{q r}{\pi}\right) +\mathrm{d}r +$$ + +with + +$$ +I(q) = |F(q)|^2. +$$ + +That is useful when the radial density itself is the object you want to model. +It is not the same as taking a full 3D FFT of the original molecular density +and then performing q-shell averaging. + +## Input fields and current defaults + +### 3D FFT Settings + +These defaults match the current backend debug and benchmark tests, except that +`q min` and `q max` are inherited from the main UI when available. + +| Field | Default | Meaning | +| -------------------------- | --------------------------- | ------------------------------------------------------------------ | +| **q min (Å^-1)** | inherited, otherwise `0.01` | Lower bound of the shared q grid | +| **q max (Å^-1)** | inherited, otherwise `1.20` | Upper bound of the shared q grid | +| **q step (Å^-1)** | `0.01` | q-grid spacing | +| **Voxel spacing (Å)** | `2.5` | Cartesian voxel spacing used for the FFT density grid | +| **Gaussian sigma (Å)** | `0.75` | Atomic deposition width for the voxelized density | +| **Minimum box length (Å)** | `640.0` | Minimum FFT box length before padding and odd-grid rounding | +| **Extra padding (Å)** | `24.0` | Additional vacuum padding around the structure before voxelization | + +### Electron Density Contrast + +The contrast section is separate from the FFT settings so you can configure and +apply solvent subtraction explicitly before the next FFT run. + +| Field | Default | Meaning | +| -------------------------------- | --------------------------- | ----------------------------------------------------------------------------------------- | +| **Compute option** | solvent formula and density | Chooses the contrast-density setup path | +| **Saved solvents** | `Water` | Preset solvent formula and density entry | +| **Solvent formula** | `H2O` when Water is loaded | Stoichiometry for neat-solvent estimation | +| **Density (g/mL)** | `1.0` when Water is loaded | Bulk density for neat-solvent estimation | +| **Direct density (e-/Å^3)** | `0.334` | Manual solvent electron density | +| **Reference solvent file** | empty | XYZ or PDB file used to estimate a uniform reference density from its full coordinate box | +| **Exclusion radius scale** | `1.0` | Multiplier applied to the atomic exclusion radii | +| **Exclusion radius padding (Å)** | `0.0` | Extra radius added to each exclusion sphere | +| **Active contrast** | none until applied | The density that will actually be used on the next run | + +### Comparison and plot defaults + +| Field | Default | Meaning | +| ------------------------------------------- | ------- | -------------------------------------------------------- | +| **Overlay 1D Born Approximation (Average)** | on | Computes and displays the legacy radial comparison curve | +| **Overlay exact Debye scattering** | off | Computes the exact Debye comparison trace on demand | +| **Show kernel-corrected FFT overlay** | off | Diagnostic overlay for zero-contrast runs only | +| **Log q axis** | on | Default q-space display scaling | +| **Log intensity axis** | on | Default intensity display scaling | + +## Kernel correction + +Kernel correction is a **diagnostic**, not a production solvent-contrast step. + +When the 3D density map is built by depositing atoms as Gaussians, that +deposition introduces a known smoothing response in q-space. For a zero-contrast +run, the current backend can divide out that Gaussian intensity factor so the +FFT result is easier to compare with the point-scatterer Debye limit. + +For solvent-contrast calculations, leave kernel correction off. Once constant +solvent-density subtraction is active, a single global Gaussian correction is +no longer the physically clean description of the full contrast-density field. + +## Outputs + +After a run, the 3D FFT window reports: + +- the q-space scattering curves +- an optional overlay against the legacy 1D Born and exact Debye curves +- an FFT real-space visualizer showing the centered structure and FFT box +- q-shell population diagnostics +- run timing, Nyquist limit, density integrals, and contrast metadata +- CSV export of the currently displayed q-space curves + +In **Computed Distribution Mode**, **Push to Model** writes the computed traces +and component map into the linked computed distribution so the main SAXS UI can +load them for SAXS Prefit and SAXS DREAM Fit. The saved distribution metadata +records whether the traces were built from average folders or representative +structures. + +Bare single-atom clusters use a direct single-atom Born trace for the 3D FFT +result so low-q bins do not become empty FFT-shell `NaN` values. + +## Related pages + +- [1D Born Approximation (Average)](electron-density-mapping.md) +- [GUI Overview](gui-overview.md) +- [Project Setup](../getting-started/project-setup.md) diff --git a/docs/user-guide/fullrmc-packmol-docker.md b/docs/user-guide/fullrmc-packmol-docker.md new file mode 100644 index 0000000..135e7f9 --- /dev/null +++ b/docs/user-guide/fullrmc-packmol-docker.md @@ -0,0 +1,258 @@ +# fullrmc Packmol Docker Link + +SAXSShell can link a pre-existing Docker container that has Packmol installed. +In the `fullrmc` workflow, Packmol is the external packer that consumes the +generated `packmol_inputs` folder and writes the packed coordinate output for +downstream RMC preparation. + +SAXSShell does **not** install Packmol for you and does **not** create the +Docker container automatically. The UI only validates the container, stores the +link metadata for the active project, and syncs generated Packmol input files +into a container-side project folder. + +## What Packmol is + +According to the official Packmol project, Packmol builds initial molecular +configurations by packing molecules into user-defined regions while enforcing +minimum-distance constraints between atoms from different molecules. The +upstream project documentation describes it as a tool for building initial +configurations for molecular dynamics simulations and notes support for `PDB`, +`TINKER`, and `XYZ` inputs. + +For SAXSShell users, the practical role of Packmol is simpler: + +- SAXSShell prepares the representative structures, composition plan, and + `packmol_combined.inp` input file. +- Packmol is then run separately inside the linked container to build the + packed coordinate output. +- Because the Packmol working folder is bind-mounted from the host, the input + and output files remain visible to both the container and your local project + environment. + +## How to cite Packmol + +The upstream Packmol README asks users to cite one of the Packmol papers when +the software contributes to a publication. For most SAXSShell workflows, the +general package paper is the clearest reference: + +- L. Martinez, R. Andrade, E. G. Birgin, and J. M. Martinez, + _Packmol: A package for building initial configurations for molecular + dynamics simulations_, + _Journal of Computational Chemistry_ **30** (2009), 2157-2164. + DOI: + +The older packing-optimization paper is also listed by Packmol upstream: + +- J. M. Martinez and L. Martinez, + _Packing optimization for the automated generation of complex system's + initial configurations for molecular dynamics and docking_, + _Journal of Computational Chemistry_ **24** (2003), 819-825. + DOI: + +If you describe the software tool itself, cite the 2009 paper first unless +your workflow specifically depends on the earlier algorithm paper. + +## SAXSShell container convention + +The Packmol Docker link in SAXSShell currently requires the selected +container-side project root to live under: + +```text +/packmol_input_files +``` + +This is a SAXSShell convention enforced by the current UI. It is **not** a +general Docker requirement and **not** a Packmol requirement. + +The recommended project layout inside the container is: + +```text +/packmol_input_files/my_project +``` + +When a link is active, SAXSShell stores that project folder as the +container-side root and syncs generated Packmol inputs into: + +```text +/packmol_input_files/my_project/rmcsetup/packmol_inputs +``` + +## Build a Docker image with Packmol installed + +The example below follows the Packmol upstream manual compilation path +(`./configure` then `make`) inside a Debian-based image. Replace the release +tag if you want a different Packmol version. + +```dockerfile +FROM debian:bookworm-slim + +ARG PACKMOL_VERSION=21.2.1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + gfortran \ + git \ + make \ + && rm -rf /var/lib/apt/lists/* + +RUN git clone --branch "v${PACKMOL_VERSION}" --depth 1 \ + https://github.com/m3g/packmol.git /opt/packmol \ + && cd /opt/packmol \ + && ./configure \ + && make \ + && cp /opt/packmol/packmol /usr/local/bin/packmol + +RUN mkdir -p /packmol_input_files +WORKDIR /packmol_input_files + +# Keeps the container available for docker exec-based validation from SAXSShell. +CMD ["sleep", "infinity"] +``` + +Build the image: + +```bash +docker build -t saxshell-packmol:21.2.1 -f Dockerfile.packmol . +``` + +!!! note "Alternative installation paths" +The Packmol upstream README also documents `pip install packmol` for many +platforms. The source-build path above is shown here because it follows the +upstream manual compilation instructions directly and keeps the container +behavior explicit. + +## Bind-mount a host folder for read/write access + +Docker's bind-mount documentation recommends the `--mount` flag and notes that +bind mounts are read-write by default. That default is the right choice for +this SAXSShell workflow, because: + +- SAXSShell needs to copy Packmol input files into the container-visible folder +- Packmol needs to write packed output files back to the same host-backed + folder + +Create a host folder that will hold one or more Packmol-linked project roots: + +```bash +mkdir -p "$HOME/saxshell_packmol_projects/my_project" +``` + +Start the container with that host folder mounted at +`/packmol_input_files`: + +```bash +docker run -d \ + --name saxshell-packmol \ + --mount type=bind,src="$HOME/saxshell_packmol_projects",dst=/packmol_input_files \ + saxshell-packmol:21.2.1 +``` + +Because the mount is read-write, files created inside +`/packmol_input_files/my_project` in the container will also appear under +`$HOME/saxshell_packmol_projects/my_project` on the host. + +Verify that Packmol resolves and that the container-side project folder exists: + +```bash +docker exec -it saxshell-packmol sh -lc \ + 'packmol --help >/dev/null && mkdir -p /packmol_input_files/my_project && ls -ld /packmol_input_files/my_project' +``` + +!!! warning "Do not mount this path read-only" +A read-only bind mount is useful for inspection, but it will block the +normal SAXSShell Packmol workflow because the UI sync step and Packmol +itself both need write access to the mounted folder. + +## Linking the container from the UI + +Open either: + +- the main SAXS window: `File > Link Packmol Docker Container...` +- the `fullrmc` window: `Tools > Link Packmol Docker Container` + +The dialog lets you enter: + +- a preset name for reuse across projects +- a discovered-container list pulled from Docker +- the Docker container name +- the Packmol command inside the container, usually `packmol` +- the shell command used for validation, usually `sh` +- the container-side project folder, which must be inside + `/packmol_input_files` + +After you press `Test Container`, SAXSShell will: + +1. check that Docker is reachable +2. use the selected container name and start the container if it is not + already running +3. verify that the selected folder exists inside the container +4. verify that the Packmol command resolves inside the container +5. capture a Packmol version/help line so the executable can be confirmed as + runnable inside the container +6. load a directory tree in the dialog so you can pick the exact + container-side project folder + +!!! info "Image placeholder" +Add a screenshot of the **Link Packmol Docker Container** dialog after a +successful `Test Container`, with the container list, selected project +folder, and directory tree visible. + +When the link is accepted, SAXSShell stores: + +- a project-specific link file at `rmcsetup/packmol_docker_link.json` +- a reusable recent preset in the application settings so the same container + can be linked to a new project more quickly later + +## Running Packmol after SAXSShell syncs the inputs + +When you run `Build Packmol Setup` in the `fullrmc` window: + +1. SAXSShell writes the normal local `rmcsetup/packmol_inputs` files +2. if a Docker link is active, SAXSShell syncs those inputs into the linked + container project folder +3. the Packmol section in the UI records the last sync status and the remote + Packmol input/output paths + +At that point, a typical Packmol run inside the container looks like: + +```bash +docker exec -it saxshell-packmol sh -lc \ + 'cd /packmol_input_files/my_project/rmcsetup/packmol_inputs && packmol < packmol_combined.inp' +``` + +Because the working folder is bind-mounted, the resulting packed output file is +available both inside the container and on the host filesystem. + +## Manual startup reference + +If the selected container is stopped, SAXSShell will try to start it before +validation. If your container only stays alive when attached to a terminal, the +manual reference command is: + +```bash +docker start -i +``` + +Use that command outside SAXSShell, then retry the link dialog. + +## Practical notes + +- Keep `/packmol_input_files` dedicated to the host bind mount so there is no + confusion about where Packmol inputs and outputs are actually being written. +- Docker bind mounts obscure pre-existing files at the target path inside the + container, so it is best to mount into a directory that is intentionally + reserved for Packmol exchange files. +- If you are using a remote Docker daemon instead of a local one, bind mounts + refer to paths on the Docker daemon host, not necessarily the machine where + you launched the client command. + +## References + +- Packmol upstream project and installation instructions: + +- Packmol user guide and examples: + +- Docker bind mounts: + +- Docker `run` / `--mount` reference: + diff --git a/docs/user-guide/gui-overview.md b/docs/user-guide/gui-overview.md index 322ead2..6db4b48 100644 --- a/docs/user-guide/gui-overview.md +++ b/docs/user-guide/gui-overview.md @@ -1,17 +1,21 @@ # GUI Overview SAXSShell is not a single-window application. The repository contains multiple -Qt workflows, each with its own UI and CLI entry point. +Qt workflows, each with its own focused UI and source-checkout module launch. -Most tools install as direct top-level commands. The SAXS and fullrmc -applications currently route through the umbrella `saxshell` CLI, or through -`python -m` module execution when you are running from a source checkout. +The current user-facing install path runs from a source checkout with the +`saxshell-py312` conda environment. Start the main application from the +repository root with: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs +``` ## Main UI workflow elements -The primary SAXS workflow lives in `saxshell saxs`. Its tabs are not isolated: -the active template, component list, geometry metadata, and saved state all -move between them. +The primary SAXS workflow lives in the main SAXSShell application. Its tabs are +not isolated: the active template, component list, geometry metadata, and saved +state all move between them. Project Setup now separates two linked actions: @@ -25,14 +29,24 @@ The linked supporting-application launches are build-mode aware: - `No Contrast (Debye)` stays in the main SAXS UI and runs the direct Debye component builder. - `Contrast (Debye)` opens the contrast workflow window. -- `Born Approximation (Average)` opens Electron Density Mapping in +- `1D Born Approximation (Average)` opens the legacy radial-density workflow in + computed-distribution mode. +- `3D FFT Born Approximation` opens the separate Cartesian FFT Born workflow in computed-distribution mode. -### `saxshell saxs` +### Main SAXSShell Application Use this for SAXS project management, prefit modeling, pyDREAM refinement, and template-driven workflows. +In plain language, this main window is where SAXSShell turns simulation-derived +cluster populations and representative structures into a model that can be +compared against experimental SAXS data for solvation-structure analysis. + +!!! info "Image placeholder" +Add a screenshot of the main SAXSShell window showing the tab bar and +the overall page layout a first-time user sees after opening a project. + ### Project Setup Defines the project inputs, computed distributions, and component-build choice. @@ -49,6 +63,10 @@ The **Active Computed Distribution** panel on this tab summarizes the saved distribution identity and whether component, prior, Prefit, and DREAM artifacts already exist for that branch. +!!! info "Image placeholder" +Add a screenshot of **Project Setup** with the computed-distribution panel +and component-build controls visible. + ### SAXS Prefit Builds the lmfit-side preview around the current template, parameter table, and @@ -57,10 +75,18 @@ cluster geometry metadata. Geometry-aware templates can require component metadata derived from the cluster-support workflow before Prefit updates are possible. +!!! info "Image placeholder" +Add a screenshot of **SAXS Prefit** with the main plot, parameter table, +and any geometry-aware controls visible. + ### SAXS DREAM Fit Builds and runs the pyDREAM workflow once Prefit is in a usable state. +!!! info "Image placeholder" +Add a screenshot of **SAXS DREAM Fit** with the runtime settings, prior-map +editor button, and results panes visible. + ### Results and export Stores the saved project state, fit artifacts, and downstream handoff files @@ -111,9 +137,9 @@ Use these tools when you want to analyze the sorted clusters themselves. - `bondanalysis` measures bond-pair and angle distributions from the cluster folders. -- `Debye-Waller Analysis` estimates intra-molecular and inter-molecular - Debye-Waller coefficients from sorted PDB cluster folders and saves them in - the active project when requested from Project Setup or the Tools menu. +- `Representative Structures` selects project-backed representative structures + from stoichiometry folders. The beta CLI setup path writes a run file for the + same backend when headless execution is preferred. ### Cluster Dynamics @@ -131,14 +157,23 @@ extend the observed structure series. Use this section for pair-distribution workflows tied to the active project. - `pdfsetup` runs Debyer-backed trajectory-averaged PDF and partial-PDF - calculations and stores the saved calculation sets in the project. -- `saxshell fullrmc` remains the downstream setup path for fullrmc-oriented - project artifacts. + calculations and stores the saved calculation sets in the project. It + requires a separate [Debyer](https://debyer.readthedocs.io/en/latest/) + installation with `debyer` available to the process. +- The `fullrmc` workflow remains the downstream setup path for + fullrmc-oriented project artifacts. +- The main SAXS window also exposes `File > Link Packmol Docker Container...` + so Packmol can be linked before opening the `fullrmc` setup window. This + path requires Docker plus Packmol installed inside the linked container. +- Inside the `fullrmc` window, `Tools > Link Packmol Docker Container` can + validate a Packmol-ready Docker container and remember the selected + `/packmol_input_files` project folder for later Packmol input syncs. ### Visualization Use `blenderxyz` when you need publication-style structure renders that go -beyond the inline previewer. +beyond the inline previewer. It requires a separate +[Blender](https://www.blender.org/download/) installation. ### SAXS Calculation Preview @@ -147,14 +182,35 @@ settings outside the main computed-distribution flow. - `SAXS Contrast Mode` is the `Contrast (Debye)` representative-structure workflow. -- `Electron Density Mapping` is the - `Born Approximation (Average)` density-profile and Fourier-transform workflow. +- `1D Born Approximation` is the + `1D Born Approximation (Average)` density-profile and Fourier-transform + workflow. +- `3D FFT Born Approximation` is the separate Cartesian contrast-density FFT + workflow. ### X-ray Toolkit Use this section for smaller estimate windows such as volume-fraction, number density, attenuation, and fluorescence calculators. +### (beta) + +Use this section for early-access workflows that are exposed from the main +`Tools` menu but still need extra caution. + +- `Debye-Waller Analysis` estimates intra-molecular and inter-molecular + Debye-Waller coefficients from sorted PDB cluster folders and saves them in + the active project when requested from Project Setup or the Tools menu. +- `Representative CLI Setup` saves + `representative_structure_cli_run.json` in the project folder so + the `representativefinder` source module can execute the same representative + selection without the plotting and viewer UI. + +!!! warning "Debye-Waller status" +The linked **Compute Debye-Waller Factors (beta)** workflow is currently +in testing and has a known bug. Keep that in mind when documenting or +using the Project Setup integration path. + ## Supporting application references ### MD Extraction @@ -165,7 +221,7 @@ density, attenuation, and fluorescence calculators. ### Structure Analysis - [Bond Analysis](bond-analysis.md) -- [Debye-Waller Analysis](debye-waller-analysis.md) +- [Representative Structure CLI](representative-structure-cli.md) ### Cluster Dynamics @@ -175,6 +231,7 @@ density, attenuation, and fluorescence calculators. ### PDF - [PDF Calculation](pdf-calculation.md) +- [fullrmc Packmol Docker Link](fullrmc-packmol-docker.md) ### Visualization @@ -183,16 +240,28 @@ density, attenuation, and fluorescence calculators. ### SAXS Calculation Preview - [SAXS Contrast Mode](saxs-contrast-mode.md) -- [Electron Density Mapping](electron-density-mapping.md) +- [1D Born Approximation (Average)](electron-density-mapping.md) +- [3D FFT Born Approximation](fft-born-approximation.md) ### X-ray Toolkit - [X-ray Toolkit](xray-toolkit.md) -## TODO +### (beta) + +- [Debye-Waller Analysis](debye-waller-analysis.md) + +## External software summary + +The conda environment provides the Python dependencies, but these optional +applications need external software: -TODO: add screenshots once the docs site has a stable asset pipeline and the -UI labels settle after the current SAXS workflow changes. +| External software | Required by | Install / docs | +| ----------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| Debyer | `pdfsetup` PDF and partial-PDF calculations | [Debyer docs](https://debyer.readthedocs.io/en/latest/) and [Debyer GitHub](https://github.com/wojdyr/debyer) | +| Blender | `blenderxyz` structure rendering | [Blender download](https://www.blender.org/download/) and [Blender installation manual](https://docs.blender.org/manual/en/latest/getting_started/installing/index.html) | +| Packmol | `fullrmc` Packmol setup and solvent packing workflows | [Packmol GitHub](https://github.com/m3g/packmol) and [Packmol user guide](https://m3g.github.io/packmol/) | +| Docker | `fullrmc` Packmol Docker link workflow | [Get Docker](https://docs.docker.com/get-started/get-docker/) | ??? note "Artwork Attribution" The SAXSShell application icon used across the UI, documentation site, and diff --git a/docs/user-guide/preloaded-saxs-models.md b/docs/user-guide/preloaded-saxs-models.md index f77986c..3dbc5dd 100644 --- a/docs/user-guide/preloaded-saxs-models.md +++ b/docs/user-guide/preloaded-saxs-models.md @@ -9,15 +9,16 @@ single paper. ## Template Catalog -| Template file | GUI name | Status | Model family | -| -------------------------------------------- | ------------------------------------------------------ | ---------- | -------------------------------------------- | -| `template_pydream_monosq_normalized.py` | `pyDREAM MonoSQ Normalized` | current | MonoSQ hard-sphere | -| `template_pydream_poly_lma_hs.py` | `pyDREAM Poly LMA Hard-Sphere` | current | sphere-only Poly LMA hard-sphere | -| `template_pydream_poly_lma_hs_mix_approx.py` | `pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)` | current | mixed-shape approximate Poly LMA hard-sphere | -| `template_likelihood_monosq.py` | `MonoSQ Basic (archived)` | archived | MonoSQ hard-sphere | -| `template_pd_likelihood_monosq.py` | `MonoSQ PD (archived)` | archived | MonoSQ hard-sphere | -| `template_pd_likelihood_monosq_decoupled.py` | `MonoSQ Decoupled (archived)` | archived | MonoSQ hard-sphere | -| `template_pydream_poly_lma_hs_legacy.py` | `pyDREAM Poly LMA Hard-Sphere (deprecated)` | deprecated | mixed-shape approximate Poly LMA hard-sphere | +| Template file | GUI name | Status | Model family | +| ------------------------------------------------------ | ------------------------------------------------------ | ---------- | --------------------------------------------- | +| `template_pydream_monosq_normalized.py` | `pyDREAM MonoSQ Normalized` | current | MonoSQ hard-sphere | +| `template_pydream_monosq_normalized_scaled_solvent.py` | `pyDREAM MonoSQ Normalized (Scaled Solvent Weight)` | current | MonoSQ hard-sphere with scale-coupled solvent | +| `template_pydream_poly_lma_hs.py` | `pyDREAM Poly LMA Hard-Sphere` | current | sphere-only Poly LMA hard-sphere | +| `template_pydream_poly_lma_hs_mix_approx.py` | `pyDREAM Poly LMA Hard-Sphere/Ellipsoid Mix (Approx.)` | current | mixed-shape approximate Poly LMA hard-sphere | +| `template_likelihood_monosq.py` | `MonoSQ Basic (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq.py` | `MonoSQ PD (archived)` | archived | MonoSQ hard-sphere | +| `template_pd_likelihood_monosq_decoupled.py` | `MonoSQ Decoupled (archived)` | archived | MonoSQ hard-sphere | +| `template_pydream_poly_lma_hs_legacy.py` | `pyDREAM Poly LMA Hard-Sphere (deprecated)` | deprecated | mixed-shape approximate Poly LMA hard-sphere | ## Shared Notation @@ -37,38 +38,102 @@ Across the bundled templates: Applies to: - `template_pydream_monosq_normalized.py` +- `template_pydream_monosq_normalized_scaled_solvent.py` - `template_likelihood_monosq.py` - `template_pd_likelihood_monosq.py` - `template_pd_likelihood_monosq_decoupled.py` These templates treat the MD-derived component profiles as a weighted solute mixture modulated by a **single** monodisperse hard-sphere structure factor. - -\[ -I*{\mathrm{mix}}(q) = \sum*{i=0}^{N-1} w_i I_i(q) -\] - -\[ -\begin{aligned} -I*{\mathrm{model}}(q) ={}& -\mathrm{scale}\, I*{\mathrm{mix}}(q)\, -S*{\mathrm{HS}}(q; R*{\mathrm{eff}}, \phi*{\mathrm{vol}}) \\ -&+ w*{\mathrm{solv}} I\_{\mathrm{solv}}(q) - -- \mathrm{offset} - \end{aligned} - \] +They differ mainly in where the experimental solvent trace enters the scaled +model expression. + +All MonoSQ templates start from the same solute branch: + +$$ +I_{\mathrm{mix}}(q) = \sum_{i=0}^{N-1} w_i I_i(q) +$$ + +$$ +I_{\mathrm{solute}}(q) = +I_{\mathrm{mix}}(q) S_{\mathrm{HS}}(q; R_{\mathrm{eff}}, \phi_{\mathrm{vol}}) +$$ + +### Current normalized MonoSQ + +The original `pyDREAM MonoSQ Normalized` template keeps the historical +unscaled-solvent convention: + +$$ +I_{\mathrm{model}}(q) = +\mathrm{scale}\, I_{\mathrm{solute}}(q) ++ w_{\mathrm{solv}} I_{\mathrm{solv}}(q) ++ \mathrm{offset}. +$$ + +In this template, `scale` applies only to the MD-derived solute branch. The +solvent trace is added after the global scale, so `solv_w` must carry both the +physical solvent-background multiplier and any remaining intensity-unit +mismatch between the imported solvent data and the scaled MD model. This +preserves the behavior of existing projects, but it can make fitted `solv_w` +values look much smaller than a physical solvent volume fraction when the +experimental solvent trace is orders of magnitude larger than the model trace. + +The solution-scattering calculator can still seed `solv_w` for this template +with the combined solvent-background multiplier. It does **not** seed `vol_frac` +for the original MonoSQ template; `vol_frac` remains a fitted hard-sphere +packing term. + +### Scaled Solvent Weight MonoSQ + +The `pyDREAM MonoSQ Normalized (Scaled Solvent Weight)` template keeps the same +MonoSQ solute branch and point-normalized likelihood, but moves the solvent +trace inside the global scale: + +$$ +I_{\mathrm{model}}(q) = +\mathrm{scale} +\left[ +I_{\mathrm{solute}}(q) ++ w_{\mathrm{solv}} I_{\mathrm{solv}}(q) +\right] ++ \mathrm{offset}. +$$ + +Here `scale` applies to the combined solute-plus-solvent model. This makes +`solv_w` a model-facing solvent-background multiplier rather than a parameter +that also has to absorb the arbitrary MD-model intensity scale. In practice, +this is the safer MonoSQ starting point when the solvent blank intensity is +much larger than the unscaled MD component profiles. + +This template declares calculator targets in its metadata: + +- `vol_frac` receives the physical solute-associated volume fraction computed + from the solution composition. +- `solv_w` receives the combined solvent-background multiplier from attenuation + and SAXS-effective solvent contrast. + +It also declares Prefit startup behavior. When experimental data are available +and there is no saved Best Prefit or current Prefit state for this template, +Prefit applies the autoscale recommendation as soon as the template loads. The +new `scale` and `offset` limits are centered around the autoscale result rather +than preserving the broad template-default ranges. The default `eff_r` starts at +3 A, the lower bound of the effective-radius search range. + +Because the solvent branch is scale-coupled, Prefit's scale recommendation also +treats the solvent term as part of the scaled model instead of subtracting it as +an already-scaled background contribution. ### Variables -| Symbol / parameter | Meaning in SAXSShell | -| ------------------------------------- | ------------------------------------------------------------------- | -| \(w_i\) | generated component weight for cluster profile \(i\) | -| \(w\_{\mathrm{solv}}\) / `solv_w` | bounded solvent contribution weight | -| \(R\_{\mathrm{eff}}\) / `eff_r` | effective hard-sphere radius used in `calc_monodisperse_sq(...)` | -| \(\phi\_{\mathrm{vol}}\) / `vol_frac` | effective hard-sphere volume fraction inside the Percus-Yevick term | -| `scale` | solute intensity scale factor | -| `offset` | constant additive background | +| Symbol / parameter | Meaning in SAXSShell | +| ------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | +| \(w_i\) | generated component weight for cluster profile \(i\) | +| \(w\_{\mathrm{solv}}\) / `solv_w` | bounded solvent contribution weight | +| \(R\_{\mathrm{eff}}\) / `eff_r` | effective hard-sphere radius used in `calc_monodisperse_sq(...)`; scaled-solvent MonoSQ defaults to 3 A | +| \(\phi\_{\mathrm{vol}}\) / `vol_frac` | effective hard-sphere volume fraction inside the Percus-Yevick term | +| `scale` | global intensity scale; original MonoSQ applies it only to solute, scaled-solvent MonoSQ applies it to solute plus solvent | +| `offset` | constant additive background | ### Likelihood conventions diff --git a/docs/user-guide/project-configuration.md b/docs/user-guide/project-configuration.md index 23a5b8b..b4bc5bc 100644 --- a/docs/user-guide/project-configuration.md +++ b/docs/user-guide/project-configuration.md @@ -44,7 +44,8 @@ distribution identity: When experimental-grid mode is active, the experimental data source also feeds into that saved identity. This is why two distributions can coexist in one project even when they differ only by build mode, for example -`No Contrast (Debye)` versus `Born Approximation (Average)`. +`No Contrast (Debye)` versus `1D Born Approximation (Average)` or +`3D FFT Born Approximation`. ## Why this matters diff --git a/docs/user-guide/pydream-workflow.md b/docs/user-guide/pydream-workflow.md index 91dff4f..c5eed0e 100644 --- a/docs/user-guide/pydream-workflow.md +++ b/docs/user-guide/pydream-workflow.md @@ -7,6 +7,13 @@ to generate a pyDREAM bundle and launch a Bayesian refinement. Like Prefit, DREAM depends on the upstream component and geometry inputs prepared earlier in the main UI workflow and by the supporting applications. +In plain language, use this tab when you want uncertainty estimates and a +posterior distribution, not just one hand-tuned Prefit curve. + +!!! info "Image placeholder" +Add a screenshot of the **SAXS DREAM Fit** tab showing the runtime +controls, prior-map editor button, and the main results panels. + SAXSShell uses **pyDREAM**, a Python implementation of the **MT-DREAM(ZS)** sampler. In plain language, pyDREAM runs several Markov chains at once, explores parameter space, and then uses the accepted samples to diff --git a/docs/user-guide/representative-structure-cli.md b/docs/user-guide/representative-structure-cli.md new file mode 100644 index 0000000..eb516a4 --- /dev/null +++ b/docs/user-guide/representative-structure-cli.md @@ -0,0 +1,103 @@ +# Representative Structure CLI (beta) + +The representative structure CLI is a headless alternative to the full +Representative Structures UI. It uses the same representative-selection +workflow and writes the same project registry under +`rmcsetup/representative_structures/`, but avoids plot drawing, structure +viewer updates, and Qt progress refreshes during the actual analysis. + +## Workflow + +1. Open the main SAXSShell application from the source checkout. +2. Open **Tools > (beta) > Open Representative CLI Setup (Beta)**. +3. Select the project folder and representative input folder. +4. Load or enter the bond-pair and angle-triplet definitions. +5. Save the run file. +6. Run the printed command from the repository root in the target environment: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project +``` + +The default run file is saved at: + +```text +/path/to/project/representative_structure_cli_run.json +``` + +## Run File + +The beta setup window stores paths relative to the project folder when +possible. A typical run file includes: + +```json +{ + "version": 1, + "input_dir": "clusters_splitxyz0001", + "output_dir": "representative_finder/representativefinder_batch_clusters_splitxyz0001", + "analysis_mode": "all", + "overwrite_existing": false, + "settings": { + "selection_algorithm": "target_distribution_quantile_distance", + "bond_weight": 1.0, + "angle_weight": 1.0, + "solvent_weight": 1.0, + "parallel_workers": 8 + } +} +``` + +`analysis_mode` can be `all` or `single`. In `single` mode, the run file can +also include `selected_stoichiometry`. + +## CLI Commands + +Inspect the targets without running: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder inspect /path/to/project +``` + +Run the saved setup: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project +``` + +Use a non-default run file: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project --run-file /path/to/run.json +``` + +Override the saved worker count: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project --workers 12 +``` + +Recompute bins that already have project representatives: + +```bash +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.representativefinder run /path/to/project --overwrite-existing +``` + +## Outputs + +Each CLI target writes the same analysis artifacts as the full UI: + +- `representative_selection.json` +- `candidate_scores.tsv` +- `selection_summary.txt` +- the copied representative structure + +After each successful target, the CLI calls the same project-persistence path +as the UI. The project registry and reusable representative files are written +under: + +```text +/path/to/project/rmcsetup/representative_structures/ +``` + +For equivalent inputs, the full UI and CLI path should leave compatible +project metadata and representative structure sets. diff --git a/docs/user-guide/saxs-prefit.md b/docs/user-guide/saxs-prefit.md index 4c672d5..ed51759 100644 --- a/docs/user-guide/saxs-prefit.md +++ b/docs/user-guide/saxs-prefit.md @@ -14,6 +14,23 @@ cluster-derived component set and an active template. The supporting applications prepare those upstream inputs before the main SAXS UI turns them into a model preview. +In plain language, Prefit is the "does this model make sense before I run a +heavier fit?" tab. + +!!! info "Image placeholder" +Add a screenshot of the **SAXS Prefit** tab showing the main plot, the +parameter table, and the controls a first-time user should inspect before +moving to DREAM. + +## When to use Prefit first + +Use Prefit before DREAM when you need to: + +- confirm that the built SAXS components load cleanly +- sanity-check the overall shape and scale of the model against the data +- decide whether a template or q-range choice is obviously wrong +- prepare geometry-aware metadata before a longer Bayesian run + ## What you can do in Prefit From the current UI implementation, Prefit supports: @@ -122,8 +139,11 @@ solution SAXS than the older additive-volume estimate $V_{\mathrm{solute}} / (V_{\mathrm{solute}} + V_{\mathrm{solvent}})$. SAXSShell still prints this physical estimate in the output console for -reference, but it is no longer written directly into the model-facing -`phi_solute` / `phi_solvent` defaults. +reference. It is written to a Prefit parameter only when the active template +explicitly declares a physical volume-fraction target, such as `vol_frac` in +`pyDREAM MonoSQ Normalized (Scaled Solvent Weight)`. Older templates that expose +`phi_solute` / `phi_solvent` keep using the SAXS-effective ratio described +below. ### SAXS-effective interaction contrast ratio @@ -221,21 +241,40 @@ factor that answers the practical SAXS question, "how much more scattering intensity does the neat solvent have than the solvent fraction inside the real sample?" -If the active template includes both a model-facing fraction parameter -(`phi_solute` / `phi_solvent`) and a solvent-weight parameter such as -`solvent_scale`, Prefit writes the attenuation factor above into -`solvent_scale` and uses `R_saxs(E)` for the fraction parameter. The solvent -term therefore becomes +Prefit writes these estimates into different model parameters depending on the +active template's declared convention. + +For split-fraction templates such as the Poly LMA hard-sphere templates, +Prefit writes the attenuation factor above into `solvent_scale` and writes +`R_saxs(E)` or its solvent complement into the model-facing fraction parameter +(`phi_solute` / `phi_solvent`). The solvent term therefore uses the attenuation +scale together with the SAXS-effective solvent fraction: \(w*{\mathrm{solv}} (1 - R*{\mathrm{saxs}}) I\_{\mathrm{solv}}(q)\). -If the template only exposes a single solvent-weight parameter such as -`solv_w`, Prefit writes the combined solvent-background multiplier +For the original `pyDREAM MonoSQ Normalized` template, Prefit preserves the +historical single-parameter convention. There is no automatic `vol_frac` target, +and `solv_w` receives the combined solvent-background multiplier $$ w_{\mathrm{model}} = \left(1 - R_{\mathrm{saxs}}(E)\right) w_{\mathrm{solv}} $$ -into that parameter directly. +directly. In that original MonoSQ model, the solvent branch is added after the +global `scale`, so `solv_w` may also absorb any intensity-unit mismatch between +the experimental solvent blank and the scaled MD-derived model. + +For `pyDREAM MonoSQ Normalized (Scaled Solvent Weight)`, Prefit uses the same +combined solvent-background multiplier for `solv_w`, but the template places +that weighted solvent trace inside the global scale. It also declares `vol_frac` +as a calculator target, so Prefit writes the physical solute-associated volume +fraction into `vol_frac` while keeping `solv_w` as the solvent-background +multiplier. + +That scaled-solvent MonoSQ template also asks Prefit to autoscale on load. If +experimental data are present and the project does not already have a saved Best +Prefit or current Prefit state for that template, Prefit applies the autoscale +estimate immediately and narrows the `scale` and `offset` limits around the +computed values. ### Fluorescence background proxy @@ -375,22 +414,6 @@ exists and is mapped correctly. - If a geometry-aware template refuses to update, check the mapping column and the active radii values before looking elsewhere. -## TODO - -TODO: re-check the consistency between the attenuation estimator output -(`solvent_scale` / `solv_w`) and the solvent contribution used by the model. -At least some current workflows appear to produce model-facing solvent weights -around `0.15` even when the imported solvent trace still looks about `100x` -too large by eye at the same q-position maxima. Revisit whether this comes -from a mismatch between: - -- the attenuation-only estimate and the actual model solvent term -- split-fraction versus single-solvent-weight template conventions -- solvent blank versus sample normalization, transmission, thickness, or exposure -- partially pre-subtracted or otherwise inconsistently reduced solvent traces -- the possible need for a separate empirical solvent-normalization factor in - addition to the attenuation scaling - ## Related pages - [Project Configuration](project-configuration.md) diff --git a/docs/user-guide/template-system.md b/docs/user-guide/template-system.md index 1e81b1d..111c372 100644 --- a/docs/user-guide/template-system.md +++ b/docs/user-guide/template-system.md @@ -361,18 +361,7 @@ anisotropic hard-ellipsoid closure. - [Henriques J, Arleth L, Lindorff-Larsen K, Skepö M. _On the Calculation of SAXS Profiles of Folded and Intrinsically Disordered Proteins from Computer Simulations_. Journal of Molecular Biology (2018).](https://pubmed.ncbi.nlm.nih.gov/29548755/) - [Edwards-Gayle CJC, Khunti N, Hamley IW, Inoue K, Cowieson N, Rambo RP. _Design of a multipurpose sample cell holder for the Diamond Light Source high-throughput SAXS beamline B21_. Journal of Synchrotron Radiation (2021).](https://pmc.ncbi.nlm.nih.gov/articles/PMC7842227/) -## CLI support - -The SAXS CLI includes template management commands through the installed -umbrella command: - -```bash -saxshell saxs templates -saxshell saxs templates validate path/to/template.py -saxshell saxs templates install path/to/template.py -``` - -From a source checkout, use the module directly: +## Template management from source ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 python -m saxshell.saxs templates diff --git a/docs/user-guide/xray-toolkit.md b/docs/user-guide/xray-toolkit.md index 1595dea..3e14884 100644 --- a/docs/user-guide/xray-toolkit.md +++ b/docs/user-guide/xray-toolkit.md @@ -19,8 +19,9 @@ project-backed workflow. ## Current scope -Unlike project-backed tools such as PDF Calculation, Electron Density Mapping, -or Debye-Waller Analysis, the X-ray Toolkit windows are intended as focused +Unlike project-backed tools such as PDF Calculation, +1D Born Approximation (Average), 3D FFT Born Approximation, or +Debye-Waller Analysis, the X-ray Toolkit windows are intended as focused calculators rather than long-lived saved analysis workspaces. ## Related pages diff --git a/docs/user-guide/xyz2pdb-conversion.md b/docs/user-guide/xyz2pdb-conversion.md index 2a8e8d0..8048633 100644 --- a/docs/user-guide/xyz2pdb-conversion.md +++ b/docs/user-guide/xyz2pdb-conversion.md @@ -52,23 +52,21 @@ available, and successful exports write the output folder back to the project's ### From the terminal -Installed package: - ```bash -xyz2pdb -xyz2pdb ui path/to/frame_folder +PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 \ + python -m saxshell.xyz2pdb ``` -From a source checkout: +You can prefill a frame folder from the source checkout: ```bash PYTHONPATH=src conda run --no-capture-output -n saxshell-py312 \ - python -m saxshell.xyz2pdb.cli + python -m saxshell.xyz2pdb ui path/to/frame_folder ``` ## Typical workflow -1. Open `xyz2pdb`. +1. Open the XYZ-to-PDB converter from the main UI or from the source checkout. 2. Choose an `XYZ` file or a folder of `XYZ` files. 3. Click `Analyze Input`. 4. Check the sample analysis and confirm the detected elements. diff --git a/mkdocs.yml b/mkdocs.yml index a80231b..d8dc889 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,7 +1,8 @@ site_name: SAXSShell site_description: >- - Workflow-oriented documentation for SAXSShell, from trajectory processing and - cluster extraction to SAXS project setup, prefit modeling, and pyDREAM runs. + Workflow-oriented documentation for SAXSShell, focused on comparing + molecular-dynamics-derived solvation structures against small-angle X-ray + scattering data. site_url: https://kewh5868.github.io/SAXSShell/ repo_url: https://github.com/kewh5868/SAXSShell repo_name: kewh5868/SAXSShell @@ -37,6 +38,9 @@ theme: plugins: - search +hooks: + - docs/hooks.py + markdown_extensions: - admonition - attr_list @@ -82,19 +86,23 @@ nav: - XYZ to PDB Conversion: user-guide/xyz2pdb-conversion.md - Structure Analysis: - Bond Analysis: user-guide/bond-analysis.md - - Debye-Waller Analysis: user-guide/debye-waller-analysis.md - Cluster Dynamics: - Cluster Dynamics: user-guide/cluster-dynamics.md - Cluster Dynamics ML: user-guide/cluster-dynamics-ml.md - PDF: - PDF Calculation: user-guide/pdf-calculation.md + - fullrmc Packmol Docker Link: user-guide/fullrmc-packmol-docker.md - Visualization: - Blender Structure Renderer: user-guide/blender-structure-renderer.md - SAXS Calculation Preview: - SAXS Contrast Mode: user-guide/saxs-contrast-mode.md - - Electron Density Mapping: user-guide/electron-density-mapping.md + - 1D Born Approximation (Average): user-guide/electron-density-mapping.md + - 3D FFT Born Approximation: user-guide/fft-born-approximation.md - X-ray Toolkit: - X-ray Toolkit: user-guide/xray-toolkit.md + - (beta): + - Representative Structure CLI: user-guide/representative-structure-cli.md + - Debye-Waller Analysis: user-guide/debye-waller-analysis.md - Tutorials: - Example Workflow: tutorials/example-workflow.md - MD to SAXS Pipeline: tutorials/md-to-saxs-pipeline.md From 104168a61ae45996f01b74640feccf03db5debd0 Mon Sep 17 00:00:00 2001 From: kewh5868 Date: Thu, 7 May 2026 16:56:29 -0600 Subject: [PATCH 7/7] ci: pin docformatter hook to Python 3.12 Run the docformatter hook with Python 3.12 so its untokenize dependency does not build under Python 3.14, where the package setup fails while reading AST constants. --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31dbd01..1be1809 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,5 +65,6 @@ repos: rev: 5757c5190d95e5449f102ace83df92e7d3b06c6c hooks: - id: docformatter + language_version: python3.12 additional_dependencies: [tomli] args: [--in-place, --config, ./pyproject.toml]